Spaces:
Sleeping
Sleeping
| import inspect | |
| import random | |
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.schedulers.scheduling_ddpm import DDPMScheduler | |
| from diffusers.schedulers.scheduling_ddim import DDIMScheduler | |
| from diffusers.schedulers.scheduling_pndm import PNDMScheduler | |
| from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData | |
| from pytorch3d.renderer.cameras import CamerasBase | |
| from pytorch3d.structures import Pointclouds | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| from .model_utils import get_num_points, get_custom_betas | |
| from .point_cloud_model import PointCloudModel | |
| from .projection_model import PointCloudProjectionModel | |
| class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel): | |
| def __init__( | |
| self, | |
| beta_start: float, | |
| beta_end: float, | |
| beta_schedule: str, | |
| point_cloud_model: str, | |
| point_cloud_model_embed_dim: int, | |
| **kwargs, # projection arguments | |
| ): | |
| super().__init__(**kwargs) | |
| # Checks | |
| if not self.predict_shape: | |
| raise NotImplementedError('Must predict shape if performing diffusion.') | |
| # Create diffusion model schedulers which define the sampling timesteps | |
| self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon") | |
| assert self.dm_pred_type in ['epsilon','sample'] | |
| scheduler_kwargs = {"prediction_type": self.dm_pred_type} | |
| if beta_schedule == 'custom': | |
| scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end))) | |
| else: | |
| scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule)) | |
| self.schedulers_map = { | |
| 'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False), | |
| 'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False), | |
| 'pndm': PNDMScheduler(**scheduler_kwargs), | |
| } | |
| self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference | |
| # Create point cloud model for processing point cloud at each diffusion step | |
| self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim) | |
| self.load_sample_init = kwargs.get('load_sample_init', False) | |
| self.sample_init_scale = kwargs.get('sample_init_scale', 1.0) | |
| self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False) | |
| self.consistent_center = kwargs.get('consistent_center', False) | |
| self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps | |
| def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim): | |
| self.point_cloud_model = PointCloudModel( | |
| model_type=point_cloud_model, | |
| embed_dim=point_cloud_model_embed_dim, | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, # voxel resolution multiplier is 1. | |
| voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1) | |
| ) | |
| def forward_train( | |
| self, | |
| pc: Pointclouds, | |
| camera: Optional[CamerasBase], | |
| image_rgb: Optional[Tensor], | |
| mask: Optional[Tensor], | |
| return_intermediate_steps: bool = False, | |
| **kwargs | |
| ): | |
| # Normalize colors and convert to tensor | |
| x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors | |
| B, N, D = x_0.shape | |
| # Sample random noise | |
| noise = torch.randn_like(x_0) | |
| if self.consistent_center: | |
| # modification suggested by https://arxiv.org/pdf/2308.07837.pdf | |
| noise = noise - torch.mean(noise, dim=1, keepdim=True) | |
| # Sample random timesteps for each point_cloud | |
| timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), | |
| device=self.device, dtype=torch.long) | |
| # Add noise to points | |
| x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features | |
| # add noise to the camera pose, based on timestamps | |
| if self.cam_noise_std > 0.000001: | |
| # the noise is very different | |
| camera = camera.clone() | |
| camT = camera.T # (B, 3) | |
| dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True)) | |
| nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise | |
| tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio | |
| camera.T = camera.T + tnoise | |
| # Conditioning, the pixel-aligned feature is based on points with noise (new points) | |
| x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs) | |
| # Forward | |
| loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input) | |
| # Whether to return intermediate steps | |
| if return_intermediate_steps: | |
| return loss, (x_0, x_t, noise, noise_pred) | |
| return loss | |
| def compute_loss(self, noise, timestep, x_0, x_t_input): | |
| x_pred = torch.zeros_like(x_0) | |
| if self.self_conditioning: | |
| # self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw | |
| if random.uniform(0, 1.) > 0.5: | |
| with torch.no_grad(): | |
| x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) | |
| noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) | |
| else: | |
| noise_pred = self.point_cloud_model(x_t_input, timestep) | |
| # Check | |
| if not noise_pred.shape == noise.shape: | |
| raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') | |
| # Loss | |
| if self.dm_pred_type == 'epsilon': | |
| loss = F.mse_loss(noise_pred, noise) | |
| elif self.dm_pred_type == 'sample': | |
| loss = F.mse_loss(noise_pred, x_0) # predicting sample | |
| else: | |
| raise NotImplementedError | |
| return loss, noise_pred | |
| def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs): | |
| "return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)" | |
| x_t_input = self.get_input_with_conditioning(x_t, camera=camera, | |
| image_rgb=image_rgb, mask=mask, t=timestep) | |
| return x_t_input | |
| def forward_sample( | |
| self, | |
| num_points: int, | |
| camera: Optional[CamerasBase], | |
| image_rgb: Optional[Tensor], | |
| mask: Optional[Tensor], | |
| # Optional overrides | |
| scheduler: Optional[str] = 'ddpm', | |
| # Inference parameters | |
| num_inference_steps: Optional[int] = 1000, | |
| eta: Optional[float] = 0.0, # for DDIM | |
| # Whether to return all the intermediate steps in generation | |
| return_sample_every_n_steps: int = -1, | |
| # Whether to disable tqdm | |
| disable_tqdm: bool = False, | |
| gt_pc: Pointclouds = None, | |
| **kwargs | |
| ): | |
| # Get scheduler from mapping, or use self.scheduler if None | |
| scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] | |
| # Get the size of the noise | |
| N = num_points | |
| B = 1 if image_rgb is None else image_rgb.shape[0] | |
| D = self.get_x_T_channel() | |
| device = self.device if image_rgb is None else image_rgb.device | |
| sample_from_interm = kwargs.get('sample_from_interm', False) | |
| interm_steps = kwargs.get('noise_step') if sample_from_interm else -1 | |
| x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler) | |
| x_pred = torch.zeros_like(x_t) | |
| # Set timesteps | |
| extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler) | |
| # Loop over timesteps | |
| all_outputs = [] | |
| return_all_outputs = (return_sample_every_n_steps > 0) | |
| progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) | |
| for i, t in enumerate(progress_bar): | |
| add_interm_output = (return_all_outputs and ( | |
| i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)) | |
| # Conditioning | |
| x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs) | |
| if self.self_conditioning: | |
| x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning | |
| inference_binary = (i == len(progress_bar) - 1) | add_interm_output | |
| # One reverse step with conditioning | |
| x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input, | |
| inference_binary=inference_binary) # (B, N, D), D=3 or 4 | |
| x_pred = x_t # for next iteration self conditioning | |
| # Append to output list if desired | |
| if add_interm_output: | |
| all_outputs.append(x_t) | |
| # Convert output back into a point cloud, undoing normalization and scaling | |
| output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale | |
| if return_all_outputs: | |
| all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) | |
| all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs] | |
| return (output, all_outputs) if return_all_outputs else output | |
| def get_x_T_channel(self): | |
| D = 3 + (self.color_channels if self.predict_color else 0) | |
| return D | |
| def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None): | |
| B, N, D = shape | |
| # Sample noise initialization | |
| if interm_steps > 0: | |
| # Sample from some intermediate steps | |
| x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True) | |
| noise = torch.randn(B, N, D, device=device) | |
| # always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD! | |
| noise = noise - torch.mean(noise, dim=1, keepdim=True) | |
| x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise | |
| else: | |
| # Sample from random Gaussian | |
| x_t = torch.randn(B, N, D, device=device) | |
| x_t = x_t * self.sample_init_scale # for test | |
| if self.consistent_center: | |
| x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) | |
| return x_t | |
| def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): | |
| """ | |
| run one reverse step to compute x_t | |
| :param extra_step_kwargs: | |
| :param scheduler: | |
| :param t: [1], diffusion time step | |
| :param x_t: (B, N, 3) | |
| :param x_t_input: conditional features (B, N, F) | |
| :param kwargs: other configurations to run diffusion step | |
| :return: denoised x_t | |
| """ | |
| B = x_t.shape[0] | |
| # Forward | |
| noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B)) | |
| if self.consistent_center: | |
| assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!' | |
| # suggested by the CCD-3DR paper | |
| noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True) | |
| # Step | |
| x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample | |
| if self.consistent_center: | |
| x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) | |
| return x_t | |
| def setup_reverse_process(self, eta, num_inference_steps, scheduler): | |
| """ | |
| setup diffusion chain, and others. | |
| """ | |
| accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| extra_set_kwargs = {"offset": 1} if accepts_offset else {} | |
| scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
| # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
| # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
| # and should be between [0, 1] | |
| accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) | |
| extra_step_kwargs = {"eta": eta} if accepts_eta else {} | |
| return extra_step_kwargs | |
| def forward(self, batch: FrameData, mode: str = 'train', **kwargs): | |
| """ | |
| A wrapper around the forward method for training and inference | |
| """ | |
| if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict | |
| batch = FrameData(**batch) # it really makes no sense, I do not understand it | |
| if mode == 'train': | |
| return self.forward_train( | |
| pc=batch.sequence_point_cloud, | |
| camera=batch.camera, | |
| image_rgb=batch.image_rgb, | |
| mask=batch.fg_probability, | |
| **kwargs) | |
| elif mode == 'sample': | |
| num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud)) | |
| return self.forward_sample( | |
| num_points=num_points, | |
| camera=batch.camera, | |
| image_rgb=batch.image_rgb, | |
| mask=batch.fg_probability, | |
| **kwargs) | |
| else: | |
| raise NotImplementedError() |