| import sys | |
| sys.path.append("..") | |
| import os | |
| import argparse | |
| from torch.utils.data import random_split | |
| from src.datamodule import VLSP2020TarDataset, VLSP2020Dataset | |
| def prepare_tar_dataset(data_dir: str, dest_dir: str): | |
| dts = VLSP2020Dataset(data_dir) | |
| train_set, val_set = random_split(dts, [42_000, 14_427]) | |
| VLSP2020TarDataset(os.path.join(dest_dir, "vlsp2020_train_set.tar")).convert( | |
| train_set | |
| ) | |
| VLSP2020TarDataset(os.path.join(dest_dir, "vlsp2020_val_set.tar")).convert(val_set) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data_dir", type=str, required=True) | |
| parser.add_argument("--dest_dir", type=str, required=True) | |
| args = parser.parse_args() | |
| prepare_tar_dataset(args.data_dir, args.dest_dir) | |