File size: 5,236 Bytes
1ac2018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import lightning.pytorch as pl
import torch.nn as nn
import os
import trimesh
import numpy as np

from torch.utils.data import DataLoader

from third_party.PartField.partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
from third_party.PartField.partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
from third_party.PartField.partfield.model.model_utils import VanillaMLP
from third_party.PartField.partfield.dataloader import Demo_Dataset

class Model(pl.LightningModule):
    def __init__(self, cfg, obj_path):
        super().__init__()
        self.obj_path = obj_path

        self.save_hyperparameters()
        self.cfg = cfg
        self.automatic_optimization = False
        self.triplane_resolution = cfg.triplane_resolution
        self.triplane_channels_low = cfg.triplane_channels_low
        self.triplane_transformer = TriplaneTransformer(
            input_dim=cfg.triplane_channels_low * 2,
            transformer_dim=1024,
            transformer_layers=6,
            transformer_heads=8,
            triplane_low_res=32,
            triplane_high_res=128,
            triplane_dim=cfg.triplane_channels_high,
        )
        self.sdf_decoder = VanillaMLP(input_dim=64,
                                      output_dim=1, 
                                      out_activation="tanh", 
                                      n_neurons=64, #64
                                      n_hidden_layers=6) #6
        self.use_pvcnn = cfg.use_pvcnnonly
        self.use_2d_feat = cfg.use_2d_feat
        if self.use_pvcnn:
            self.pvcnn = TriPlanePC2Encoder(
                cfg.pvcnn,
                device="cuda",
                shape_min=-1, 
                shape_length=2,
                use_2d_feat=self.use_2d_feat) #.cuda()
        self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
        self.grid_coord = get_grid_coord(256)
        self.mse_loss = torch.nn.MSELoss()
        self.l1_loss = torch.nn.L1Loss(reduction='none')

        if cfg.regress_2d_feat:
            self.feat_decoder = VanillaMLP(input_dim=64,
                                output_dim=192, 
                                out_activation="GELU", 
                                n_neurons=64, #64
                                n_hidden_layers=6) #6   
        
    def predict_dataloader(self):
        dataset = Demo_Dataset(self.obj_path)
        dataloader = DataLoader(dataset, 
                            num_workers=self.cfg.dataset.val_num_workers,
                            batch_size=self.cfg.dataset.val_batch_size,
                            shuffle=False, 
                            pin_memory=True,
                            drop_last=False)
        
        return dataloader            

    @torch.no_grad()
    def predict_step(self, batch, batch_idx):
        N = batch['pc'].shape[0]
        
        assert N == 1

        pc_feat = self.pvcnn(batch['pc'], batch['pc'])

        planes = pc_feat
        planes = self.triplane_transformer(planes)
        sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
        
        def sample_points(vertices, faces, n_point_per_face):
            # Generate random barycentric coordinates
            # borrowed from Kaolin https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/mesh/trianglemesh.py#L43
            n_f = faces.shape[0]
            u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1),
                                        device=vertices.device,
                                        dtype=vertices.dtype))
            v = torch.rand((n_f, n_point_per_face, 1),
                            device=vertices.device,
                            dtype=vertices.dtype)
            w0 = 1 - u
            w1 = u * (1 - v)
            w2 = u * v

            face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1))
            face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1))
            face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1))
            points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1)
            return points

        def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face):
            n_sample_each = self.cfg.n_sample_each # we iterate over this to avoid OOM
            n_v = tensor_vertices.shape[1]
            n_sample = n_v // n_sample_each + 1
            all_sample = []
            for i_sample in range(n_sample):
                sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,])
                assert sampled_feature.shape[1] % n_point_per_face == 0
                sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1])
                sampled_feature = torch.mean(sampled_feature, axis=-2)
                all_sample.append(sampled_feature)
            return torch.cat(all_sample, dim=1)

        part_planes = part_planes.cpu().numpy()
        return part_planes, batch['uid'][0]