Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,156 Bytes
49cbc74 16d51ab 49cbc74 16d51ab 49cbc74 ec28976 6064267 49cbc74 16d51ab 49cbc74 16d51ab 49cbc74 16d51ab 49cbc74 ec28976 49cbc74 16d51ab 49cbc74 16d51ab ec28976 49cbc74 16d51ab ec28976 16d51ab 49cbc74 16d51ab 49cbc74 ec28976 6064267 49cbc74 16d51ab 49cbc74 16d51ab 49cbc74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import argparse
from pathlib import Path
import torch
import tqdm
from datasets import concatenate_datasets, load_dataset, interleave_datasets
from qwenimage.datamodels import QwenConfig
from qwenimage.foundation import QwenImageFoundationSaveInterm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--imsize", type=int, default=512)
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
parser.add_argument("--outdir", type=str, default="/data/regression_output")
parser.add_argument("--steps", type=int, default=50)
args = parser.parse_args()
total_per = 10
EDIT_TYPES = [
"color",
"style",
"replace",
"remove",
"add",
"motion change",
"background change",
]
all_edit_datasets = []
for edit_type in EDIT_TYPES:
to_concat = []
for ds_n in range(total_per):
ds = load_dataset("parquet", data_files=f"{args.indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
to_concat.append(ds)
edit_type_concat = concatenate_datasets(to_concat)
all_edit_datasets.append(edit_type_concat)
# consistent ordering for indexing, also allow extension
join_ds = interleave_datasets(all_edit_datasets)
save_base_dir = Path(args.outdir)
save_base_dir.mkdir(exist_ok=True, parents=True)
foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=args.imsize * args.imsize))
dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
image=[input_data["input_img"]],
prompt=input_data["instruction"],
vae_image_override=args.imsize * args.imsize,
latent_size_override=args.imsize * args.imsize,
num_inference_steps=args.steps,
))
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
if __name__ == "__main__":
main()
|