Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.models import AutoencoderKL | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| from safetensors.torch import load_model | |
| from diffusion import IDDPM, DPMS | |
| from diffusion.utils.misc import read_config | |
| from diffusion.model.nets import PixArtMS_XL_2, ControlPixArtMSMVHalfWithEncoder | |
| from diffusion.utils.data import ASPECT_RATIO_512_TEST | |
| from utils.camera import get_camera_poses | |
| from utils.postprocess import adaptive_instance_normalization, wavelet_reconstruction | |
| class Enhancer: | |
| def __init__(self, model_path, config_path): | |
| self.config = read_config(config_path) | |
| self.image_size = self.config.image_size | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.weight_dtype = torch.float16 | |
| self._load_model(model_path, self.config.pipeline_load_from) | |
| def _load_model(self, model_path, pipeline_load_from): | |
| self.tokenizer = T5Tokenizer.from_pretrained(pipeline_load_from, subfolder="tokenizer") | |
| self.text_encoder = T5EncoderModel.from_pretrained(pipeline_load_from, subfolder="text_encoder", torch_dtype=self.weight_dtype).to(self.device) | |
| self.vae = AutoencoderKL.from_pretrained(pipeline_load_from, subfolder="vae", torch_dtype=self.weight_dtype).to(self.device) | |
| del self.vae.encoder # we do not use vae encoder | |
| # only support fixed latent size currently | |
| latent_size = self.image_size // 8 | |
| lewei_scale = {512: 1, 1024: 2} | |
| model_kwargs = { | |
| "model_max_length": self.config.model_max_length, | |
| "qk_norm": self.config.qk_norm, | |
| "kv_compress_config": self.config.kv_compress_config if self.config.kv_compress else None, | |
| "micro_condition": self.config.micro_condition, | |
| "use_crossview_module": getattr(self.config, 'use_crossview_module', False), | |
| } | |
| model = PixArtMS_XL_2(input_size=latent_size, pe_interpolation=lewei_scale[self.image_size], **model_kwargs).to(self.device) | |
| model = ControlPixArtMSMVHalfWithEncoder(model).to(self.weight_dtype).to(self.device) | |
| load_model(model, model_path) | |
| model.eval() | |
| self.model = model | |
| self.noise_maker = IDDPM(str(self.config.train_sampling_steps)) | |
| def _encode_prompt(self, text_prompt, n_views): | |
| txt_tokens = self.tokenizer( | |
| text_prompt, | |
| max_length=self.config.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| caption_embs = self.text_encoder( | |
| txt_tokens.input_ids, | |
| attention_mask=txt_tokens.attention_mask)[0][:, None] | |
| emb_masks = txt_tokens.attention_mask | |
| caption_embs = caption_embs.repeat_interleave(n_views, dim=0).to(self.weight_dtype) | |
| emb_masks = emb_masks.repeat_interleave(n_views, dim=0).to(self.weight_dtype) | |
| return caption_embs, emb_masks | |
| def inference(self, mv_imgs, c2ws, prompt="", fov=math.radians(49.1), noise_level=120, cfg_scale=4.5, sample_steps=20, color_shift=None): | |
| mv_imgs = F.interpolate(mv_imgs, size=(512, 512), mode='bilinear', align_corners=False) | |
| n_views = mv_imgs.shape[0] | |
| # pixle-sigma input tensor range is [-1, 1] | |
| mv_imgs = 2.*mv_imgs - 1. | |
| originial_mv_imgs = mv_imgs.clone().to(self.device) | |
| if noise_level == 0: | |
| noise_level = torch.zeros((n_views,)).long().to(self.device) | |
| else: | |
| noise_level = noise_level * torch.ones((n_views,)).long().to(self.device) | |
| mv_imgs = self.noise_maker.q_sample(mv_imgs.to(self.device), noise_level-1) | |
| cur_camera_pose, epipolar_constrains, cam_distances = get_camera_poses(c2ws=c2ws, fov=fov, h=mv_imgs.size(-2), w=mv_imgs.size(-1)) | |
| epipolar_constrains = epipolar_constrains.to(self.device) | |
| cam_distances = cam_distances.to(self.weight_dtype).to(self.device) | |
| caption_embs, emb_masks = self._encode_prompt(prompt, n_views) | |
| null_y = self.model.y_embedder.y_embedding[None].repeat(n_views, 1, 1)[:, None] | |
| latent_size_h, latent_size_w = mv_imgs.size(-2) // 8, mv_imgs.size(-1) // 8 | |
| z = torch.randn(n_views, 4, latent_size_h, latent_size_w, device=self.device) | |
| z_lq = self.model.encode( | |
| mv_imgs.to(self.weight_dtype).to(self.device), | |
| cur_camera_pose.to(self.weight_dtype).to(self.device), | |
| n_views=n_views, | |
| ) | |
| model_kwargs = dict( | |
| c=torch.cat([z_lq] * 2), | |
| data_info={}, | |
| mask=emb_masks, | |
| noise_level=torch.cat([noise_level] * 2), | |
| epipolar_constrains=torch.cat([epipolar_constrains] * 2), | |
| cam_distances=torch.cat([cam_distances] * 2), | |
| n_views=n_views, | |
| ) | |
| dpm_solver = DPMS( | |
| self.model.forward_with_dpmsolver, | |
| condition=caption_embs, | |
| uncondition=null_y, | |
| cfg_scale=cfg_scale, | |
| model_kwargs=model_kwargs | |
| ) | |
| samples = dpm_solver.sample( | |
| z, | |
| steps=sample_steps, | |
| order=2, | |
| skip_type="time_uniform", | |
| method="multistep", | |
| disable_progress_ui=False, | |
| ) | |
| samples = samples.to(self.weight_dtype) | |
| output_mv_imgs = self.vae.decode(samples / self.vae.config.scaling_factor).sample | |
| if color_shift == "adain": | |
| for i, output_mv_img in enumerate(output_mv_imgs): | |
| output_mv_imgs[i] = adaptive_instance_normalization(output_mv_img.unsqueeze(0), originial_mv_imgs[i:i+1]).squeeze(0) | |
| elif color_shift == "wavelet": | |
| for i, output_mv_img in enumerate(output_mv_imgs): | |
| output_mv_imgs[i] = wavelet_reconstruction(output_mv_img.unsqueeze(0), originial_mv_imgs[i:i+1]).squeeze(0) | |
| output_mv_imgs = torch.clamp((output_mv_imgs + 1.) / 2., 0, 1) | |
| torch.cuda.empty_cache() | |
| return output_mv_imgs |