Spaces:
Paused
Paused
| import argparse, os, sys, glob | |
| import torch | |
| import PIL | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from tqdm import tqdm, trange | |
| from itertools import islice | |
| from einops import rearrange, repeat | |
| from torchvision.utils import make_grid | |
| import time | |
| from pytorch_lightning import seed_everything | |
| from torch import autocast | |
| from contextlib import contextmanager, nullcontext | |
| from ldm.util import instantiate_from_config | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.models.diffusion.plms import PLMSSampler | |
| def chunk(it, size): | |
| it = iter(it) | |
| return iter(lambda: tuple(islice(it, size)), ()) | |
| def torch_gc(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| def load_model_from_config(config, ckpt, verbose=False): | |
| print(f"Loading model from {ckpt}") | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| sd = pl_sd["state_dict"] | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("missing keys:") | |
| print(m) | |
| if len(u) > 0 and verbose: | |
| print("unexpected keys:") | |
| print(u) | |
| model.cuda() | |
| model.half() | |
| model.eval() | |
| return model | |
| def load_img(image, W, H): | |
| w, h = image.size | |
| print(f"loaded input image of size ({w}, {h})") | |
| image = image.resize((int(W), int(H)), resample=PIL.Image.LANCZOS) | |
| print(f"resize input image to size ({W}, {H})") | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = image[None].transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| return 2.*image - 1. | |
| class AppModel(): | |
| def __init__(self,): | |
| self.config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") | |
| self.model = load_model_from_config(self.config, "models/ldm/stable-diffusion-v1/model.ckpt") | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| self.device = device | |
| self.model = self.model.to(device) | |
| self.sampler = PLMSSampler(self.model) | |
| self.img_sampler = DDIMSampler(self.model) | |
| self.C = 4 # latent channels | |
| self.f = 8 # downsampling factors | |
| def run_with_prompt(self, seed, prompt, n_samples, W, H, scale, ddim_steps, strength=0., init_img=None): | |
| torch_gc() | |
| seed_everything(seed) | |
| ddim_eta=0.0 | |
| assert prompt is not None | |
| print(f"Prompt: {prompt}") | |
| batch_size = n_samples | |
| data = [batch_size * [prompt]] | |
| start_code = None | |
| n_rows = int(n_samples**0.5) | |
| precision_scope = autocast | |
| if init_img is None: | |
| with torch.no_grad(): | |
| with precision_scope(device_type='cuda', dtype=torch.float16): | |
| with self.model.ema_scope(): | |
| all_samples = list() | |
| for prompts in tqdm(data, desc="data"): | |
| torch_gc() | |
| uc = None | |
| if scale != 1.0: | |
| uc = self.model.get_learned_conditioning(batch_size * [""]) | |
| if isinstance(prompts, tuple): | |
| prompts = list(prompts) | |
| c = self.model.get_learned_conditioning(prompts) | |
| shape = [self.C, H // self.f, W // self.f] | |
| samples_ddim, _ = self.sampler.sample(S=ddim_steps, | |
| conditioning=c, | |
| batch_size=n_samples, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=uc, | |
| eta=ddim_eta, | |
| x_T=start_code) | |
| x_samples_ddim = self.model.decode_first_stage(samples_ddim) | |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
| for x_sample in x_samples_ddim: | |
| x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
| image = Image.fromarray(x_sample.astype(np.uint8)) | |
| all_samples.append(image) | |
| # additionally, grid image | |
| grid = torch.stack([x_samples_ddim], 0) | |
| grid = rearrange(grid, 'n b c h w -> (n b) c h w') | |
| grid = make_grid(grid, nrow=n_rows) | |
| # to image | |
| grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() | |
| grid = grid.astype(np.uint8) | |
| torch_gc() | |
| return grid, all_samples | |
| else: | |
| init_image = load_img(init_img, W, H).to(self.device) | |
| init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) | |
| torch_gc() | |
| with precision_scope(device_type='cuda', dtype=torch.float16): | |
| init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space | |
| torch_gc() | |
| sampler = self.img_sampler | |
| sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) | |
| assert 0. <= strength < 1., 'can only work with strength in [0.0, 1.0)' | |
| t_enc = int(strength * ddim_steps) | |
| print(f"target t_enc is {t_enc} steps") | |
| with torch.no_grad(): | |
| with precision_scope(device_type='cuda', dtype=torch.float16): | |
| with self.model.ema_scope(): | |
| all_samples = list() | |
| for prompts in tqdm(data, desc="data"): | |
| uc = None | |
| if scale != 1.0: | |
| uc = self.model.get_learned_conditioning(batch_size * [""]) | |
| if isinstance(prompts, tuple): | |
| prompts = list(prompts) | |
| c = self.model.get_learned_conditioning(prompts) | |
| # encode (scaled latent) | |
| z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) | |
| # decode it | |
| samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale, | |
| unconditional_conditioning=uc,) | |
| x_samples = self.model.decode_first_stage(samples) | |
| x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
| for x_sample in x_samples: | |
| x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
| image = Image.fromarray(x_sample.astype(np.uint8)) | |
| all_samples.append(image) | |
| # additionally, save as grid | |
| grid = torch.stack([x_samples], 0) | |
| grid = rearrange(grid, 'n b c h w -> (n b) c h w') | |
| grid = make_grid(grid, nrow=n_rows) | |
| # to image | |
| grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() | |
| grid = grid.astype(np.uint8) | |
| torch_gc() | |
| return grid, all_samples | |
| if __name__ == "__main__": | |
| main() | |