File size: 2,156 Bytes
49cbc74
 
16d51ab
 
 
49cbc74
16d51ab
 
 
 
 
49cbc74
 
 
ec28976
 
 
6064267
49cbc74
16d51ab
49cbc74
16d51ab
49cbc74
 
 
 
 
 
 
 
 
16d51ab
49cbc74
 
 
 
ec28976
49cbc74
 
 
16d51ab
49cbc74
 
16d51ab
ec28976
49cbc74
16d51ab
ec28976
16d51ab
49cbc74
 
 
16d51ab
49cbc74
 
 
ec28976
6064267
 
49cbc74
16d51ab
49cbc74
16d51ab
 
49cbc74
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
from pathlib import Path

import torch
import tqdm
from datasets import concatenate_datasets, load_dataset, interleave_datasets

from qwenimage.datamodels import QwenConfig
from qwenimage.foundation import QwenImageFoundationSaveInterm


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--start-index", type=int, default=0)
    parser.add_argument("--imsize", type=int, default=512)
    parser.add_argument("--indir", type=str, default="/data/CrispEdit")
    parser.add_argument("--outdir", type=str, default="/data/regression_output")
    parser.add_argument("--steps", type=int, default=50)
    args = parser.parse_args()

    total_per = 10

    EDIT_TYPES = [
        "color",
        "style",
        "replace",
        "remove",
        "add",
        "motion change",
        "background change",
    ]

    all_edit_datasets = []
    for edit_type in EDIT_TYPES:
        to_concat = []
        for ds_n in range(total_per):
            ds = load_dataset("parquet", data_files=f"{args.indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
            to_concat.append(ds)
        edit_type_concat = concatenate_datasets(to_concat)
        all_edit_datasets.append(edit_type_concat)

    # consistent ordering for indexing, also allow extension
    join_ds = interleave_datasets(all_edit_datasets)

    save_base_dir = Path(args.outdir)
    save_base_dir.mkdir(exist_ok=True, parents=True)

    foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=args.imsize * args.imsize))

    dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
    
    for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):

        output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
            image=[input_data["input_img"]],
            prompt=input_data["instruction"],
            vae_image_override=args.imsize * args.imsize,
            latent_size_override=args.imsize * args.imsize,
            num_inference_steps=args.steps,
        ))

        torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")


if __name__ == "__main__":
    main()