| | import sys |
| | import time |
| | import torch |
| | import os |
| |
|
| | import json |
| | import argparse |
| | sys.path.append(os.getcwd()) |
| | from diffusers import LDMPipeline, DDIMScheduler, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler |
| | from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler |
| | from scheduler.scheduling_ddim_lm import DDIMLMScheduler |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="sampling script for CelebA-HQ.") |
| | parser.add_argument('--test_num', type=int, default=1) |
| | parser.add_argument('--start_index', type=int, default=0) |
| | parser.add_argument('--batch_size', type=int, default=4) |
| | parser.add_argument('--num_inference_steps', type=int, default=20) |
| | parser.add_argument('--sampler_type', type = str,default='lag', choices=[ 'pndm', 'ddim_lm', 'ddim', 'dpm++', 'dpm','dpm_lm', 'unipc']) |
| | parser.add_argument('--save_dir', type=str, default='/xxx/xxx') |
| | parser.add_argument('--model_id', type=str, |
| | default='/xxx/xxx/ddpm_ema_cifar10') |
| | parser.add_argument('--lamb', type=float, default=1.0) |
| | parser.add_argument('--kappa', type=float, default=0.0) |
| | parser.add_argument('--dtype', type=str, default='fp32') |
| | parser.add_argument('--device', type=str, default='cuda') |
| |
|
| | args = parser.parse_args() |
| |
|
| | dtype = None |
| | if args.dtype in ['fp32']: |
| | dtype = torch.float32 |
| | elif args.dtype in ['fp64']: |
| | dtype = torch.float64 |
| | elif args.dtype in ['fp16']: |
| | dtype = torch.float16 |
| | elif args.dtype in ['bf16']: |
| | dtype = torch.bfloat16 |
| |
|
| | start_index = args.start_index |
| | batch_size = args.batch_size |
| | sampler_type = args.sampler_type |
| | test_num = args.test_num |
| | num_inference_steps = args.num_inference_steps |
| | device = args.device |
| | lamb = args.lamb |
| | kappa = args.kappa |
| | model_id = args.model_id |
| |
|
| | save_dir = args.save_dir |
| | if not os.path.exists(save_dir): |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | with torch.no_grad(): |
| | |
| | pipe = LDMPipeline.from_pretrained(model_id, torch_dtype=dtype) |
| | pipe.unet.to(device) |
| | pipe.vqvae.to(device) |
| |
|
| | |
| | if sampler_type in ['pndm']: |
| | pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config) |
| | elif sampler_type in ['dpm++']: |
| | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
| | pipe.scheduler.config.solver_order = 3 |
| | pipe.scheduler.config.algorithm_type = "dpmsolver++" |
| | elif sampler_type in ['dpm_lm']: |
| | pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config) |
| | pipe.scheduler.config.solver_order = 3 |
| | pipe.scheduler.config.algorithm_type = "dpmsolver" |
| | pipe.scheduler.lamb = lamb |
| | pipe.scheduler.lm = True |
| | pipe.scheduler.kappa = kappa |
| | elif sampler_type in ['dpm']: |
| | pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config) |
| | pipe.scheduler.config.solver_order = 3 |
| | pipe.scheduler.config.algorithm_type = "dpmsolver" |
| | pipe.scheduler.lm = False |
| | elif sampler_type in ['ddim']: |
| | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
| | elif sampler_type in ['ddim_lm']: |
| | pipe.scheduler = DDIMLMScheduler.from_config(pipe.scheduler.config) |
| | pipe.scheduler.lamb = lamb |
| | pipe.scheduler.lm = True |
| | pipe.scheduler.kappa = kappa |
| | elif sampler_type in ['unipc']: |
| | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| |
|
| | for seed in range(start_index, start_index + test_num): |
| | print('prepare to sample') |
| | start_time = time.time() |
| | torch.manual_seed(seed) |
| | |
| | |
| | images = pipe(batch_size=batch_size, num_inference_steps=num_inference_steps).images |
| |
|
| | |
| | for i, image in enumerate(images): |
| | image.save( |
| | os.path.join(save_dir, f"cifar10_{sampler_type}_inference{num_inference_steps}_seed{seed}_{i}.png")) |
| | print(f"{sampler_type} batch##{seed},done") |
| |
|
| | |
| | end_time = time.time() |
| | time_difference = end_time - start_time |
| | print(f"The code took {time_difference} seconds to run.") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|