2-Step-Qwen-Image-Edit / scripts /save_regression_outputs.py
Elea Zhong
triplet loss experiments (prelim)
6064267
import argparse
from pathlib import Path
import torch
import tqdm
from datasets import concatenate_datasets, load_dataset, interleave_datasets
from qwenimage.datamodels import QwenConfig
from qwenimage.foundation import QwenImageFoundationSaveInterm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--imsize", type=int, default=512)
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
parser.add_argument("--outdir", type=str, default="/data/regression_output")
parser.add_argument("--steps", type=int, default=50)
args = parser.parse_args()
total_per = 10
EDIT_TYPES = [
"color",
"style",
"replace",
"remove",
"add",
"motion change",
"background change",
]
all_edit_datasets = []
for edit_type in EDIT_TYPES:
to_concat = []
for ds_n in range(total_per):
ds = load_dataset("parquet", data_files=f"{args.indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
to_concat.append(ds)
edit_type_concat = concatenate_datasets(to_concat)
all_edit_datasets.append(edit_type_concat)
# consistent ordering for indexing, also allow extension
join_ds = interleave_datasets(all_edit_datasets)
save_base_dir = Path(args.outdir)
save_base_dir.mkdir(exist_ok=True, parents=True)
foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=args.imsize * args.imsize))
dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
image=[input_data["input_img"]],
prompt=input_data["instruction"],
vae_image_override=args.imsize * args.imsize,
latent_size_override=args.imsize * args.imsize,
num_inference_steps=args.steps,
))
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
if __name__ == "__main__":
main()