File size: 3,498 Bytes
f3ff4f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gsplat as gs
import numpy as np
import torch
import torch.nn.functional as F
from easydict import EasyDict as edict


class GSplatRenderer:
    def __init__(self, rendering_options={}) -> None:
        self.pipe = edict({
            "kernel_size": 0.1,
            "convert_SHs_python": False,
            "compute_cov3D_python": False,
            "scale_modifier": 1.0,
            "debug": False,
            "use_mip_gaussian": True
        })
        self.rendering_options = edict({
            "resolution": None,
            "near": None,
            "far": None,
            "ssaa": 1,
            "bg_color": 'random',
        })
        self.rendering_options.update(rendering_options)
        self.bg_color = None

    def render(
            self,
            gaussian,
            extrinsics: torch.Tensor,
            intrinsics: torch.Tensor,
            colors_overwrite: torch.Tensor = None
    ) -> edict:

        resolution = self.rendering_options["resolution"]
        ssaa = self.rendering_options["ssaa"]

        if self.rendering_options["bg_color"] == 'random':
            self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
            if np.random.rand() < 0.5:
                self.bg_color += 1
        else:
            self.bg_color = torch.tensor(
                self.rendering_options["bg_color"],
                dtype=torch.float32,
                device="cuda"
            )

        height = resolution * ssaa
        width = resolution * ssaa

        # Set up background color
        if self.rendering_options["bg_color"] == 'random':
            self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
            if np.random.rand() < 0.5:
                self.bg_color += 1
        else:
            self.bg_color = torch.tensor(
                self.rendering_options["bg_color"],
                dtype=torch.float32,
                device="cuda"
            )

        Ks_scaled = intrinsics.clone()
        Ks_scaled[0, 0] *= width
        Ks_scaled[1, 1] *= height
        Ks_scaled[0, 2] *= width
        Ks_scaled[1, 2] *= height
        Ks_scaled = Ks_scaled.unsqueeze(0)

        near_plane = 0.01
        far_plane = 1000.0

        # Rasterize with gsplat
        render_colors, render_alphas, meta = gs.rasterization(
            means=gaussian.get_xyz,
            quats=F.normalize(gaussian.get_rotation, dim=-1),
            scales=gaussian.get_scaling / intrinsics[0, 0],
            opacities=gaussian.get_opacity.squeeze(-1),
            colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid(
                gaussian.get_features.squeeze(1)).unsqueeze(0),
            viewmats=extrinsics.unsqueeze(0),
            Ks=Ks_scaled,
            width=width,
            height=height,
            near_plane=near_plane,
            far_plane=far_plane,
            radius_clip=3.0,
            eps2d=0.3,
            render_mode="RGB",
            backgrounds=self.bg_color.unsqueeze(0),
            camera_model="pinhole"
        )

        rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1)

        # Apply supersampling if needed
        if ssaa > 1:
            rendered_image = F.interpolate(
                rendered_image[None],
                size=(resolution, resolution),
                mode='bilinear',
                align_corners=False,
                antialias=True
            ).squeeze()

        return edict({'color': rendered_image})