Spaces:
Sleeping
Sleeping
| from contextlib import nullcontext | |
| import torch | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers import ModelMixin | |
| from torch import Tensor | |
| from .pvcnn.pvcnn import PVCNN2 | |
| from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus | |
| from .simple.simple_model import SimplePointModel | |
| class PointCloudModel(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| model_type: str = 'pvcnn', | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| embed_dim: int = 64, | |
| dropout: float = 0.1, | |
| width_multiplier: int = 1, | |
| voxel_resolution_multiplier: int = 1, | |
| ): | |
| super().__init__() | |
| self.model_type = model_type | |
| if self.model_type == 'pvcnn': | |
| self.autocast_context = torch.autocast('cuda', dtype=torch.float32) | |
| self.model = PVCNN2( | |
| embed_dim=embed_dim, | |
| num_classes=out_channels, | |
| extra_feature_channels=(in_channels - 3), | |
| dropout=dropout, width_multiplier=width_multiplier, | |
| voxel_resolution_multiplier=voxel_resolution_multiplier | |
| ) | |
| self.model.classifier[-1].bias.data.normal_(0, 1e-6) | |
| self.model.classifier[-1].weight.data.normal_(0, 1e-6) | |
| elif self.model_type == 'pvcnnplusplus': | |
| self.autocast_context = torch.autocast('cuda', dtype=torch.float32) | |
| self.model = PVCNN2PlusPlus( | |
| embed_dim=embed_dim, | |
| num_classes=out_channels, | |
| extra_feature_channels=(in_channels - 3), | |
| ) | |
| self.model.output_projection[-1].bias.data.normal_(0, 1e-6) | |
| self.model.output_projection[-1].weight.data.normal_(0, 1e-6) | |
| elif self.model_type == 'simple': | |
| self.autocast_context = nullcontext() | |
| self.model = SimplePointModel( | |
| embed_dim=embed_dim, | |
| num_classes=out_channels, | |
| extra_feature_channels=(in_channels - 3), | |
| ) | |
| self.model.output_projection.bias.data.normal_(0, 1e-6) | |
| self.model.output_projection.weight.data.normal_(0, 1e-6) | |
| else: | |
| raise NotImplementedError() | |
| def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor: | |
| """ Receives input of shape (B, N, in_channels) and returns output | |
| of shape (B, N, out_channels) """ | |
| with self.autocast_context: | |
| if not ret_feats: | |
| return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2) | |
| else: | |
| pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True) | |
| return pred.transpose(1, 2), feats |