Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from pathlib import Path | |
| import torch | |
| import tqdm | |
| from datasets import concatenate_datasets, load_dataset, interleave_datasets | |
| from qwenimage.datamodels import QwenConfig | |
| from qwenimage.foundation import QwenImageFoundationSaveInterm | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--start-index", type=int, default=0) | |
| parser.add_argument("--imsize", type=int, default=512) | |
| parser.add_argument("--indir", type=str, default="/data/CrispEdit") | |
| parser.add_argument("--outdir", type=str, default="/data/regression_output") | |
| parser.add_argument("--steps", type=int, default=50) | |
| args = parser.parse_args() | |
| total_per = 10 | |
| EDIT_TYPES = [ | |
| "color", | |
| "style", | |
| "replace", | |
| "remove", | |
| "add", | |
| "motion change", | |
| "background change", | |
| ] | |
| all_edit_datasets = [] | |
| for edit_type in EDIT_TYPES: | |
| to_concat = [] | |
| for ds_n in range(total_per): | |
| ds = load_dataset("parquet", data_files=f"{args.indir}/{edit_type}_{ds_n:05d}.parquet", split="train") | |
| to_concat.append(ds) | |
| edit_type_concat = concatenate_datasets(to_concat) | |
| all_edit_datasets.append(edit_type_concat) | |
| # consistent ordering for indexing, also allow extension | |
| join_ds = interleave_datasets(all_edit_datasets) | |
| save_base_dir = Path(args.outdir) | |
| save_base_dir.mkdir(exist_ok=True, parents=True) | |
| foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=args.imsize * args.imsize)) | |
| dataset_to_process = join_ds.select(range(args.start_index, len(join_ds))) | |
| for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index): | |
| output_dict = foundation.base_pipe(foundation.INPUT_MODEL( | |
| image=[input_data["input_img"]], | |
| prompt=input_data["instruction"], | |
| vae_image_override=args.imsize * args.imsize, | |
| latent_size_override=args.imsize * args.imsize, | |
| num_inference_steps=args.steps, | |
| )) | |
| torch.save(output_dict, save_base_dir/f"{idx:06d}.pt") | |
| if __name__ == "__main__": | |
| main() | |