| | import torch |
| | import torch.nn as nn |
| | from models.modules.resunet import ResUnet_DirectAttenMultiImg_Cond |
| | from models.modules.parpoints_encoder import ParPoint_Encoder |
| | from models.modules.PointEMB import PointEmbed |
| | from models.modules.utils import StackedRandomGenerator |
| | from models.modules.diffusion_sampler import edm_sampler |
| | from models.modules.encoder import DiagonalGaussianDistribution |
| | import numpy as np |
| | class EDMLoss_MultiImgCond: |
| | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,use_par=False): |
| | self.P_mean = P_mean |
| | self.P_std = P_std |
| | self.sigma_data = sigma_data |
| | self.use_par=use_par |
| |
|
| | def __call__(self, net, data_batch, classifier_free=False): |
| | inputs = data_batch['input'] |
| | image=data_batch['image'] |
| | proj_mat=data_batch['proj_mat'] |
| | valid_frames=data_batch['valid_frames'] |
| | par_points=data_batch["par_points"] |
| | category_code=data_batch["category_code"] |
| | rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1], device=inputs.device) |
| |
|
| | sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
| | weight = (sigma ** 2 + self.sigma_data ** 2) / (self.sigma_data * sigma) ** 2 |
| | y=inputs |
| |
|
| | n = torch.randn_like(y) * sigma |
| |
|
| | |
| | |
| | if classifier_free and np.random.random()<0.5: |
| | image=torch.zeros_like(image).float().cuda() |
| | net.module.extract_img_feat(image) |
| | net.module.set_proj_matrix(proj_mat) |
| | net.module.set_valid_frames(valid_frames) |
| | net.module.set_category_code(category_code) |
| | if self.use_par: |
| | net.module.extract_point_feat(par_points) |
| |
|
| | D_yn = net(y + n,sigma) |
| | loss = weight * ((D_yn - y) ** 2) |
| | return loss |
| |
|
| | class Triplane_Diff_MultiImgCond_EDM(nn.Module): |
| | def __init__(self,opt): |
| | super().__init__() |
| | self.diff_reso=opt['diff_reso'] |
| | self.diff_dim=opt['output_channel'] |
| | self.use_cat_embedding=opt['use_cat_embedding'] |
| | self.use_fp16=False |
| | self.sigma_data=0.5 |
| | self.sigma_max=float("inf") |
| | self.sigma_min=0 |
| | self.use_par=opt['use_par'] |
| | self.triplane_padding=opt['triplane_padding'] |
| | self.block_type=opt['block_type'] |
| | |
| | if opt['backbone']=="resunet_multiimg_direct_atten": |
| | self.denoise_model=ResUnet_DirectAttenMultiImg_Cond(channel=opt['input_channel'], |
| | output_channel=opt['output_channel'],use_par=opt['use_par'],par_channel=opt['par_channel'], |
| | img_in_channels=opt['img_in_channels'],vit_reso=opt['vit_reso'],triplane_padding=self.triplane_padding, |
| | norm=opt['norm'],use_cat_embedding=self.use_cat_embedding,block_type=self.block_type) |
| | else: |
| | raise NotImplementedError |
| | if opt['use_par']: |
| | par_emb_dim = opt['par_emb_dim'] |
| | par_args = opt['par_point_encoder'] |
| | self.point_embedder = PointEmbed(hidden_dim=par_emb_dim) |
| | self.par_points_encoder = ParPoint_Encoder(c_dim=par_args['plane_latent_dim'], dim=par_emb_dim, |
| | plane_resolution=par_args['plane_reso'], |
| | unet_kwargs=par_args['unet']) |
| | self.unflatten = torch.nn.Unflatten(1, (16, 16)) |
| | def prepare_data(self,data_batch): |
| | |
| | device=torch.device("cuda") |
| | means, logvars = data_batch['triplane_mean'].to(device, non_blocking=True), data_batch['triplane_logvar'].to( |
| | device, non_blocking=True) |
| | distribution = DiagonalGaussianDistribution(means, logvars) |
| | plane_feat = distribution.sample() |
| |
|
| | image=data_batch["image"].to(device) |
| | proj_mat = data_batch['proj_mat'].to(device, non_blocking=True) |
| | valid_frames=data_batch["valid_frames"].to(device,non_blocking=True) |
| | par_points=data_batch["par_points"].to(device,non_blocking=True) |
| | category_code=data_batch["category_code"].to(device,non_blocking=True) |
| | input_dict = {"input": plane_feat.float(), |
| | "image": image.float(), |
| | "par_points":par_points.float(), |
| | "proj_mat":proj_mat.float(), |
| | "category_code":category_code.float(), |
| | "valid_frames":valid_frames.float()} |
| |
|
| | return input_dict |
| |
|
| | def prepare_sample_data(self,data_batch): |
| | device=torch.device("cuda") |
| | image=data_batch['image'].to(device, non_blocking=True) |
| | proj_mat = data_batch['proj_mat'].to(device, non_blocking=True) |
| | valid_frames = data_batch["valid_frames"].to(device, non_blocking=True) |
| | par_points = data_batch["par_points"].to(device, non_blocking=True) |
| | category_code=data_batch["category_code"].to(device,non_blocking=True) |
| | sample_dict={ |
| | "image":image.float(), |
| | "proj_mat":proj_mat.float(), |
| | "valid_frames":valid_frames.float(), |
| | "category_code":category_code.float(), |
| | "par_points":par_points.float(), |
| | } |
| | return sample_dict |
| |
|
| | def prepare_eval_data(self,data_batch): |
| | device=torch.device("cuda") |
| | samples=data_batch["points"].to(device, non_blocking=True) |
| | labels=data_batch['labels'].to(device,non_blocking=True) |
| |
|
| | eval_dict={ |
| | "samples":samples, |
| | "labels":labels, |
| | } |
| | return eval_dict |
| |
|
| | def extract_point_feat(self,par_points): |
| | par_emb=self.point_embedder(par_points) |
| | self.par_feat=self.par_points_encoder(par_points,par_emb) |
| |
|
| | def extract_img_feat(self,image): |
| | self.image_emb=image |
| |
|
| | def set_proj_matrix(self,proj_matrix): |
| | self.proj_matrix=proj_matrix |
| |
|
| | def set_valid_frames(self,valid_frames): |
| | self.valid_frames=valid_frames |
| |
|
| | def set_category_code(self,category_code): |
| | self.category_code=category_code |
| |
|
| | def forward(self, x, sigma,force_fp32=False): |
| | x = x.to(torch.float32) |
| | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) |
| | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 |
| |
|
| | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) |
| | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() |
| | c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() |
| | c_noise = sigma.log() / 4 |
| |
|
| | if self.use_par: |
| | F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(), self.image_emb, self.proj_matrix, |
| | self.valid_frames,self.category_code,self.par_feat) |
| | else: |
| | F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(),self.image_emb,self.proj_matrix, |
| | self.valid_frames,self.category_code) |
| | assert F_x.dtype == dtype |
| | D_x = c_skip * x + c_out * F_x.to(torch.float32) |
| | return D_x |
| |
|
| | def round_sigma(self, sigma): |
| | return torch.as_tensor(sigma) |
| |
|
| | @torch.no_grad() |
| | def sample(self, input_batch, batch_seeds=None,ret_all=False,num_steps=18): |
| | img_cond=input_batch['image'] |
| | proj_mat=input_batch['proj_mat'] |
| | valid_frames=input_batch["valid_frames"] |
| | category_code=input_batch["category_code"] |
| | if img_cond is not None: |
| | batch_size, device = img_cond.shape[0], img_cond.device |
| | if batch_seeds is None: |
| | batch_seeds = torch.arange(batch_size) |
| | else: |
| | device = batch_seeds.device |
| | batch_size = batch_seeds.shape[0] |
| |
|
| | self.extract_img_feat(img_cond) |
| | self.set_proj_matrix(proj_mat) |
| | self.set_valid_frames(valid_frames) |
| | self.set_category_code(category_code) |
| | if self.use_par: |
| | par_points=input_batch["par_points"] |
| | self.extract_point_feat(par_points) |
| | rnd = StackedRandomGenerator(device, batch_seeds) |
| | latents = rnd.randn([batch_size, self.diff_dim, self.diff_reso*3,self.diff_reso], device=device) |
| |
|
| | return edm_sampler(self, latents, randn_like=rnd.randn_like,ret_all=ret_all,sigma_min=0.002, sigma_max=80,num_steps=num_steps) |
| |
|
| |
|
| |
|