Spaces:
Runtime error
Runtime error
| from transformers import CLIPTextModel, CLIPTokenizer, logging | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | |
| # suppress partial model loading warning | |
| logging.set_verbosity_error() | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| def seed_everything(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| # torch.backends.cudnn.deterministic = True | |
| # torch.backends.cudnn.benchmark = True | |
| def get_views(panorama_height, panorama_width, window_size=64, stride=8): | |
| panorama_height /= 8 | |
| panorama_width /= 8 | |
| num_blocks_height = (panorama_height - window_size) // stride + 1 | |
| num_blocks_width = (panorama_width - window_size) // stride + 1 | |
| total_num_blocks = int(num_blocks_height * num_blocks_width) | |
| views = [] | |
| for i in range(total_num_blocks): | |
| h_start = int((i // num_blocks_width) * stride) | |
| h_end = h_start + window_size | |
| w_start = int((i % num_blocks_width) * stride) | |
| w_end = w_start + window_size | |
| views.append((h_start, h_end, w_start, w_end)) | |
| return views | |
| class MultiDiffusion(nn.Module): | |
| def __init__(self, device, sd_version='2.0', hf_key=None): | |
| super().__init__() | |
| self.device = device | |
| self.sd_version = sd_version | |
| print(f'[INFO] loading stable diffusion...') | |
| if hf_key is not None: | |
| print(f'[INFO] using hugging face custom model key: {hf_key}') | |
| model_key = hf_key | |
| elif self.sd_version == '2.1': | |
| model_key = "stabilityai/stable-diffusion-2-1-base" | |
| elif self.sd_version == '2.0': | |
| model_key = "stabilityai/stable-diffusion-2-base" | |
| elif self.sd_version == '1.5': | |
| model_key = "runwayml/stable-diffusion-v1-5" | |
| else: | |
| model_key = self.sd_version #For custom models or fine-tunes, allow people to use arbitrary versions | |
| #raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') | |
| # Create model | |
| self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) | |
| self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) | |
| self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
| print(f'[INFO] loaded stable diffusion!') | |
| def get_random_background(self, n_samples): | |
| # sample random background with a constant rgb value | |
| backgrounds = torch.rand(n_samples, 3, device=self.device)[:, :, None, None].repeat(1, 1, 512, 512) | |
| return torch.cat([self.encode_imgs(bg.unsqueeze(0)) for bg in backgrounds]) | |
| def get_text_embeds(self, prompt, negative_prompt): | |
| # Tokenize text and get embeddings | |
| text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
| truncation=True, return_tensors='pt') | |
| text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
| # Do the same for unconditional embeddings | |
| uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
| return_tensors='pt') | |
| uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
| # Cat for final embeddings | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| return text_embeddings | |
| def encode_imgs(self, imgs): | |
| imgs = 2 * imgs - 1 | |
| posterior = self.vae.encode(imgs).latent_dist | |
| latents = posterior.sample() * 0.18215 | |
| return latents | |
| def decode_latents(self, latents): | |
| latents = 1 / 0.18215 * latents | |
| imgs = self.vae.decode(latents).sample | |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
| return imgs | |
| def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50, | |
| guidance_scale=7.5, bootstrapping=20): | |
| # get bootstrapping backgrounds | |
| # can move this outside of the function to speed up generation. i.e., calculate in init | |
| bootstrapping_backgrounds = self.get_random_background(bootstrapping) | |
| # Prompts -> text embeds | |
| text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768] | |
| # Define panorama grid and get views | |
| latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device) | |
| noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1) | |
| views = get_views(height, width) | |
| count = torch.zeros_like(latent) | |
| value = torch.zeros_like(latent) | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| with torch.autocast('cuda'): | |
| for i, t in enumerate(self.scheduler.timesteps): | |
| count.zero_() | |
| value.zero_() | |
| for h_start, h_end, w_start, w_end in views: | |
| masks_view = masks[:, :, h_start:h_end, w_start:w_end] | |
| latent_view = latent[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1) | |
| if i < bootstrapping: | |
| bg = bootstrapping_backgrounds[torch.randint(0, bootstrapping, (len(prompts) - 1,))] | |
| bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t) | |
| latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:]) | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latent_view] * 2) | |
| # predict the noise residual | |
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample'] | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the denoising step with the reference model | |
| latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample'] | |
| value[:, :, h_start:h_end, w_start:w_end] += (latents_view_denoised * masks_view).sum(dim=0, | |
| keepdims=True) | |
| count[:, :, h_start:h_end, w_start:w_end] += masks_view.sum(dim=0, keepdims=True) | |
| # take the MultiDiffusion step | |
| latent = torch.where(count > 0, value / count, value) | |
| # Img latents -> imgs | |
| imgs = self.decode_latents(latent) # [1, 3, 512, 512] | |
| img = T.ToPILImage()(imgs[0].cpu()) | |
| return img | |
| def preprocess_mask(mask_path, h, w, device): | |
| mask = np.array(Image.open(mask_path).convert("L")) | |
| mask = mask.astype(np.float32) / 255.0 | |
| mask = mask[None, None] | |
| mask[mask < 0.5] = 0 | |
| mask[mask >= 0.5] = 1 | |
| mask = torch.from_numpy(mask).to(device) | |
| mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest') | |
| return mask | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--mask_paths', type=list) | |
| parser.add_argument('--bg_prompt', type=str) | |
| parser.add_argument('--bg_negative', type=str) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image' | |
| parser.add_argument('--fg_prompts', type=list) | |
| parser.add_argument('--fg_negative', type=list) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image' | |
| parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'], | |
| help="stable diffusion version") | |
| parser.add_argument('--H', type=int, default=768) | |
| parser.add_argument('--W', type=int, default=512) | |
| parser.add_argument('--seed', type=int, default=0) | |
| parser.add_argument('--steps', type=int, default=50) | |
| parser.add_argument('--bootstrapping', type=int, default=20) | |
| opt = parser.parse_args() | |
| seed_everything(opt.seed) | |
| device = torch.device('cuda') | |
| sd = MultiDiffusion(device, opt.sd_version) | |
| fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths]) | |
| bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) | |
| bg_mask[bg_mask < 0] = 0 | |
| masks = torch.cat([bg_mask, fg_masks]) | |
| prompts = [opt.bg_prompt] + opt.fg_prompts | |
| neg_prompts = [opt.bg_negative] + opt.fg_negative | |
| img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping) | |
| # save image | |
| img.save('out.png') | |