| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import sys |
| | import os |
| | from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule |
| | from .utils import zero_module |
| | from .Positional_Embedding import PositionalEmbedding |
| |
|
| | class Pointnet2Encoder(nn.Module): |
| | def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]): |
| | super().__init__() |
| | self.sa1 = PointnetSAModuleVotes( |
| | npoint=npoints[0], |
| | radius=radius[0], |
| | nsample=nsample[0], |
| | mlp=[input_feature_dim, 64, 64, 128], |
| | use_xyz=True, |
| | normalize_xyz=True |
| | ) |
| |
|
| | self.sa2 = PointnetSAModuleVotes( |
| | npoint=npoints[1], |
| | radius=radius[1], |
| | nsample=nsample[1], |
| | mlp=[128, 128, 128, 256], |
| | use_xyz=True, |
| | normalize_xyz=True |
| | ) |
| |
|
| | self.sa3 = PointnetSAModuleVotes( |
| | npoint=npoints[2], |
| | radius=radius[2], |
| | nsample=nsample[2], |
| | mlp=[256, 256, 256, 512], |
| | use_xyz=True, |
| | normalize_xyz=True |
| | ) |
| |
|
| | self.sa4 = PointnetSAModuleVotes( |
| | npoint=npoints[3], |
| | radius=radius[3], |
| | nsample=nsample[3], |
| | mlp=[512, 512, 512, 512], |
| | use_xyz=True, |
| | normalize_xyz=True |
| | ) |
| | def _break_up_pc(self, pc): |
| | xyz = pc[..., 0:3].contiguous() |
| | features = ( |
| | pc[..., 3:].transpose(1, 2).contiguous() |
| | if pc.size(-1) > 3 else None |
| | ) |
| |
|
| | return xyz, features |
| | def forward(self,pointcloud,end_points=None): |
| | if not end_points: end_points = {} |
| | batch_size = pointcloud.shape[0] |
| |
|
| | xyz, features = self._break_up_pc(pointcloud) |
| |
|
| | end_points['org_xyz'] = xyz |
| | |
| | xyz1, features1, _ = self.sa1(xyz, features) |
| | end_points['sa1_xyz'] = xyz1 |
| | end_points['sa1_features'] = features1 |
| |
|
| | xyz2, features2, _ = self.sa2(xyz1, features1) |
| | end_points['sa2_xyz'] = xyz2 |
| | end_points['sa2_features'] = features2 |
| |
|
| | xyz3, features3, _ = self.sa3(xyz2, features2) |
| | end_points['sa3_xyz'] = xyz3 |
| | end_points['sa3_features'] = features3 |
| | |
| | xyz4, features4, _ = self.sa4(xyz3, features3) |
| | end_points['sa4_xyz'] = xyz4 |
| | end_points['sa4_features'] = features4 |
| | |
| | return end_points |
| |
|
| |
|
| |
|
| | class PointUNet(nn.Module): |
| | r""" |
| | Backbone network for point cloud feature learning. |
| | Based on Pointnet++ single-scale grouping network. |
| | |
| | Parameters |
| | ---------- |
| | input_feature_dim: int |
| | Number of input channels in the feature descriptor for each point. |
| | e.g. 3 for RGB. |
| | """ |
| |
|
| | def __init__(self): |
| | super().__init__() |
| |
|
| | self.noisy_encoder=Pointnet2Encoder() |
| | self.cond_encoder=Pointnet2Encoder() |
| | self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512]) |
| | self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512]) |
| | |
| | self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256]) |
| | self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256]) |
| | |
| | self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128]) |
| | self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128]) |
| | |
| | self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128]) |
| | self.fp4 = PointnetFPModule(mlp=[128, 128, 128]) |
| | |
| |
|
| | self.output_layer=nn.Sequential( |
| | nn.LayerNorm(128), |
| | zero_module(nn.Linear(in_features=128,out_features=3,bias=False)) |
| | ) |
| | self.t_emb_layer = PositionalEmbedding(256) |
| | self.map_layer0 = nn.Linear(in_features=256, out_features=512) |
| | self.map_layer1 = nn.Linear(in_features=512, out_features=512) |
| |
|
| | def forward(self, noise_points, t,cond_points): |
| | r""" |
| | Forward pass of the network |
| | |
| | Parameters |
| | ---------- |
| | pointcloud: Variable(torch.cuda.FloatTensor) |
| | (B, N, 3 + input_feature_dim) tensor |
| | Point cloud to run predicts on |
| | Each point in the point-cloud MUST |
| | be formated as (x, y, z, features...) |
| | |
| | Returns |
| | ---------- |
| | end_points: {XXX_xyz, XXX_features, XXX_inds} |
| | XXX_xyz: float32 Tensor of shape (B,K,3) |
| | XXX_features: float32 Tensor of shape (B,K,D) |
| | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] |
| | """ |
| | t_emb = self.t_emb_layer(t) |
| | t_emb = F.silu(self.map_layer0(t_emb)) |
| | t_emb = F.silu(self.map_layer1(t_emb)) |
| | t_emb = t_emb[:, :, None] |
| | noise_end_points=self.noisy_encoder(noise_points) |
| | cond=self.cond_encoder(cond_points) |
| | |
| | features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb, |
| | cond['sa4_features']) |
| | features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'], |
| | features) |
| | features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features, |
| | cond["sa3_features"]) |
| | features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'], |
| | features) |
| | features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features, |
| | cond['sa2_features']) |
| | features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features) |
| | features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features, |
| | cond['sa1_features']) |
| | features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features) |
| | features=features.transpose(1,2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | output_points=self.output_layer(features) |
| |
|
| | return output_points |
| |
|
| |
|
| | if __name__ == '__main__': |
| | net=PointUNet().cuda().float() |
| | net=net.eval() |
| | noise_points=torch.randn(16,4096,3).cuda().float() |
| | cond_points=torch.randn(16,4096,3).cuda().float() |
| | t=torch.randn(16).cuda().float() |
| | cond_encoder=Pointnet2Encoder().cuda().float() |
| |
|
| | out = net(noise_points,cond_points) |
| | print(out.shape) |