Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData | |
| from pytorch3d.renderer.cameras import CamerasBase | |
| from pytorch3d.structures import Pointclouds | |
| from torch import Tensor | |
| from .point_cloud_transformer_model import PointCloudTransformerModel | |
| from .projection_model import PointCloudProjectionModel | |
| class PointCloudColoringModel(PointCloudProjectionModel): | |
| def __init__( | |
| self, | |
| point_cloud_model: str, | |
| point_cloud_model_layers: int, | |
| point_cloud_model_embed_dim: int, | |
| **kwargs, # projection arguments | |
| ): | |
| super().__init__(**kwargs) | |
| # Checks | |
| if self.predict_shape or not self.predict_color: | |
| raise NotImplementedError('Must predict color, not shape, for coloring') | |
| # Create point cloud model for processing point cloud | |
| self.point_cloud_model = PointCloudTransformerModel( | |
| num_layers=point_cloud_model_layers, | |
| model_type=point_cloud_model, | |
| embed_dim=point_cloud_model_embed_dim, | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, | |
| ) # why use transformer instead??? | |
| def _forward( | |
| self, | |
| pc: Pointclouds, | |
| camera: Optional[CamerasBase], | |
| image_rgb: Optional[Tensor], | |
| mask: Optional[Tensor], | |
| return_point_cloud: bool = False, | |
| noise_std: float = 0.0, | |
| ): | |
| # Normalize colors and convert to tensor | |
| x = self.point_cloud_to_tensor(pc, normalize=True, scale=True) | |
| x_points, x_colors = x[:, :, :3], x[:, :, 3:] | |
| # Add noise to points. TODO: Add to configs. | |
| x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc? | |
| # Conditioning | |
| # x_input = self.get_input_with_conditioning(x_input, camera=camera, | |
| # image_rgb=image_rgb, mask=mask) | |
| # XH: edit to run | |
| x_input = self.get_input_with_conditioning(x_input, camera=camera, | |
| image_rgb=image_rgb, mask=mask, t=None) | |
| # Forward | |
| pred_colors = self.point_cloud_model(x_input) | |
| # During inference, we return the point cloud with the predicted colors | |
| if return_point_cloud: | |
| pred_pointcloud = self.tensor_to_point_cloud( | |
| torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True) | |
| return pred_pointcloud | |
| # During training, we have ground truth colors and return the loss | |
| loss = F.mse_loss(pred_colors, x_colors) | |
| return loss | |
| def forward(self, batch: FrameData, **kwargs): | |
| """A wrapper around the forward method""" | |
| if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict | |
| batch = FrameData(**batch) # it really makes no sense, I do not understand it | |
| return self._forward( | |
| pc=batch.sequence_point_cloud, | |
| camera=batch.camera, | |
| image_rgb=batch.image_rgb, | |
| mask=batch.fg_probability, | |
| **kwargs, | |
| ) |