Spaces:
Runtime error
Runtime error
File size: 3,498 Bytes
f3ff4f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gsplat as gs
import numpy as np
import torch
import torch.nn.functional as F
from easydict import EasyDict as edict
class GSplatRenderer:
def __init__(self, rendering_options={}) -> None:
self.pipe = edict({
"kernel_size": 0.1,
"convert_SHs_python": False,
"compute_cov3D_python": False,
"scale_modifier": 1.0,
"debug": False,
"use_mip_gaussian": True
})
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"bg_color": 'random',
})
self.rendering_options.update(rendering_options)
self.bg_color = None
def render(
self,
gaussian,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
colors_overwrite: torch.Tensor = None
) -> edict:
resolution = self.rendering_options["resolution"]
ssaa = self.rendering_options["ssaa"]
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(
self.rendering_options["bg_color"],
dtype=torch.float32,
device="cuda"
)
height = resolution * ssaa
width = resolution * ssaa
# Set up background color
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(
self.rendering_options["bg_color"],
dtype=torch.float32,
device="cuda"
)
Ks_scaled = intrinsics.clone()
Ks_scaled[0, 0] *= width
Ks_scaled[1, 1] *= height
Ks_scaled[0, 2] *= width
Ks_scaled[1, 2] *= height
Ks_scaled = Ks_scaled.unsqueeze(0)
near_plane = 0.01
far_plane = 1000.0
# Rasterize with gsplat
render_colors, render_alphas, meta = gs.rasterization(
means=gaussian.get_xyz,
quats=F.normalize(gaussian.get_rotation, dim=-1),
scales=gaussian.get_scaling / intrinsics[0, 0],
opacities=gaussian.get_opacity.squeeze(-1),
colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid(
gaussian.get_features.squeeze(1)).unsqueeze(0),
viewmats=extrinsics.unsqueeze(0),
Ks=Ks_scaled,
width=width,
height=height,
near_plane=near_plane,
far_plane=far_plane,
radius_clip=3.0,
eps2d=0.3,
render_mode="RGB",
backgrounds=self.bg_color.unsqueeze(0),
camera_model="pinhole"
)
rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1)
# Apply supersampling if needed
if ssaa > 1:
rendered_image = F.interpolate(
rendered_image[None],
size=(resolution, resolution),
mode='bilinear',
align_corners=False,
antialias=True
).squeeze()
return edict({'color': rendered_image}) |