Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import io | |
| import tarfile | |
| import torch | |
| import webdataset as wds | |
| import numpy as np | |
| from tqdm import tqdm | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline | |
| import datetime | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| import diffusers | |
| import fire | |
| def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50): | |
| blocks_to_save = [ | |
| 'unet.down_blocks.2.attentions.1', | |
| 'unet.mid_block.attentions.0', | |
| 'unet.up_blocks.0.attentions.0', | |
| 'unet.up_blocks.0.attentions.1', | |
| ] | |
| # Initialization | |
| dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42) | |
| pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo') | |
| pipe.to('cuda') | |
| pipe.set_progress_bar_config(disable=True) | |
| dataloader = DataLoader(dataset, batch_size=dataset_batch_size) | |
| ct = datetime.datetime.now() | |
| save_path = os.path.join(save_path, str(ct)) | |
| # Collecting dataset | |
| os.makedirs(save_path, exist_ok=True) | |
| writers = { | |
| block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save | |
| } | |
| writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')}) | |
| def to_kwargs(kwargs_to_save): | |
| kwargs = kwargs_to_save.copy() | |
| seed = kwargs['seed'] | |
| del kwargs['seed'] | |
| kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document) | |
| return kwargs | |
| dataloader_iter = iter(dataloader) | |
| for num_document, batch in tqdm(enumerate(dataloader)): | |
| if num_document < start_at: | |
| continue | |
| if num_document >= finish_at: | |
| break | |
| kwargs_to_save = { | |
| 'prompt': batch['caption'], | |
| 'positions_to_cache': blocks_to_save, | |
| 'save_input': True, | |
| 'save_output': True, | |
| 'num_inference_steps': 1, | |
| 'guidance_scale': 0.0, | |
| 'seed': num_document, | |
| 'output_type': 'pil' | |
| } | |
| kwargs = to_kwargs(kwargs_to_save) | |
| output, cache = pipe.run_with_cache( | |
| **kwargs | |
| ) | |
| blocks = cache['input'].keys() | |
| for block in blocks: | |
| sample = { | |
| "__key__": f"sample_{num_document}", | |
| "output.pth": cache['output'][block], | |
| "diff.pth": cache['output'][block] - cache['input'][block], | |
| "gen_args.json": kwargs_to_save | |
| } | |
| writers[block].write(sample) | |
| writers['images'].write({ | |
| "__key__": f"sample_{num_document}", | |
| "images.npy": np.stack(output.images) | |
| }) | |
| for block, writer in writers.items(): | |
| writer.close() | |
| if __name__ == '__main__': | |
| fire.Fire(main) |