Jiangsan123 commited on
Commit
2098a77
·
0 Parent(s):

Reinitialize clean repo without large files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. README.md +13 -0
  3. app.py +497 -0
  4. requirements.txt +14 -0
  5. src/__pycache__/cameras.cpython-39.pyc +0 -0
  6. src/__pycache__/config.cpython-38.pyc +0 -0
  7. src/__pycache__/config.cpython-39.pyc +0 -0
  8. src/__pycache__/sparse_voxel_model.cpython-39.pyc +0 -0
  9. src/cameras.py +287 -0
  10. src/config.py +230 -0
  11. src/config_old.py +230 -0
  12. src/dataloader/__pycache__/data_pack.cpython-39.pyc +0 -0
  13. src/dataloader/__pycache__/reader_colmap_dataset.cpython-39.pyc +0 -0
  14. src/dataloader/__pycache__/reader_nerf_dataset.cpython-39.pyc +0 -0
  15. src/dataloader/data_pack.py +232 -0
  16. src/dataloader/reader_colmap_dataset.py +162 -0
  17. src/dataloader/reader_colmap_dataset_or.py +148 -0
  18. src/dataloader/reader_nerf_dataset.py +180 -0
  19. src/dataloader/reader_nerf_dataset_copy.py +170 -0
  20. src/sparse_voxel_gears/__pycache__/adaptive.cpython-39.pyc +0 -0
  21. src/sparse_voxel_gears/__pycache__/constructor.cpython-39.pyc +0 -0
  22. src/sparse_voxel_gears/__pycache__/io.cpython-39.pyc +0 -0
  23. src/sparse_voxel_gears/__pycache__/pooling.cpython-39.pyc +0 -0
  24. src/sparse_voxel_gears/__pycache__/properties.cpython-39.pyc +0 -0
  25. src/sparse_voxel_gears/__pycache__/renderer.cpython-39.pyc +0 -0
  26. src/sparse_voxel_gears/adaptive.py +296 -0
  27. src/sparse_voxel_gears/constructor.py +425 -0
  28. src/sparse_voxel_gears/io.py +156 -0
  29. src/sparse_voxel_gears/pooling.py +68 -0
  30. src/sparse_voxel_gears/properties.py +146 -0
  31. src/sparse_voxel_gears/renderer.py +178 -0
  32. src/sparse_voxel_gears/renderer_copy.py +178 -0
  33. src/sparse_voxel_model.py +67 -0
  34. src/sparse_voxel_model_copy.py +67 -0
  35. src/utils/__pycache__/activation_utils.cpython-39.pyc +0 -0
  36. src/utils/__pycache__/bounding_utils.cpython-39.pyc +0 -0
  37. src/utils/__pycache__/camera_utils.cpython-39.pyc +0 -0
  38. src/utils/__pycache__/colmap_utils.cpython-39.pyc +0 -0
  39. src/utils/__pycache__/fuser_utils.cpython-39.pyc +0 -0
  40. src/utils/__pycache__/image_utils.cpython-39.pyc +0 -0
  41. src/utils/__pycache__/loss_utils.cpython-39.pyc +0 -0
  42. src/utils/__pycache__/marching_cubes_utils.cpython-39.pyc +0 -0
  43. src/utils/__pycache__/mono_utils.cpython-39.pyc +0 -0
  44. src/utils/__pycache__/octree_utils.cpython-39.pyc +0 -0
  45. src/utils/__pycache__/system_utils.cpython-39.pyc +0 -0
  46. src/utils/activation_utils.py +49 -0
  47. src/utils/bounding_utils.py +102 -0
  48. src/utils/camera_utils.py +79 -0
  49. src/utils/colmap_utils.py +62 -0
  50. src/utils/fuser_utils.py +185 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ply filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Beetle Viz
3
+ emoji: 😻
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Mon Oct 6 10:16:31 2025
5
+
6
+ @author: nibio
7
+ """
8
+
9
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
10
+
11
+ import os, time
12
+ import numpy as np
13
+ import imageio.v3 as iio
14
+ from scipy.spatial.transform import Rotation
15
+ from typing import Optional
16
+
17
+ import torch
18
+
19
+ from src.config import cfg, update_config
20
+ from src.dataloader.data_pack import DataPack
21
+ from src.sparse_voxel_model import SparseVoxelModel
22
+ from src.utils.image_utils import im_tensor2np, viz_tensordepth
23
+ from src.cameras import MiniCam
24
+
25
+ import viser
26
+ import viser.transforms as tf
27
+
28
+
29
+ def matrix2wxyz(R: np.ndarray) -> np.ndarray:
30
+ return Rotation.from_matrix(R).as_quat()[[3, 0, 1, 2]]
31
+
32
+
33
+ def wxyz2matrix(wxyz: np.ndarray) -> np.ndarray:
34
+ return Rotation.from_quat(wxyz[[1, 2, 3, 0]]).as_matrix()
35
+
36
+
37
+ class SVRasterViewer:
38
+ def __init__(self, cfg):
39
+
40
+ # ---------- Data & model ----------
41
+ data_pack = DataPack(
42
+ source_path=cfg.data.source_path,
43
+ image_dir_name=cfg.data.image_dir_name,
44
+ res_downscale=cfg.data.res_downscale,
45
+ res_width=cfg.data.res_width,
46
+ skip_blend_alpha=cfg.data.skip_blend_alpha,
47
+ alpha_is_white=cfg.model.white_background,
48
+ data_device=cfg.data.data_device,
49
+ use_test=cfg.data.eval,
50
+ test_every=cfg.data.test_every,
51
+ camera_params_only=True,
52
+ )
53
+ self.tr_cam_lst = data_pack.get_train_cameras()
54
+ self.te_cam_lst = data_pack.get_test_cameras()
55
+
56
+ self.scene_center = (
57
+ np.mean([c.c2w[:3, 3].cpu().numpy() for c in self.tr_cam_lst], axis=0)
58
+ if len(self.tr_cam_lst)
59
+ else np.zeros(3, dtype=np.float32)
60
+ )
61
+
62
+ self.voxel_model = SparseVoxelModel(
63
+ n_samp_per_vox=cfg.model.n_samp_per_vox,
64
+ sh_degree=cfg.model.sh_degree,
65
+ ss=cfg.model.ss,
66
+ white_background=cfg.model.white_background,
67
+ black_background=cfg.model.black_background,
68
+ )
69
+ self.voxel_model.load_iteration(args.model_path, args.iteration) # args from __main__
70
+ self.voxel_model.freeze_vox_geo()
71
+
72
+ # ---------- UI ----------
73
+ self.server = viser.ViserServer(port=cfg.port)
74
+ self.is_connected = False
75
+
76
+ self.server.gui.set_panel_label("SVRaster viser")
77
+ self.server.gui.add_markdown(
78
+ "**View control:**\n- Mouse drag + scroll\n- WASD + QE keys"
79
+ )
80
+ self.fps = self.server.gui.add_text("Rending FPS", initial_value="-1", disabled=True)
81
+
82
+ self.active_sh_degree_slider = self.server.gui.add_slider(
83
+ "active_sh_degree", min=0, max=self.voxel_model.max_sh_degree, step=1,
84
+ initial_value=self.voxel_model.active_sh_degree
85
+ )
86
+ self.ss_slider = self.server.gui.add_slider("ss", min=0.5, max=2.0, step=0.05, initial_value=self.voxel_model.ss)
87
+ self.width_slider = self.server.gui.add_slider("width", min=64, max=2048, step=8, initial_value=1024)
88
+ self.fovx_slider = self.server.gui.add_slider("fovx", min=10, max=150, step=1, initial_value=70)
89
+ self.near_slider = self.server.gui.add_slider("near", min=0.02,max=10, step=0.01,initial_value=0.2)
90
+
91
+ self.render_dropdown = self.server.gui.add_dropdown(
92
+ "render mod", options=["all","rgb only","depth only","normal only"], initial_value="all"
93
+ )
94
+ self.output_dropdown = self.server.gui.add_dropdown(
95
+ "output", options=["rgb","alpha","dmean","dmed","dmean2n","dmed2n","n"], initial_value="rgb"
96
+ )
97
+
98
+ # ---- Focus & crop controls ----
99
+ self.alpha_thr_slider = self.server.gui.add_slider(
100
+ "alpha_threshold", min=0.0, max=0.95, step=0.01, initial_value=0.35
101
+ )
102
+ self.keep_closest_slider = self.server.gui.add_slider(
103
+ "keep_closest_pct", min=0.2, max=1.0, step=0.05, initial_value=0.6
104
+ )
105
+ self.hide_outside_checkbox = self.server.gui.add_checkbox(
106
+ "hide_outside_focus", initial_value=False
107
+ )
108
+
109
+ self.center_btn = self.server.gui.add_button("Center on object")
110
+ self.reset_btn = self.server.gui.add_button("Reset to first view")
111
+ self.autoframe_btn = self.server.gui.add_button("Auto-frame (depth)")
112
+ self.focus_btn = self.server.gui.add_button("Focus foreground")
113
+ self.rebase_btn = self.server.gui.add_button("Recenter world to focus")
114
+
115
+ # ---- state for world rebase / focus mask ----
116
+ self.world_offset = np.zeros(3, dtype=np.float32) # world translation applied during render
117
+ self.focus_center: Optional[np.ndarray] = None
118
+
119
+ # ---------- Camera frusta ----------
120
+ self.tr_frust, self.te_frust = [], []
121
+
122
+ def add_frustum(name, cam, color):
123
+ c2w = cam.c2w.cpu().numpy()
124
+ frame = self.server.scene.add_camera_frustum(
125
+ name,
126
+ fov=cam.fovy,
127
+ aspect=cam.image_width / cam.image_height,
128
+ scale=0.10,
129
+ wxyz=matrix2wxyz(c2w[:3, :3]),
130
+ position=c2w[:3, 3],
131
+ color=color,
132
+ visible=False,
133
+ )
134
+
135
+ @frame.on_click
136
+ def _(event: viser.SceneNodePointerEvent):
137
+ client = event.client
138
+ with client.atomic():
139
+ client.camera.wxyz = event.target.wxyz
140
+ client.camera.position = event.target.position
141
+ self._camera_lookat(client, self.scene_center)
142
+
143
+ return frame
144
+
145
+ for i, cam in enumerate(self.tr_cam_lst):
146
+ self.tr_frust.append(add_frustum(f"/frustum/train/{i:04d}", cam, [0.0, 1.0, 0.0]))
147
+ for i, cam in enumerate(self.te_cam_lst):
148
+ self.te_frust.append(add_frustum(f"/frustum/test/{i:04d}", cam, [1.0, 0.0, 0.0]))
149
+
150
+ self.show_cam_dropdown = self.server.gui.add_dropdown(
151
+ "show cameras", options=["none","train","test","all"], initial_value="none"
152
+ )
153
+
154
+ @self.show_cam_dropdown.on_update
155
+ def _(_):
156
+ for f in self.tr_frust: f.visible = self.show_cam_dropdown.value in ["train","all"]
157
+ for f in self.te_frust: f.visible = self.show_cam_dropdown.value in ["test","all"]
158
+
159
+ # ---------- Button handlers ----------
160
+ @self.center_btn.on_click
161
+ def _(event: viser.GuiEvent):
162
+ if event.client: self._camera_lookat(event.client, self.scene_center)
163
+
164
+ @self.reset_btn.on_click
165
+ def _(event: viser.GuiEvent):
166
+ client = event.client
167
+ if not client: return
168
+ init = self.tr_cam_lst[0].c2w.cpu().numpy()
169
+ with client.atomic():
170
+ client.camera.wxyz = matrix2wxyz(init[:3, :3])
171
+ client.camera.position = init[:3, 3]
172
+ self._camera_lookat(client, self.scene_center)
173
+
174
+ @self.autoframe_btn.on_click
175
+ def _(event: viser.GuiEvent):
176
+ if event.client: self._auto_frame_by_depth(event.client)
177
+
178
+ @self.focus_btn.on_click
179
+ def _(event: viser.GuiEvent):
180
+ if event.client: self._focus_foreground(event.client)
181
+
182
+ @self.rebase_btn.on_click
183
+ def _(event: viser.GuiEvent):
184
+ client = event.client
185
+ if not client or self.focus_center is None:
186
+ print("[rebase] Run 'Focus foreground' first.")
187
+ return
188
+ delta = self.focus_center.astype(np.float32)
189
+ self.world_offset = self.world_offset + delta # accumulate translation
190
+ with client.atomic():
191
+ client.camera.position = (np.asarray(client.camera.position) - delta).astype(np.float32)
192
+ self.scene_center = np.zeros(3, dtype=np.float32)
193
+ print("[rebase] World recentered; new world_offset:", self.world_offset)
194
+
195
+ # ---------- On connect ----------
196
+ @self.server.on_client_connect
197
+ def _(client: viser.ClientHandle):
198
+ init = self.tr_cam_lst[0].c2w.cpu().numpy()
199
+ with client.atomic():
200
+ client.camera.wxyz = matrix2wxyz(init[:3, :3])
201
+ client.camera.position = init[:3, 3]
202
+ ok = self._auto_frame_by_depth(client, quiet=True)
203
+ if not ok:
204
+ self._camera_lookat(client, self.scene_center)
205
+ self.is_connected = True
206
+
207
+ # ---------- Download ----------
208
+ self.download_button = self.server.gui.add_button("Download view")
209
+
210
+ @self.download_button.on_click
211
+ def _(event: viser.GuiEvent):
212
+ im, _ = self.render_viser_camera(event.client.camera)
213
+ event.client.send_file_download(
214
+ "svraster_viser.png", iio.imwrite("<bytes>", im, extension=".png")
215
+ )
216
+
217
+ # ---------------- camera utils ----------------
218
+ def _camera_lookat(
219
+ self,
220
+ client: viser.ClientHandle,
221
+ target: np.ndarray,
222
+ distance: Optional[float] = None,
223
+ ):
224
+ """
225
+ Point the camera at `target` by writing orientation (wxyz) and position directly.
226
+ Compatible with Viser builds where camera.look_at is not callable.
227
+ """
228
+ target = np.asarray(target, dtype=np.float32)
229
+ eye = np.asarray(client.camera.position, dtype=np.float32)
230
+
231
+ vec = eye - target # target -> eye
232
+ norm = np.linalg.norm(vec)
233
+ if not np.isfinite(norm) or norm < 1e-6:
234
+ vec = np.array([0, 0, 1.0], dtype=np.float32)
235
+ norm = 0.5
236
+
237
+ d = float(norm if distance is None else distance)
238
+
239
+ # Orthonormal basis that looks at target.
240
+ fwd = -(vec / max(norm, 1e-6)) # camera forward (eye->target)
241
+ up_guess = np.array([0, 1, 0], dtype=np.float32)
242
+ if abs(np.dot(fwd, up_guess)) > 0.99:
243
+ up_guess = np.array([1, 0, 0], dtype=np.float32)
244
+ right = np.cross(up_guess, fwd)
245
+ right /= max(np.linalg.norm(right), 1e-6)
246
+ up = np.cross(fwd, right)
247
+ up /= max(np.linalg.norm(up), 1e-6)
248
+
249
+ R = np.stack([right, up, fwd], axis=1).astype(np.float32)
250
+ new_pos = target - fwd * d
251
+
252
+ with client.atomic():
253
+ client.camera.wxyz = matrix2wxyz(R)
254
+ client.camera.position = new_pos
255
+
256
+ def _auto_frame_by_depth(self, client: viser.ClientHandle, quiet: bool = False) -> bool:
257
+ """Render once, use center-pixel median depth to determine a good pivot."""
258
+ try:
259
+ _, _, depth_med = self.render_viser_camera(client.camera, return_depth=True)
260
+ except Exception as e:
261
+ if not quiet: print("[auto-frame] render error:", e)
262
+ return False
263
+
264
+ H, W = depth_med.shape
265
+ d = float(depth_med[H // 2, W // 2])
266
+ if not np.isfinite(d) or d <= 0:
267
+ if not quiet: print("[auto-frame] invalid depth; falling back")
268
+ return False
269
+
270
+ R = wxyz2matrix(client.camera.wxyz)
271
+ fwd = R @ np.array([0, 0, 1], dtype=np.float32)
272
+ target = np.asarray(client.camera.position, dtype=np.float32) + fwd * d
273
+ self._camera_lookat(client, target, distance=d)
274
+ if not quiet: print("[auto-frame] success; depth =", d)
275
+ return True
276
+
277
+ # ----------- Focus only the foreground object -----------
278
+ def _focus_foreground(self, client: viser.ClientHandle):
279
+ """
280
+ Use alpha (1 - T) to mask foreground, keep closest depths,
281
+ back-project to world, compute tight AABB, center & fit view.
282
+ Stores self.focus_center so you can 'Recenter world to focus'.
283
+ """
284
+ try:
285
+ _, _, depth_med, T = self.render_viser_camera(client.camera, return_depth=True, return_T=True)
286
+ except Exception as e:
287
+ print("[focus] render error:", e)
288
+ return
289
+
290
+ alpha = 1.0 - T
291
+ thr = float(self.alpha_thr_slider.value)
292
+ mask = (alpha > thr) & np.isfinite(depth_med) & (depth_med > 0)
293
+
294
+ if mask.sum() < 50:
295
+ print("[focus] Not enough foreground; lower alpha_threshold or change view.")
296
+ return
297
+
298
+ # Keep only the closest K% pixels to drop the outer ring
299
+ K = float(self.keep_closest_slider.value)
300
+ dvals = depth_med[mask]
301
+ q = np.quantile(dvals, K)
302
+ mask &= depth_med <= q
303
+ if mask.sum() < 50:
304
+ print("[focus] Too few pixels after depth filtering; raise keep_closest_pct.")
305
+ return
306
+
307
+ # Back-project masked pixels to world
308
+ width = int(self.width_slider.value)
309
+ aspect = max(1e-6, float(client.camera.aspect))
310
+ height = max(1, int(round(width / aspect)))
311
+ fovx = np.deg2rad(float(self.fovx_slider.value))
312
+ fovy = fovx * height / max(width, 1)
313
+
314
+ fx = width / (2.0 * np.tan(fovx * 0.5))
315
+ fy = height / (2.0 * np.tan(fovy * 0.5))
316
+ cx, cy = (width - 1) / 2.0, (height - 1) / 2.0
317
+
318
+ ys, xs = np.where(mask)
319
+ zs = depth_med[ys, xs].astype(np.float32)
320
+
321
+ x_cam = (xs - cx) / fx * zs
322
+ y_cam = (ys - cy) / fy * zs
323
+ z_cam = zs
324
+ P_cam = np.stack([x_cam, y_cam, z_cam], axis=0) # (3, N)
325
+
326
+ R = wxyz2matrix(client.camera.wxyz)
327
+ t = np.asarray(client.camera.position, dtype=np.float32)[:, None]
328
+ # Apply current world rebase so P_world matches what we render
329
+ t = (t - self.world_offset[:, None]).astype(np.float32)
330
+
331
+ P_world = (R @ P_cam) + t # (3, N)
332
+
333
+ pmin = np.min(P_world, axis=1)
334
+ pmax = np.max(P_world, axis=1)
335
+ center = (pmin + pmax) * 0.5
336
+ extent = (pmax - pmin) * 0.5
337
+
338
+ # Save for rebase
339
+ self.focus_center = center.astype(np.float32)
340
+
341
+ # Choose distance that fits bbox into the view (larger FOV dimension)
342
+ fovx_deg = float(self.fovx_slider.value)
343
+ fovy_deg = fovx_deg * height / max(width, 1)
344
+ fov_rad = np.deg2rad(max(fovx_deg, fovy_deg))
345
+ radius = float(np.linalg.norm(extent, ord=np.inf))
346
+ dist = radius / np.tan(max(1e-4, fov_rad * 0.5)) * 1.25 # padding
347
+
348
+ # Update logical scene center for orbiting & go there
349
+ self.scene_center = center.astype(np.float32)
350
+ self._camera_lookat(client, self.scene_center, distance=dist)
351
+
352
+ print(f"[focus] bbox half-extent ~{extent}, distance {dist:.3f}")
353
+
354
+ # ---------------- rendering ----------------
355
+ @torch.no_grad()
356
+ def render_viser_camera(
357
+ self,
358
+ camera: viser.CameraHandle,
359
+ return_depth: bool = False,
360
+ return_T: bool = False,
361
+ ):
362
+ width = int(self.width_slider.value)
363
+ aspect = max(1e-6, float(camera.aspect))
364
+ height = max(1, int(round(width / aspect)))
365
+
366
+ fovx_deg = float(self.fovx_slider.value)
367
+ fovy_deg = fovx_deg * height / max(width, 1)
368
+ near = float(self.near_slider.value)
369
+
370
+ c2w = np.eye(4, dtype=np.float32)
371
+ c2w[:3, :3] = wxyz2matrix(camera.wxyz)
372
+ c2w[:3, 3] = camera.position
373
+ # Apply world rebase: move the *world* by -offset equivalently by moving camera by -offset in world coords.
374
+ c2w[:3, 3] = c2w[:3, 3] - self.world_offset
375
+
376
+ minicam = MiniCam(
377
+ c2w, fovx=np.deg2rad(fovx_deg), fovy=np.deg2rad(fovy_deg),
378
+ width=width, height=height, near=near
379
+ )
380
+
381
+ self.voxel_model.active_sh_degree = int(self.active_sh_degree_slider.value)
382
+
383
+ render_opt = {
384
+ "ss": self.ss_slider.value,
385
+ "output_T": True,
386
+ "output_depth": True,
387
+ "output_normal": True,
388
+ }
389
+ if self.render_dropdown.value == "rgb only":
390
+ render_opt["output_depth"] = False; render_opt["output_normal"] = False
391
+ elif self.render_dropdown.value == "depth only":
392
+ render_opt["color_mode"] = "dontcare"; render_opt["output_normal"] = False
393
+ elif self.render_dropdown.value == "normal only":
394
+ render_opt["color_mode"] = "dontcare"; render_opt["output_depth"] = False
395
+
396
+ t0 = time.time()
397
+ try:
398
+ render_pkg = self.voxel_model.render(minicam, **render_opt)
399
+ except RuntimeError as e:
400
+ print("[render] RuntimeError:", e)
401
+ im = np.ones((height, width, 3), dtype=np.uint8) * 255
402
+ if return_depth and return_T:
403
+ depth_med = np.full((height, width), np.nan, dtype=np.float32)
404
+ T = np.ones((height, width), dtype=np.float32)
405
+ return im, 0.0, depth_med, T
406
+ if return_depth:
407
+ depth_med = np.full((height, width), np.nan, dtype=np.float32)
408
+ return im, 0.0, depth_med
409
+ if return_T:
410
+ T = np.ones((height, width), dtype=np.float32)
411
+ return im, 0.0, T
412
+ return im, 0.0
413
+ torch.cuda.synchronize()
414
+ eps = time.time() - t0
415
+
416
+ # choose output image
417
+ if self.output_dropdown.value == "dmean":
418
+ im = viz_tensordepth(render_pkg["depth"][0])
419
+ elif self.output_dropdown.value == "dmed":
420
+ im = viz_tensordepth(render_pkg["depth"][2])
421
+ elif self.output_dropdown.value == "dmean2n":
422
+ im = im_tensor2np(minicam.depth2normal(render_pkg["depth"][0]) * 0.5 + 0.5)
423
+ elif self.output_dropdown.value == "dmed2n":
424
+ im = im_tensor2np(minicam.depth2normal(render_pkg["depth"][2]) * 0.5 + 0.5)
425
+ elif self.output_dropdown.value == "n":
426
+ im = im_tensor2np(render_pkg["normal"] * 0.5 + 0.5)
427
+ elif self.output_dropdown.value == "alpha":
428
+ im = im_tensor2np(1 - render_pkg["T"].repeat(3, 1, 1))
429
+ else:
430
+ im = im_tensor2np(render_pkg["color"])
431
+
432
+ depth_med = render_pkg["depth"][2].detach().cpu().numpy()
433
+ T = render_pkg["T"].detach().cpu().numpy() # (H, W)
434
+
435
+ # Optional image-level masking to hide outside the focused object
436
+ if self.hide_outside_checkbox.value:
437
+ alpha = 1.0 - T
438
+ thr = float(self.alpha_thr_slider.value)
439
+ mask = (alpha > thr) & np.isfinite(depth_med) & (depth_med > 0)
440
+ if mask.any():
441
+ K = float(self.keep_closest_slider.value)
442
+ dvals = depth_med[mask]
443
+ q = np.quantile(dvals, K)
444
+ mask &= depth_med <= q
445
+ mask3 = np.repeat(mask[..., None], 3, axis=2)
446
+ bg = np.zeros_like(im) # black background
447
+ im = np.where(mask3, im, bg)
448
+
449
+ del render_pkg
450
+
451
+ if return_depth and return_T:
452
+ return im, eps, depth_med, T
453
+ if return_depth:
454
+ return im, eps, depth_med
455
+ if return_T:
456
+ return im, eps, T
457
+ return im, eps
458
+
459
+ # ---------------- server tick ----------------
460
+ def update(self):
461
+ if not self.is_connected:
462
+ return
463
+ times = []
464
+ for client in self.server.get_clients().values():
465
+ im, eps = self.render_viser_camera(client.camera)
466
+ times.append(eps)
467
+ client.scene.set_background_image(im, format="jpeg")
468
+ if times:
469
+ self.fps.value = f"{round(1 / np.mean(times)):4d}"
470
+
471
+
472
+ if __name__ == "__main__":
473
+ import os, time
474
+
475
+ class Args:
476
+ model_path = "Entimus_imperialis_out_model/2025-1008-1320-c3c8c5"
477
+ iteration = -1
478
+ port = 7860 # Hugging Face default port
479
+
480
+ args = Args()
481
+ print(f"[INFO] Launching SVRaster viewer on Hugging Face...")
482
+ print(f"[INFO] Model path: {args.model_path}")
483
+
484
+ update_config(os.path.join(args.model_path, "config.yaml"))
485
+ cfg.port = args.port
486
+
487
+ svraster_viewer = SVRasterViewer(cfg)
488
+
489
+ # Keep process alive so Hugging Face doesn't stop it
490
+ while True:
491
+ svraster_viewer.update()
492
+ time.sleep(0.01)
493
+
494
+
495
+
496
+
497
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ numpy
5
+ scipy
6
+ imageio
7
+ open3d
8
+ trimesh
9
+ matplotlib
10
+ Pillow
11
+ tqdm
12
+ huggingface_hub
13
+ viser==0.1.30
14
+ gradio==5.2.0
src/__pycache__/cameras.cpython-39.pyc ADDED
Binary file (8.98 kB). View file
 
src/__pycache__/config.cpython-38.pyc ADDED
Binary file (3.64 kB). View file
 
src/__pycache__/config.cpython-39.pyc ADDED
Binary file (3.67 kB). View file
 
src/__pycache__/sparse_voxel_model.cpython-39.pyc ADDED
Binary file (1.85 kB). View file
 
src/cameras.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import svraster_cuda
13
+
14
+
15
+ class CameraBase:
16
+
17
+ '''
18
+ Base class of perspective cameras.
19
+ '''
20
+
21
+ def __repr__(self):
22
+ clsname = self.__class__.__name__
23
+ fname = f"image_name='{self.image_name}'"
24
+ res = f"HW=({self.image_height}x{self.image_width})"
25
+ fov = f"fovx={np.rad2deg(self.fovx):.1f}deg"
26
+ return f"{clsname}({fname}, {res}, {fov})"
27
+
28
+ @property
29
+ def lookat(self):
30
+ return self.c2w[:3, 2]
31
+
32
+ @property
33
+ def position(self):
34
+ return self.c2w[:3, 3]
35
+
36
+ @property
37
+ def down(self):
38
+ return self.c2w[:3, 1]
39
+
40
+ @property
41
+ def right(self):
42
+ return self.c2w[:3, 0]
43
+
44
+ @property
45
+ def cx(self):
46
+ return self.image_width * self.cx_p
47
+
48
+ @property
49
+ def cy(self):
50
+ return self.image_height * self.cy_p
51
+
52
+ @property
53
+ def pix_size(self):
54
+ return 2 * self.tanfovx / self.image_width
55
+
56
+ @property
57
+ def tanfovx(self):
58
+ return np.tan(self.fovx * 0.5)
59
+
60
+ @property
61
+ def tanfovy(self):
62
+ return np.tan(self.fovy * 0.5)
63
+
64
+ def compute_rd(self, wh=None, cxcy=None, device=None):
65
+ '''Ray directions in world space.'''
66
+ if wh is None:
67
+ wh = (self.image_width, self.image_height)
68
+ if cxcy is None:
69
+ cxcy = (self.cx * wh[0] / self.image_width, self.cy * wh[1] / self.image_height)
70
+ rd = svraster_cuda.utils.compute_rd(
71
+ width=wh[0], height=wh[1],
72
+ cx=cxcy[0], cy=cxcy[1],
73
+ tanfovx=self.tanfovx, tanfovy=self.tanfovy,
74
+ c2w_matrix=self.c2w.cuda())
75
+ rd = rd.to(device if device is None else self.c2w.device)
76
+ return rd
77
+
78
+ def project(self, pts, return_depth=False):
79
+ # Return normalized image coordinate in [-1, 1]
80
+ cam_pts = pts @ self.w2c[:3, :3].T + self.w2c[:3, 3]
81
+ depth = cam_pts[:, [2]]
82
+ cam_uv = cam_pts[:, :2] / depth
83
+ scale_x = 1 / self.tanfovx
84
+ scale_y = 1 / self.tanfovy
85
+ shift_x = 2 * self.cx_p - 1
86
+ shift_y = 2 * self.cy_p - 1
87
+ cam_uv[:, 0] = cam_uv[:, 0] * scale_x + shift_x
88
+ cam_uv[:, 1] = cam_uv[:, 1] * scale_y + shift_y
89
+ if return_depth:
90
+ return cam_uv, depth
91
+ return cam_uv
92
+
93
+ def depth2pts(self, depth):
94
+ device = depth.device
95
+ h, w = depth.shape[-2:]
96
+ rd = self.compute_rd(wh=(w, h), device=device)
97
+ return self.position.view(3,1,1).to(device) + rd * depth
98
+
99
+ def depth2normal(self, depth, ks=3, tol_cos=-1):
100
+ assert ks % 2 == 1
101
+ pad = ks // 2
102
+ ks_1 = ks - 1
103
+ pts = self.depth2pts(depth)
104
+ normal_pseudo = torch.zeros_like(pts)
105
+ dx = pts[:, pad:-pad, ks_1:] - pts[:, pad:-pad, :-ks_1]
106
+ dy = pts[:, ks_1:, pad:-pad] - pts[:, :-ks_1, pad:-pad]
107
+ normal_pseudo[:, pad:-pad, pad:-pad] = torch.nn.functional.normalize(torch.cross(dx, dy, dim=0), dim=0)
108
+
109
+ if tol_cos > 0:
110
+ with torch.no_grad():
111
+ pts_dir = torch.nn.functional.normalize(pts - self.position.view(3,1,1), dim=0)
112
+ dot = (normal_pseudo * pts_dir).sum(0)
113
+ mask = (dot > tol_cos)
114
+ normal_pseudo = normal_pseudo * mask
115
+
116
+ return normal_pseudo
117
+
118
+
119
+ class Camera(CameraBase):
120
+ def __init__(
121
+ self, image_name,
122
+ w2c, fovx, fovy, cx_p, cy_p,
123
+ near=0.02,
124
+ image=None, mask=None, depth=None,
125
+ sparse_pt=None):
126
+
127
+ self.image_name = image_name
128
+
129
+ # Camera parameters
130
+ self.w2c = torch.tensor(w2c, dtype=torch.float32, device="cuda")
131
+ self.c2w = self.w2c.inverse().contiguous()
132
+
133
+ self.fovx = fovx
134
+ self.fovy = fovy
135
+
136
+ # Load frame
137
+ self.image = image.cpu()
138
+
139
+ # Other camera parameters
140
+ self.image_width = self.image.shape[2]
141
+ self.image_height = self.image.shape[1]
142
+ self.cx_p = (0.5 if cx_p is None else cx_p)
143
+ self.cy_p = (0.5 if cy_p is None else cy_p)
144
+ self.near = near
145
+
146
+ # Load mask and depth if there are
147
+ self.mask = mask.cpu() if mask is not None else None
148
+ self.depth = depth.cpu() if depth is not None else None
149
+
150
+ # Load sparse depth
151
+ if sparse_pt is not None:
152
+ self.sparse_pt = torch.tensor(sparse_pt, dtype=torch.float32, device="cpu")
153
+ else:
154
+ self.sparse_pt = None
155
+
156
+ def to(self, device):
157
+ self.image = self.image.to(device)
158
+ if self.mask is not None:
159
+ self.mask = self.mask.to(device)
160
+ if self.depth is not None:
161
+ self.depth = self.depth.to(device)
162
+ return self
163
+
164
+ def auto_exposure_init(self):
165
+ self._exposure_A = torch.eye(3, dtype=torch.float32, device="cuda")
166
+ self._exposure_t = torch.zeros([3,1,1], dtype=torch.float32, device="cuda")
167
+ self.exposure_updated = False
168
+
169
+ def auto_exposure_apply(self, image):
170
+ if self.exposure_updated:
171
+ image = torch.einsum('ij,jhw->ihw', self._exposure_A, image) + self._exposure_t
172
+ return image
173
+
174
+ def auto_exposure_update(self, ren, ref):
175
+ self.exposure_updated = True
176
+ self._exposure_A.requires_grad_()
177
+ self._exposure_t.requires_grad_()
178
+ optim = torch.optim.Adam([self._exposure_A, self._exposure_t], lr=1e-3)
179
+ for _ in range(100):
180
+ loss = (self.auto_exposure_apply(ren).clamp(0, 1) - ref).abs().mean()
181
+ loss.backward()
182
+ optim.step()
183
+ optim.zero_grad(set_to_none=True)
184
+ self._exposure_A.requires_grad_(False)
185
+ self._exposure_t.requires_grad_(False)
186
+
187
+ def clone_mini(self):
188
+ return MiniCam(
189
+ c2w=self.c2w.clone(),
190
+ fovx=self.fovx, fovy=self.fovy,
191
+ width=self.image_width, height=self.image_height,
192
+ near=self.near,
193
+ cx_p=self.cx_p, cy_p=self.cy_p)
194
+
195
+
196
+ class MiniCam(CameraBase):
197
+ def __init__(self,
198
+ c2w, fovx, fovy,
199
+ width, height,
200
+ near=0.02,
201
+ cx_p=None, cy_p=None,
202
+ image_name="minicam"):
203
+
204
+ self.image_name = image_name
205
+ self.c2w = torch.tensor(c2w).clone().cuda()
206
+ self.w2c = self.c2w.inverse()
207
+
208
+ self.fovx = fovx
209
+ self.fovy = fovy
210
+ self.image_width = width
211
+ self.image_height = height
212
+ self.cx_p = (0.5 if cx_p is None else cx_p)
213
+ self.cy_p = (0.5 if cy_p is None else cy_p)
214
+ self.near = near
215
+
216
+ self.depth = None
217
+ self.mask = None
218
+
219
+ def clone_mini(self):
220
+ return MiniCam(
221
+ c2w=self.c2w.clone(),
222
+ fovx=self.fovx, fovy=self.fovy,
223
+ width=self.image_width, height=self.image_height,
224
+ near=self.near,
225
+ cx_p=self.cx_p, cy_p=self.cy_p)
226
+
227
+ def move_forward(self, dist):
228
+ new_position = self.position + dist * self.lookat
229
+ self.c2w[:3, 3] = new_position
230
+ self.w2c = self.c2w.inverse()
231
+ return self
232
+
233
+ def move_up(self, dist):
234
+ return self.move_down(-dist)
235
+
236
+ def move_down(self, dist):
237
+ new_position = self.position + dist * self.down
238
+ self.c2w[:3, 3] = new_position
239
+ self.w2c = self.c2w.inverse()
240
+ return self
241
+
242
+ def move_right(self, dist):
243
+ new_position = self.position + dist * self.right
244
+ self.c2w[:3, 3] = new_position
245
+ self.w2c = self.c2w.inverse()
246
+ return self
247
+
248
+ def move_left(self, dist):
249
+ return self.move_right(-dist)
250
+
251
+ def rotate(self, R):
252
+ self.c2w[:3, :3] = (R @ self.w2c[:3, :3]).T
253
+ self.w2c = self.c2w.inverse()
254
+ return self
255
+
256
+ def rotate_x(self, rad=None, deg=None):
257
+ assert rad is None or deg is None, "Can only specify rotation by either rad or deg."
258
+ if rad is None:
259
+ rad = np.deg2rad(deg)
260
+ R = torch.tensor([
261
+ [1, 0, 0],
262
+ [0, np.cos(rad), -np.sin(rad)],
263
+ [0, np.sin(rad), np.cos(rad)],
264
+ ], dtype=torch.float32, device="cuda")
265
+ return self.rotate(R)
266
+
267
+ def rotate_y(self, rad=None, deg=None):
268
+ assert rad is None or deg is None, "Can only specify rotation by either rad or deg."
269
+ if rad is None:
270
+ rad = np.deg2rad(deg)
271
+ R = torch.tensor([
272
+ [np.cos(rad), 0, -np.sin(rad)],
273
+ [0, 1, 0],
274
+ [np.sin(rad), 0, np.cos(rad)],
275
+ ], dtype=torch.float32, device="cuda")
276
+ return self.rotate(R)
277
+
278
+ def rotate_z(self, rad=None, deg=None):
279
+ assert rad is None or deg is None, "Can only specify rotation by either rad or deg."
280
+ if rad is None:
281
+ rad = np.deg2rad(deg)
282
+ R = torch.tensor([
283
+ [np.cos(rad), -np.sin(rad), 0],
284
+ [np.sin(rad), np.cos(rad), 0],
285
+ [0, 0, 1],
286
+ ], dtype=torch.float32, device="cuda")
287
+ return self.rotate(R)
src/config.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import argparse
10
+ from yacs.config import CfgNode
11
+
12
+
13
+ cfg = CfgNode()
14
+
15
+ cfg.model = CfgNode(dict(
16
+ n_samp_per_vox = 1, # Number of sampled points per visited voxel
17
+ sh_degree = 3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
18
+ ss = 1.5, # Super-sampling rates for anti-aliasing
19
+ white_background = False, # Assum white background
20
+ black_background = False, # Assum black background
21
+ ))
22
+
23
+ cfg.data = CfgNode(dict(
24
+ source_path = "",
25
+ image_dir_name = "images",
26
+ mask_dir_name = "masks",
27
+ res_downscale = 0.,
28
+ res_width = 0,
29
+ skip_blend_alpha = False,
30
+ data_device = "cpu",
31
+ eval = False,
32
+ test_every = 8,
33
+ ))
34
+
35
+ cfg.bounding = CfgNode(dict(
36
+ # Define the main (inside) region bounding box
37
+ # The default use the suggested bounding if given by dataset.
38
+ # Otherwise, it automatically chose from forward or camera_median modes.
39
+ # See src/utils/bounding_utils.py for details.
40
+
41
+ # default | camera_median | camera_max | forward | pcd
42
+ bound_mode = "default",
43
+ bound_scale = 1.0, # Scaling factor of the bound
44
+ forward_dist_scale = 1.0, # For forward mode
45
+ pcd_density_rate = 0.1, # For pcd mode
46
+
47
+ # Number of Octree level outside the main foreground region
48
+ outside_level = 5,
49
+ ))
50
+
51
+ cfg.optimizer = CfgNode(dict(
52
+ geo_lr = 0.025,
53
+ sh0_lr = 0.010,
54
+ shs_lr = 0.00025,
55
+
56
+ optim_beta1 = 0.1,
57
+ optim_beta2 = 0.99,
58
+ optim_eps = 1e-15,
59
+
60
+ lr_decay_ckpt = [19000],
61
+ lr_decay_mult = 0.1,
62
+ ))
63
+
64
+ cfg.regularizer = CfgNode(dict(
65
+ # Main photometric loss
66
+ lambda_photo = 1.0,
67
+ use_l1 = False,
68
+ use_huber = False,
69
+ huber_thres = 0.03,
70
+
71
+ # SSIM loss
72
+ lambda_ssim = 0.02,
73
+
74
+ # Sparse depth loss
75
+ lambda_sparse_depth = 0.0,
76
+ sparse_depth_until = 10_000,
77
+
78
+ # Mask loss
79
+ lambda_mask = 0.0,
80
+
81
+ # Depthanything loss
82
+ lambda_depthanythingv2 = 0.0,
83
+ depthanythingv2_from = 3000,
84
+ depthanythingv2_end = 20000,
85
+ depthanythingv2_end_mult = 0.1,
86
+
87
+ # Mast3r metrid loss
88
+ lambda_mast3r_metric_depth = 0.0,
89
+ mast3r_repo_path = '',
90
+ mast3r_metric_depth_from = 0,
91
+ mast3r_metric_depth_end = 20000,
92
+ mast3r_metric_depth_end_mult = 0.01,
93
+
94
+ # Final transmittance should concentrate to either 0 or 1
95
+ lambda_T_concen = 0.0,
96
+
97
+ # Final transmittance should be 0
98
+ lambda_T_inside = 0.0,
99
+
100
+ # Per-point rgb loss
101
+ lambda_R_concen = 0.01,
102
+
103
+ # Geometric regularization
104
+ lambda_ascending = 0.0,
105
+ ascending_from = 0,
106
+
107
+ # Distortion loss (encourage distribution concentration on ray)
108
+ lambda_dist = 0.1,
109
+ dist_from = 10000,
110
+
111
+ # Consistency loss of rendered normal and derived normal from expected depth
112
+ lambda_normal_dmean = 0.0,
113
+ n_dmean_from = 10_000,
114
+ n_dmean_end = 20_000,
115
+ n_dmean_ks = 3,
116
+ n_dmean_tol_deg = 90.0,
117
+
118
+ # Consistency loss of rendered normal and derived normal from median depth
119
+ lambda_normal_dmed = 0.0,
120
+ n_dmed_from=3000,
121
+ n_dmed_end=20_000,
122
+
123
+ # Total variation loss of density grid
124
+ lambda_tv_density = 1e-10,
125
+ tv_from = 0,
126
+ tv_until = 10000,
127
+
128
+ # Data augmentation
129
+ ss_aug_max = 1.5,
130
+ rand_bg = False,
131
+ ))
132
+
133
+ cfg.init = CfgNode(dict(
134
+ # Voxel property initialization
135
+ geo_init = -10.0,
136
+ sh0_init = 0.5,
137
+ shs_init = 0.0,
138
+
139
+ sh_degree_init = 3,
140
+
141
+ # Init main inside region by dense voxels
142
+ init_n_level = 6, # (2^6)^3 voxels
143
+
144
+ # Number of voxel ratio for outside (background region)
145
+ init_out_ratio = 2.0,
146
+ ))
147
+
148
+ cfg.procedure = CfgNode(dict(
149
+ # Schedule
150
+ n_iter = 20_000,
151
+ sche_mult = 1.0,
152
+ seed=3721,
153
+
154
+ # Reset sh
155
+ reset_sh_ckpt = [-1],
156
+
157
+ # Adaptive general setup
158
+ adapt_from = 1000,
159
+ adapt_every = 1000,
160
+
161
+ # Adaptive voxel pruning
162
+ prune_until = 18000,
163
+ prune_thres_init = 0.0001,
164
+ prune_thres_final = 0.05,
165
+
166
+ # Adaptive voxel pruning
167
+ subdivide_until = 15000,
168
+ subdivide_all_until = 0,
169
+ subdivide_samp_thres = 1.0, # A voxel max sampling rate should larger than this.
170
+ subdivide_prop = 0.05,
171
+ subdivide_max_num = 10_000_000,
172
+ ))
173
+
174
+ cfg.auto_exposure = CfgNode(dict(
175
+ enable = False,
176
+ auto_exposure_upd_ckpt = [5000, 10000, 15000]
177
+ ))
178
+
179
+ for i_cfg in cfg.values():
180
+ i_cfg.set_new_allowed(True)
181
+
182
+
183
+ def everytype2bool(v):
184
+ if v.isnumeric():
185
+ return bool(int(v))
186
+ v = v.lower()
187
+ if v in ['n', 'no', 'none', 'false']:
188
+ return False
189
+ return True
190
+
191
+
192
+ def update_argparser(parser):
193
+ for name in cfg.keys():
194
+ group = parser.add_argument_group(name)
195
+ for key, value in getattr(cfg, name).items():
196
+ t = type(value)
197
+
198
+ if t == bool:
199
+ group.add_argument(f"--{key}", action='store_true' if t else 'store_false')
200
+ elif t == list:
201
+ group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs="*")
202
+ elif t == tuple:
203
+ group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs=len(value))
204
+ else:
205
+ group.add_argument(f"--{key}", default=value, type=t)
206
+
207
+
208
+ def update_config(cfg_files, cmd_lst=[]):
209
+ # Update from config files
210
+ if isinstance(cfg_files, str):
211
+ cfg_files = [cfg_files]
212
+ for cfg_path in cfg_files:
213
+ cfg.merge_from_file(cfg_path)
214
+
215
+ if len(cmd_lst) == 0:
216
+ return
217
+
218
+ # Parse the arguments from command line
219
+ internal_parser = argparse.ArgumentParser()
220
+ update_argparser(internal_parser)
221
+ internal_args = internal_parser.parse_args(cmd_lst)
222
+
223
+ # Update from command line args
224
+ for name in cfg.keys():
225
+ cfg_subgroup = getattr(cfg, name)
226
+ for key in cfg_subgroup.keys():
227
+ arg_val = getattr(internal_args, key)
228
+ # Check if the default values is updated
229
+ if internal_parser.get_default(key) != arg_val:
230
+ cfg_subgroup[key] = arg_val
src/config_old.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import argparse
10
+ from yacs.config import CfgNode
11
+
12
+
13
+ cfg = CfgNode()
14
+
15
+ cfg.model = CfgNode(dict(
16
+ n_samp_per_vox = 1, # Number of sampled points per visited voxel
17
+ sh_degree = 3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
18
+ ss = 1.5, # Super-sampling rates for anti-aliasing
19
+ white_background = False, # Assum white background
20
+ black_background = False, # Assum black background
21
+ ))
22
+
23
+ cfg.data = CfgNode(dict(
24
+ source_path = "",
25
+ image_dir_name = "images",
26
+ res_downscale = 0.,
27
+ res_width = 0,
28
+ skip_blend_alpha = False,
29
+ data_device = "cpu",
30
+ eval = False,
31
+ test_every = 8,
32
+ alpha_is_white = True,
33
+ ))
34
+
35
+ cfg.bounding = CfgNode(dict(
36
+ # Define the main (inside) region bounding box
37
+ # The default use the suggested bounding if given by dataset.
38
+ # Otherwise, it automatically chose from forward or camera_median modes.
39
+ # See src/utils/bounding_utils.py for details.
40
+
41
+ # default | camera_median | camera_max | forward | pcd
42
+ bound_mode = "default",
43
+ bound_scale = 1.0, # Scaling factor of the bound
44
+ forward_dist_scale = 1.0, # For forward mode
45
+ pcd_density_rate = 0.1, # For pcd mode
46
+
47
+ # Number of Octree level outside the main foreground region
48
+ outside_level = 5,
49
+ ))
50
+
51
+ cfg.optimizer = CfgNode(dict(
52
+ geo_lr = 0.025,
53
+ sh0_lr = 0.010,
54
+ shs_lr = 0.00025,
55
+
56
+ optim_beta1 = 0.1,
57
+ optim_beta2 = 0.99,
58
+ optim_eps = 1e-15,
59
+
60
+ lr_decay_ckpt = [19000],
61
+ lr_decay_mult = 0.1,
62
+ ))
63
+
64
+ cfg.regularizer = CfgNode(dict(
65
+ # Main photometric loss
66
+ lambda_photo = 1.0,
67
+ use_l1 = False,
68
+ use_huber = False,
69
+ huber_thres = 0.03,
70
+
71
+ # SSIM loss
72
+ lambda_ssim = 0.02,
73
+
74
+ # Sparse depth loss
75
+ lambda_sparse_depth = 0.0,
76
+ sparse_depth_until = 10_000,
77
+
78
+ # Mask loss
79
+ lambda_mask = 0.0,
80
+
81
+ # Depthanything loss
82
+ lambda_depthanythingv2 = 0.0,
83
+ depthanythingv2_from = 3000,
84
+ depthanythingv2_end = 20000,
85
+ depthanythingv2_end_mult = 0.1,
86
+
87
+ # Mast3r metrid loss
88
+ lambda_mast3r_metric_depth = 0.0,
89
+ mast3r_repo_path = '',
90
+ mast3r_metric_depth_from = 0,
91
+ mast3r_metric_depth_end = 20000,
92
+ mast3r_metric_depth_end_mult = 0.01,
93
+
94
+ # Final transmittance should concentrate to either 0 or 1
95
+ lambda_T_concen = 0.0,
96
+
97
+ # Final transmittance should be 0
98
+ lambda_T_inside = 0.0,
99
+
100
+ # Per-point rgb loss
101
+ lambda_R_concen = 0.01,
102
+
103
+ # Geometric regularization
104
+ lambda_ascending = 0.0,
105
+ ascending_from = 0,
106
+
107
+ # Distortion loss (encourage distribution concentration on ray)
108
+ lambda_dist = 0.1,
109
+ dist_from = 10000,
110
+
111
+ # Consistency loss of rendered normal and derived normal from expected depth
112
+ lambda_normal_dmean = 0.0,
113
+ n_dmean_from = 10_000,
114
+ n_dmean_end = 20_000,
115
+ n_dmean_ks = 3,
116
+ n_dmean_tol_deg = 90.0,
117
+
118
+ # Consistency loss of rendered normal and derived normal from median depth
119
+ lambda_normal_dmed = 0.0,
120
+ n_dmed_from=3000,
121
+ n_dmed_end=20_000,
122
+
123
+ # Total variation loss of density grid
124
+ lambda_tv_density = 1e-10,
125
+ tv_from = 0,
126
+ tv_until = 10000,
127
+
128
+ # Data augmentation
129
+ ss_aug_max = 1.5,
130
+ rand_bg = False,
131
+ ))
132
+
133
+ cfg.init = CfgNode(dict(
134
+ # Voxel property initialization
135
+ geo_init = -10.0,
136
+ sh0_init = 0.5,
137
+ shs_init = 0.0,
138
+
139
+ sh_degree_init = 3,
140
+
141
+ # Init main inside region by dense voxels
142
+ init_n_level = 6, # (2^6)^3 voxels
143
+
144
+ # Number of voxel ratio for outside (background region)
145
+ init_out_ratio = 2.0,
146
+ ))
147
+
148
+ cfg.procedure = CfgNode(dict(
149
+ # Schedule
150
+ n_iter = 20_000,
151
+ sche_mult = 1.0,
152
+ seed=3721,
153
+
154
+ # Reset sh
155
+ reset_sh_ckpt = [-1],
156
+
157
+ # Adaptive general setup
158
+ adapt_from = 1000,
159
+ adapt_every = 1000,
160
+
161
+ # Adaptive voxel pruning
162
+ prune_until = 18000,
163
+ prune_thres_init = 0.0001,
164
+ prune_thres_final = 0.05,
165
+
166
+ # Adaptive voxel pruning
167
+ subdivide_until = 15000,
168
+ subdivide_all_until = 0,
169
+ subdivide_samp_thres = 1.0, # A voxel max sampling rate should larger than this.
170
+ subdivide_prop = 0.05,
171
+ subdivide_max_num = 10_000_000,
172
+ ))
173
+
174
+ cfg.auto_exposure = CfgNode(dict(
175
+ enable = False,
176
+ auto_exposure_upd_ckpt = [5000, 10000, 15000]
177
+ ))
178
+
179
+ for i_cfg in cfg.values():
180
+ i_cfg.set_new_allowed(True)
181
+
182
+
183
+ def everytype2bool(v):
184
+ if v.isnumeric():
185
+ return bool(int(v))
186
+ v = v.lower()
187
+ if v in ['n', 'no', 'none', 'false']:
188
+ return False
189
+ return True
190
+
191
+
192
+ def update_argparser(parser):
193
+ for name in cfg.keys():
194
+ group = parser.add_argument_group(name)
195
+ for key, value in getattr(cfg, name).items():
196
+ t = type(value)
197
+
198
+ if t == bool:
199
+ group.add_argument(f"--{key}", action='store_true' if t else 'store_false')
200
+ elif t == list:
201
+ group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs="*")
202
+ elif t == tuple:
203
+ group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs=len(value))
204
+ else:
205
+ group.add_argument(f"--{key}", default=value, type=t)
206
+
207
+
208
+ def update_config(cfg_files, cmd_lst=[]):
209
+ # Update from config files
210
+ if isinstance(cfg_files, str):
211
+ cfg_files = [cfg_files]
212
+ for cfg_path in cfg_files:
213
+ cfg.merge_from_file(cfg_path)
214
+
215
+ if len(cmd_lst) == 0:
216
+ return
217
+
218
+ # Parse the arguments from command line
219
+ internal_parser = argparse.ArgumentParser()
220
+ update_argparser(internal_parser)
221
+ internal_args = internal_parser.parse_args(cmd_lst)
222
+
223
+ # Update from command line args
224
+ for name in cfg.keys():
225
+ cfg_subgroup = getattr(cfg, name)
226
+ for key in cfg_subgroup.keys():
227
+ arg_val = getattr(internal_args, key)
228
+ # Check if the default values is updated
229
+ if internal_parser.get_default(key) != arg_val:
230
+ cfg_subgroup[key] = arg_val
src/dataloader/__pycache__/data_pack.cpython-39.pyc ADDED
Binary file (5.62 kB). View file
 
src/dataloader/__pycache__/reader_colmap_dataset.cpython-39.pyc ADDED
Binary file (4.04 kB). View file
 
src/dataloader/__pycache__/reader_nerf_dataset.cpython-39.pyc ADDED
Binary file (3.97 kB). View file
 
src/dataloader/data_pack.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import random
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from src.dataloader.reader_colmap_dataset import read_colmap_dataset
17
+ from src.dataloader.reader_nerf_dataset import read_nerf_dataset
18
+ from src.utils.camera_utils import interpolate_poses
19
+
20
+ from src.cameras import Camera, MiniCam
21
+
22
+
23
+ class DataPack:
24
+
25
+ def __init__(self,
26
+ source_path,
27
+ image_dir_name="images",
28
+ mask_dir_name="masks",
29
+ res_downscale=0.,
30
+ res_width=0,
31
+ skip_blend_alpha=False,
32
+ alpha_is_white=False,
33
+ data_device="cpu",
34
+ use_test=False,
35
+ test_every=8,
36
+ camera_params_only=False):
37
+
38
+ camera_creator = CameraCreator(
39
+ res_downscale=res_downscale,
40
+ res_width=res_width,
41
+ skip_blend_alpha=skip_blend_alpha,
42
+ alpha_is_white=alpha_is_white,
43
+ data_device=data_device,
44
+ camera_params_only=camera_params_only,
45
+ )
46
+
47
+ sparse_path = os.path.join(source_path, "sparse")
48
+ colmap_path = os.path.join(source_path, "colmap", "sparse")
49
+ meta_path1 = os.path.join(source_path, "transforms_train.json")
50
+ meta_path2 = os.path.join(source_path, "transforms.json")
51
+
52
+ # Read images concurrently
53
+ s_time = time.perf_counter()
54
+
55
+ if os.path.exists(sparse_path) or os.path.exists(colmap_path):
56
+ print("Read dataset in COLMAP format.")
57
+ dataset = read_colmap_dataset(
58
+ source_path=source_path,
59
+ image_dir_name=image_dir_name,
60
+ mask_dir_name=mask_dir_name,
61
+ use_test=use_test,
62
+ test_every=test_every,
63
+ camera_creator=camera_creator)
64
+ elif os.path.exists(meta_path1) or os.path.exists(meta_path2):
65
+ print("Read dataset in NeRF format.")
66
+ dataset = read_nerf_dataset(
67
+ source_path=source_path,
68
+ use_test=use_test,
69
+ test_every=test_every,
70
+ camera_creator=camera_creator)
71
+ else:
72
+ raise Exception("Unknown scene type!")
73
+
74
+ e_time = time.perf_counter()
75
+ print(f"Read dataset in {e_time - s_time:.3f} seconds.")
76
+
77
+ self._cameras = {
78
+ 'train': dataset['train_cam_lst'],
79
+ 'test': dataset['test_cam_lst'],
80
+ }
81
+
82
+ ##############################
83
+ # Read additional dataset info
84
+ ##############################
85
+ # If the dataset suggested a scene bound
86
+ self.suggested_bounding = dataset.get('suggested_bounding', None)
87
+
88
+ # If the dataset provide a transformation to other coordinate
89
+ self.to_world_matrix = None
90
+ to_world_path = os.path.join(source_path, 'to_world_matrix.txt')
91
+ if os.path.isfile(to_world_path):
92
+ self.to_world_matrix = np.loadtxt(to_world_path)
93
+
94
+ # If the dataset has a point cloud
95
+ self.point_cloud = dataset.get('point_cloud', None)
96
+
97
+ def get_train_cameras(self):
98
+ return self._cameras['train']
99
+
100
+ def get_test_cameras(self):
101
+ return self._cameras['test']
102
+
103
+ def interpolate_cameras(self, n_frames, starting_id=0, ids=[], step_forward=0):
104
+ cams = self.get_train_cameras()
105
+ if len(ids):
106
+ key_poses = [cams[i].c2w.cpu().numpy() for i in ids]
107
+ else:
108
+ assert starting_id >= 0
109
+ assert starting_id < len(cams)
110
+ cam_pos = torch.stack([cam.position for cam in cams])
111
+ ids = [starting_id]
112
+ for _ in range(3):
113
+ farthest_id = torch.cdist(cam_pos[ids], cam_pos).amin(0).argmax().item()
114
+ ids.append(farthest_id)
115
+ ids[1], ids[2] = ids[2], ids[1]
116
+ key_poses = [cams[i].c2w.cpu().numpy() for i in ids]
117
+
118
+ if step_forward != 0:
119
+ for i in range(len(key_poses)):
120
+ lookat = key_poses[i][:3, 2]
121
+ key_poses[i][:3, 3] += step_forward * lookat
122
+
123
+ interp_poses = interpolate_poses(key_poses, n_frame=n_frames, periodic=True)
124
+
125
+ base_cam = cams[ids[0]]
126
+ interp_cams = [
127
+ MiniCam(
128
+ c2w=pose,
129
+ fovx=base_cam.fovx, fovy=base_cam.fovy,
130
+ width=base_cam.image_width, height=base_cam.image_height)
131
+ for pose in interp_poses]
132
+ return interp_cams
133
+
134
+
135
+ # Create a random sequence of image indices
136
+ def compute_iter_idx(num_data, num_iter):
137
+ tr_iter_idx = []
138
+ while len(tr_iter_idx) < num_iter:
139
+ lst = list(range(num_data))
140
+ random.shuffle(lst)
141
+ tr_iter_idx.extend(lst)
142
+ return tr_iter_idx[:num_iter]
143
+
144
+
145
+ # Function that create Camera instances while parsing dataset
146
+ class CameraCreator:
147
+
148
+ warned = False
149
+
150
+ def __init__(self,
151
+ res_downscale=0.,
152
+ res_width=0,
153
+ skip_blend_alpha=False,
154
+ alpha_is_white=False,
155
+ data_device="cpu",
156
+ camera_params_only=False):
157
+
158
+ self.res_downscale = res_downscale
159
+ self.res_width = res_width
160
+ self.skip_blend_alpha = skip_blend_alpha
161
+ self.alpha_is_white = alpha_is_white
162
+ self.data_device = data_device
163
+ self.camera_params_only = camera_params_only
164
+
165
+ def __call__(self,
166
+ image,
167
+ w2c,
168
+ fovx,
169
+ fovy,
170
+ cx_p=0.5,
171
+ cy_p=0.5,
172
+ sparse_pt=None,
173
+ image_name="",
174
+ mask=None):
175
+
176
+ # Determine target resolution
177
+ if self.res_downscale > 0:
178
+ downscale = self.res_downscale
179
+ elif self.res_width > 0:
180
+ downscale = image.size[0] / self.res_width
181
+ else:
182
+ downscale = 1
183
+
184
+ total_pix = image.size[0] * image.size[1]
185
+ if total_pix > 1200 ** 2 and not self.warned:
186
+ self.warned = True
187
+ suggest_ds = (total_pix ** 0.5) / 1200
188
+ print(f"###################################################################")
189
+ print(f"Image too large. Suggest to use `--res_downscale {suggest_ds:.1f}`.")
190
+ print(f"###################################################################")
191
+
192
+ # Load camera parameters only
193
+ if self.camera_params_only:
194
+ return MiniCam(
195
+ c2w=np.linalg.inv(w2c),
196
+ fovx=fovx, fovy=fovy,
197
+ cx_p=cx_p, cy_p=cy_p,
198
+ width=round(image.size[0] / downscale),
199
+ height=round(image.size[1] / downscale),
200
+ image_name=image_name)
201
+
202
+ # Resize image if needed
203
+ if downscale != 1:
204
+ size = (round(image.size[0] / downscale), round(image.size[1] / downscale))
205
+ image = image.resize(size)
206
+
207
+ # Convert image to tensor
208
+ tensor = torch.tensor(np.array(image), dtype=torch.float32).moveaxis(-1, 0) / 255.0
209
+ if tensor.shape[0] == 4:
210
+ # Blend alpha channel
211
+ tensor, mask = tensor.split([3, 1], dim=0)
212
+ if not self.skip_blend_alpha:
213
+ tensor = tensor * mask + int(self.alpha_is_white) * (1 - mask)
214
+
215
+ # Conver mask to tensor if there is
216
+ if mask is not None:
217
+ size = tensor.shape[-2:][::-1]
218
+ if mask.size != size:
219
+ mask = mask.resize(size)
220
+ mask = torch.tensor(np.array(mask), dtype=torch.float32) / 255.0
221
+ if len(mask.shape) == 3:
222
+ mask = mask.mean(-1)
223
+ mask = mask[None]
224
+
225
+ return Camera(
226
+ w2c=w2c,
227
+ fovx=fovx, fovy=fovy,
228
+ cx_p=cx_p, cy_p=cy_p,
229
+ image=tensor,
230
+ mask=mask,
231
+ sparse_pt=sparse_pt,
232
+ image_name=image_name)
src/dataloader/reader_colmap_dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import json
11
+ import natsort
12
+ import pycolmap
13
+ import numpy as np
14
+ from PIL import Image
15
+ from pathlib import Path
16
+ import concurrent.futures
17
+
18
+ from src.utils.colmap_utils import parse_colmap_pts
19
+ from src.utils.camera_utils import focal2fov
20
+
21
+
22
+ def read_colmap_dataset(source_path, image_dir_name, mask_dir_name, use_test, test_every, camera_creator):
23
+ """
24
+ Read a COLMAP dataset and return cameras, intrinsics, extrinsics, and optional masks.
25
+
26
+ Fixes:
27
+ - Safe image/mask opening using `with Image.open(...)` (no file leaks).
28
+ - Compatible with both old/new pycolmap APIs.
29
+ - Returns PIL.Image objects (for backward compatibility with DataPack).
30
+ """
31
+
32
+ source_path = Path(source_path)
33
+
34
+ # ---------------- Parse COLMAP reconstruction ----------------
35
+ sparse_path = source_path / "sparse" / "0"
36
+ if not sparse_path.exists():
37
+ sparse_path = source_path / "colmap" / "sparse" / "0"
38
+ if not sparse_path.exists():
39
+ raise Exception("Cannot find COLMAP reconstruction (expected sparse/0 or colmap/sparse/0).")
40
+
41
+ sfm = pycolmap.Reconstruction(sparse_path)
42
+ point_cloud = parse_colmap_pts(sfm)
43
+ correspondent = point_cloud.corr
44
+
45
+ # ---------------- Sort key by filename ----------------
46
+ keys = natsort.natsorted(sfm.images.keys(), key=lambda k: sfm.images[k].name)
47
+
48
+ # ---------------- Load all frames ----------------
49
+ todo_lst = []
50
+ for key in keys:
51
+ frame = sfm.images[key]
52
+
53
+ # ---- Load RGB image safely ----
54
+ image_path = source_path / image_dir_name / frame.name
55
+ if not image_path.exists():
56
+ image_path = image_path.with_suffix(".png")
57
+ if not image_path.exists():
58
+ image_path = image_path.with_suffix(".jpg")
59
+ if not image_path.exists():
60
+ image_path = image_path.with_suffix(".JPG")
61
+ if not image_path.exists():
62
+ raise Exception(f"File not found: {str(image_path)}")
63
+
64
+ # safely open and immediately copy to new PIL object (closed after copy)
65
+ with Image.open(image_path) as img:
66
+ image = img.copy() # copy keeps data in memory, closes file handle
67
+
68
+ # ---- Load intrinsics ----
69
+ if frame.camera.model.name == "SIMPLE_PINHOLE":
70
+ focal_x, cx, cy = frame.camera.params
71
+ fovx = focal2fov(focal_x, frame.camera.width)
72
+ fovy = focal2fov(focal_x, frame.camera.height)
73
+ cx_p = cx / frame.camera.width
74
+ cy_p = cy / frame.camera.height
75
+ elif frame.camera.model.name == "PINHOLE":
76
+ focal_x, focal_y, cx, cy = frame.camera.params
77
+ fovx = focal2fov(focal_x, frame.camera.width)
78
+ fovy = focal2fov(focal_y, frame.camera.height)
79
+ cx_p = cx / frame.camera.width
80
+ cy_p = cy / frame.camera.height
81
+ else:
82
+ raise ValueError(
83
+ f"Unsupported COLMAP camera model: {frame.camera.model.name}. "
84
+ "Only undistorted SIMPLE_PINHOLE and PINHOLE are supported."
85
+ )
86
+
87
+ # ---- Load extrinsics (support both pycolmap APIs) ----
88
+ w2c = np.eye(4, dtype=np.float32)
89
+ cam_from_world = getattr(frame, "cam_from_world", None)
90
+ if cam_from_world is not None:
91
+ if callable(cam_from_world):
92
+ # Old pycolmap API
93
+ w2c[:3] = cam_from_world().matrix()
94
+ else:
95
+ # New pycolmap API (Rigid3d object)
96
+ w2c[:3] = cam_from_world.matrix()
97
+ else:
98
+ raise RuntimeError("Cannot find cam_from_world attribute in COLMAP frame.")
99
+
100
+ # ---- Sparse point correspondence ----
101
+ sparse_pt = point_cloud.points[correspondent[frame.name]]
102
+
103
+ # ---- Optional mask ----
104
+ mask = None
105
+ if mask_dir_name is not None:
106
+ mask_path = (source_path / mask_dir_name / frame.name).with_suffix(".png")
107
+ if mask_path.exists():
108
+ with Image.open(mask_path) as m:
109
+ mask = m.copy() # keep PIL.Image for DataPack
110
+
111
+ # ---- Store frame data ----
112
+ todo_lst.append(dict(
113
+ image=image,
114
+ w2c=w2c,
115
+ fovx=fovx,
116
+ fovy=fovy,
117
+ cx_p=cx_p,
118
+ cy_p=cy_p,
119
+ sparse_pt=sparse_pt,
120
+ image_name=image_path.name,
121
+ mask=mask,
122
+ ))
123
+
124
+ # ---------------- Create cameras concurrently ----------------
125
+ import torch
126
+ torch.inverse(torch.eye(3, device="cuda")) # fix PyTorch lazy init bug
127
+
128
+ with concurrent.futures.ThreadPoolExecutor() as executor:
129
+ futures = [executor.submit(camera_creator, **todo) for todo in todo_lst]
130
+ cam_lst = [f.result() for f in futures]
131
+
132
+ # ---------------- Split train/test ----------------
133
+ if use_test:
134
+ train_cam_lst = [cam for i, cam in enumerate(cam_lst) if i % test_every != 0]
135
+ test_cam_lst = [cam for i, cam in enumerate(cam_lst) if i % test_every == 0]
136
+ else:
137
+ train_cam_lst = cam_lst
138
+ test_cam_lst = []
139
+
140
+ # ---------------- Optional bounding box ----------------
141
+ nerf_normalization_path = source_path / "nerf_normalization.json"
142
+ if nerf_normalization_path.is_file():
143
+ with open(nerf_normalization_path) as f:
144
+ nerf_norm = json.load(f)
145
+ suggested_center = np.array(nerf_norm["center"], dtype=np.float32)
146
+ suggested_radius = np.array(nerf_norm["radius"], dtype=np.float32)
147
+ suggested_bounding = np.stack([
148
+ suggested_center - suggested_radius,
149
+ suggested_center + suggested_radius,
150
+ ])
151
+ else:
152
+ suggested_bounding = None
153
+
154
+ # ---------------- Return dataset ----------------
155
+ dataset = {
156
+ "train_cam_lst": train_cam_lst,
157
+ "test_cam_lst": test_cam_lst,
158
+ "suggested_bounding": suggested_bounding,
159
+ "point_cloud": point_cloud,
160
+ }
161
+ return dataset
162
+
src/dataloader/reader_colmap_dataset_or.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import json
11
+ import natsort
12
+ import pycolmap
13
+ import numpy as np
14
+ from PIL import Image
15
+ from pathlib import Path
16
+ import concurrent.futures
17
+
18
+ from src.utils.colmap_utils import parse_colmap_pts
19
+ from src.utils.camera_utils import focal2fov
20
+
21
+
22
+ def read_colmap_dataset(source_path, image_dir_name, mask_dir_name, use_test, test_every, camera_creator):
23
+
24
+ source_path = Path(source_path)
25
+
26
+ # Parse colmap meta data
27
+ sparse_path = source_path / "sparse" / "0"
28
+ if not sparse_path.exists():
29
+ sparse_path = source_path / "colmap" / "sparse" / "0"
30
+ if not sparse_path.exists():
31
+ raise Exception("Can not find COLMAP reconstruction.")
32
+
33
+ sfm = pycolmap.Reconstruction(sparse_path)
34
+ point_cloud = parse_colmap_pts(sfm)
35
+ correspondent = point_cloud.corr
36
+
37
+ # Sort key by filename
38
+ keys = natsort.natsorted(
39
+ sfm.images.keys(),
40
+ key = lambda k : sfm.images[k].name)
41
+
42
+ # Load all images and cameras
43
+ todo_lst = []
44
+ for key in keys:
45
+
46
+ frame = sfm.images[key]
47
+
48
+ # Load image
49
+ image_path = source_path / image_dir_name / frame.name
50
+ if not image_path.exists():
51
+ image_path = image_path.with_suffix('.png')
52
+ if not image_path.exists():
53
+ image_path = image_path.with_suffix('.jpg')
54
+ if not image_path.exists():
55
+ image_path = image_path.with_suffix('.JPG')
56
+ if not image_path.exists():
57
+ raise Exception(f"File not found: {str(image_path)}")
58
+ image = Image.open(image_path)
59
+
60
+ # Load camera intrinsic
61
+ if frame.camera.model.name == "SIMPLE_PINHOLE":
62
+ focal_x, cx, cy = frame.camera.params
63
+ fovx = focal2fov(focal_x, frame.camera.width)
64
+ fovy = focal2fov(focal_x, frame.camera.height)
65
+ cx_p = cx / frame.camera.width
66
+ cy_p = cy / frame.camera.height
67
+ elif frame.camera.model.name == "PINHOLE":
68
+ focal_x, focal_y, cx, cy = frame.camera.params
69
+ fovx = focal2fov(focal_x, frame.camera.width)
70
+ fovy = focal2fov(focal_y, frame.camera.height)
71
+ cx_p = cx / frame.camera.width
72
+ cy_p = cy / frame.camera.height
73
+ else:
74
+ assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
75
+
76
+ # Load camera extrinsic
77
+ w2c = np.eye(4, dtype=np.float32)
78
+ try:
79
+ w2c[:3] = frame.cam_from_world().matrix()
80
+ except:
81
+ # Older version of pycolmap
82
+ w2c[:3] = frame.cam_from_world.matrix()
83
+
84
+ # Load sparse point
85
+ sparse_pt = point_cloud.points[correspondent[frame.name]]
86
+
87
+ # Load mask if there is
88
+ mask_path = (source_path / mask_dir_name / frame.name).with_suffix('.png')
89
+ if mask_path.exists():
90
+ mask = Image.open(mask_path)
91
+ else:
92
+ mask = None
93
+
94
+ todo_lst.append(dict(
95
+ image=image,
96
+ w2c=w2c,
97
+ fovx=fovx,
98
+ fovy=fovy,
99
+ cx_p=cx_p,
100
+ cy_p=cy_p,
101
+ sparse_pt=sparse_pt,
102
+ image_name=image_path.name,
103
+ mask=mask,
104
+ ))
105
+
106
+ # Load all cameras concurrently
107
+ import torch
108
+ torch.inverse(torch.eye(3, device="cuda")) # Fix module lazy loading bug:
109
+ # https://github.com/pytorch/pytorch/issues/90613
110
+
111
+ with concurrent.futures.ThreadPoolExecutor() as executor:
112
+ futures = [executor.submit(camera_creator, **todo) for todo in todo_lst]
113
+ cam_lst = [f.result() for f in futures]
114
+
115
+ # Split train/test
116
+ if use_test:
117
+ train_cam_lst = [
118
+ cam for i, cam in enumerate(cam_lst)
119
+ if i % test_every != 0]
120
+ test_cam_lst = [
121
+ cam for i, cam in enumerate(cam_lst)
122
+ if i % test_every == 0]
123
+ else:
124
+ train_cam_lst = cam_lst
125
+ test_cam_lst = []
126
+
127
+ # Parse main scene bound if there is
128
+ nerf_normalization_path = os.path.join(source_path, "nerf_normalization.json")
129
+ if os.path.isfile(nerf_normalization_path):
130
+ with open(nerf_normalization_path) as f:
131
+ nerf_normalization = json.load(f)
132
+ suggested_center = np.array(nerf_normalization["center"], dtype=np.float32)
133
+ suggested_radius = np.array(nerf_normalization["radius"], dtype=np.float32)
134
+ suggested_bounding = np.stack([
135
+ suggested_center - suggested_radius,
136
+ suggested_center + suggested_radius,
137
+ ])
138
+ else:
139
+ suggested_bounding = None
140
+
141
+ # Pack dataset
142
+ dataset = {
143
+ 'train_cam_lst': train_cam_lst,
144
+ 'test_cam_lst': test_cam_lst,
145
+ 'suggested_bounding': suggested_bounding,
146
+ 'point_cloud': point_cloud,
147
+ }
148
+ return dataset
src/dataloader/reader_nerf_dataset.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import json
11
+ import pycolmap
12
+ import numpy as np
13
+ from PIL import Image
14
+ from pathlib import Path
15
+ import concurrent.futures
16
+
17
+ from src.utils.colmap_utils import parse_colmap_pts
18
+ from src.utils.camera_utils import fov2focal, focal2fov
19
+
20
+
21
+ def read_nerf_dataset(source_path, test_every, use_test, camera_creator):
22
+
23
+ source_path = Path(source_path)
24
+
25
+ # Load training cameras
26
+ if (source_path / "transforms_train.json").exists():
27
+ train_cam_lst, point_cloud = read_cameras_from_json(
28
+ source_path=source_path,
29
+ meta_fname="transforms_train.json",
30
+ camera_creator=camera_creator)
31
+ else:
32
+ train_cam_lst, point_cloud = read_cameras_from_json(
33
+ source_path=source_path,
34
+ meta_fname="transforms.json",
35
+ camera_creator=camera_creator)
36
+
37
+ # Load testing cameras
38
+ if (source_path / "transforms_test.json").exists():
39
+ test_cam_lst, _ = read_cameras_from_json(
40
+ source_path=source_path,
41
+ meta_fname="transforms_test.json",
42
+ camera_creator=camera_creator)
43
+ elif use_test:
44
+ test_cam_lst = [
45
+ cam for i, cam in enumerate(train_cam_lst)
46
+ if i % test_every == 0]
47
+ train_cam_lst = [
48
+ cam for i, cam in enumerate(train_cam_lst)
49
+ if i % test_every != 0]
50
+ else:
51
+ test_cam_lst = []
52
+
53
+ # Parse main scene bound if there is
54
+ nerf_normalization_path = os.path.join(source_path, "nerf_normalization.json")
55
+ if os.path.isfile(nerf_normalization_path):
56
+ with open(nerf_normalization_path) as f:
57
+ nerf_normalization = json.load(f)
58
+ suggested_center = np.array(nerf_normalization["center"], dtype=np.float32)
59
+ suggested_radius = np.array(nerf_normalization["radius"], dtype=np.float32)
60
+ suggested_bounding = np.stack([
61
+ suggested_center - suggested_radius,
62
+ suggested_center + suggested_radius,
63
+ ])
64
+ else:
65
+ # Assume synthetic blender scene bound
66
+ suggested_bounding = np.array([
67
+ [-1.5, -1.5, -1.5],
68
+ [1.5, 1.5, 1.5],
69
+ ], dtype=np.float32)
70
+
71
+ # Pack dataset
72
+ dataset = {
73
+ 'train_cam_lst': train_cam_lst,
74
+ 'test_cam_lst': test_cam_lst,
75
+ 'suggested_bounding': suggested_bounding,
76
+ 'point_cloud': point_cloud,
77
+ }
78
+ return dataset
79
+
80
+
81
+ def read_cameras_from_json(source_path, meta_fname, camera_creator):
82
+
83
+ with open(source_path / meta_fname) as f:
84
+ meta = json.load(f)
85
+
86
+ # Load COLMAP points if there is
87
+ if "colmap" in meta:
88
+ sfm = pycolmap.Reconstruction(source_path / meta["colmap"]["path"])
89
+ if "transform" in meta["colmap"]:
90
+ transform = np.array(meta["colmap"]["transform"])
91
+ else:
92
+ transform = None
93
+ point_cloud = parse_colmap_pts(sfm, transform)
94
+ correspondent = point_cloud.corr
95
+ else:
96
+ point_cloud = None
97
+ correspondent = None
98
+
99
+ # Load global setup
100
+ global_fovx = meta.get("camera_angle_x", 0)
101
+ global_fovy = meta.get("camera_angle_y", 0)
102
+ global_cx_p = parse_principle_point(meta, is_cx=True)
103
+ global_cy_p = parse_principle_point(meta, is_cx=False)
104
+
105
+ # Load all images and cameras
106
+ todo_lst = []
107
+ for frame in meta["frames"]:
108
+
109
+ # Guess the rgb image path and load image
110
+ path_candidates = [
111
+ source_path / frame["file_path"],
112
+ source_path / (frame["file_path"] + '.png'),
113
+ source_path / (frame["file_path"] + '.jpg'),
114
+ source_path / (frame["file_path"] + '.JPG'),
115
+ ]
116
+ for image_path in path_candidates:
117
+ if image_path.exists():
118
+ break
119
+
120
+ if frame.get('heldout', False):
121
+ image = Image.new('RGB', (frame['w'], frame['h']))
122
+ elif image_path.exists():
123
+ image = Image.open(image_path)
124
+ else:
125
+ raise Exception(f"File not found: {str(image_path)}")
126
+
127
+ # Load camera intrinsic
128
+ fovx = frame.get('camera_angle_x', global_fovx)
129
+ cx_p = frame.get('cx_p', global_cx_p)
130
+ cy_p = frame.get('cy_p', global_cy_p)
131
+
132
+ if 'camera_angle_y' in frame:
133
+ fovy = frame['camera_angle_y']
134
+ elif global_fovy > 0:
135
+ fovy = global_fovy
136
+ else:
137
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
138
+
139
+ # Load camera pose
140
+ c2w = np.array(frame["transform_matrix"])
141
+ c2w[:3, 1:3] *= -1 # from opengl y-up-z-back to colmap y-down-z-forward
142
+ w2c = np.linalg.inv(c2w).astype(np.float32)
143
+
144
+ # Load sparse point
145
+ if point_cloud is not None:
146
+ sparse_pt = point_cloud.points[correspondent[image_path.name]]
147
+ else:
148
+ sparse_pt = None
149
+
150
+ todo_lst.append(dict(
151
+ image=image,
152
+ w2c=w2c,
153
+ fovx=fovx,
154
+ fovy=fovy,
155
+ cx_p=cx_p,
156
+ cy_p=cy_p,
157
+ sparse_pt=sparse_pt,
158
+ image_name=image_path.name,
159
+ ))
160
+
161
+ # Load all cameras concurrently
162
+ import torch
163
+ torch.inverse(torch.eye(3, device="cuda")) # Fix module lazy loading bug:
164
+ # https://github.com/pytorch/pytorch/issues/90613
165
+
166
+ with concurrent.futures.ThreadPoolExecutor() as executor:
167
+ futures = [executor.submit(camera_creator, **todo) for todo in todo_lst]
168
+ cam_lst = [f.result() for f in futures]
169
+
170
+ return cam_lst, point_cloud
171
+
172
+
173
+ def parse_principle_point(info, is_cx):
174
+ key = "cx" if is_cx else "cy"
175
+ key_res = "w" if is_cx else "h"
176
+ if f"{key}_p" in info:
177
+ return info[f"{key}_p"]
178
+ if key in info and key_res in info:
179
+ return info[key] / info[key_res]
180
+ return None
src/dataloader/reader_nerf_dataset_copy.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import json
11
+ import pycolmap
12
+ import numpy as np
13
+ from PIL import Image
14
+ from pathlib import Path
15
+
16
+ from src.utils.colmap_utils import parse_colmap_pts
17
+ from src.utils.camera_utils import fov2focal, focal2fov
18
+
19
+
20
+ def read_nerf_dataset(source_path, test_every, use_test, camera_creator):
21
+
22
+ source_path = Path(source_path)
23
+
24
+ # Load training cameras
25
+ if (source_path / "transforms_train.json").exists():
26
+ train_cam_lst, point_cloud = read_cameras_from_json(
27
+ source_path=source_path,
28
+ meta_fname="transforms_train.json",
29
+ camera_creator=camera_creator)
30
+ else:
31
+ train_cam_lst, point_cloud = read_cameras_from_json(
32
+ source_path=source_path,
33
+ meta_fname="transforms.json",
34
+ camera_creator=camera_creator)
35
+
36
+ # Load testing cameras
37
+ if (source_path / "transforms_test.json").exists():
38
+ test_cam_lst, _ = read_cameras_from_json(
39
+ source_path=source_path,
40
+ meta_fname="transforms_test.json",
41
+ camera_creator=camera_creator)
42
+ elif use_test:
43
+ test_cam_lst = [
44
+ cam for i, cam in enumerate(train_cam_lst)
45
+ if i % test_every == 0]
46
+ train_cam_lst = [
47
+ cam for i, cam in enumerate(train_cam_lst)
48
+ if i % test_every != 0]
49
+ else:
50
+ test_cam_lst = []
51
+
52
+ # Parse main scene bound if there is
53
+ nerf_normalization_path = os.path.join(source_path, "nerf_normalization.json")
54
+ if os.path.isfile(nerf_normalization_path):
55
+ with open(nerf_normalization_path) as f:
56
+ nerf_normalization = json.load(f)
57
+ suggested_center = np.array(nerf_normalization["center"], dtype=np.float32)
58
+ suggested_radius = np.array(nerf_normalization["radius"], dtype=np.float32)
59
+ suggested_bounding = np.stack([
60
+ suggested_center - suggested_radius,
61
+ suggested_center + suggested_radius,
62
+ ])
63
+ else:
64
+ # Assume synthetic blender scene bound
65
+ suggested_bounding = np.array([
66
+ [-1.5, -1.5, -1.5],
67
+ [1.5, 1.5, 1.5],
68
+ ], dtype=np.float32)
69
+
70
+ # Pack dataset
71
+ dataset = {
72
+ 'train_cam_lst': train_cam_lst,
73
+ 'test_cam_lst': test_cam_lst,
74
+ 'suggested_bounding': suggested_bounding,
75
+ 'point_cloud': point_cloud,
76
+ }
77
+ return dataset
78
+
79
+
80
+ def read_cameras_from_json(source_path, meta_fname, camera_creator):
81
+
82
+ with open(source_path / meta_fname) as f:
83
+ meta = json.load(f)
84
+
85
+ # Load COLMAP points if there is
86
+ if "colmap" in meta:
87
+ sfm = pycolmap.Reconstruction(source_path / meta["colmap"]["path"])
88
+ if "transform" in meta["colmap"]:
89
+ transform = np.array(meta["colmap"]["transform"])
90
+ else:
91
+ transform = None
92
+ point_cloud = parse_colmap_pts(sfm, transform)
93
+ correspondent = point_cloud.corr
94
+ else:
95
+ point_cloud = None
96
+ correspondent = None
97
+
98
+ # Load global setup
99
+ global_fovx = meta.get("camera_angle_x", 0)
100
+ global_fovy = meta.get("camera_angle_y", 0)
101
+ global_cx_p = parse_principle_point(meta, is_cx=True)
102
+ global_cy_p = parse_principle_point(meta, is_cx=False)
103
+
104
+ # Load all images and cameras
105
+ cam_lst = []
106
+ for frame in meta["frames"]:
107
+
108
+ # Guess the rgb image path and load image
109
+ path_candidates = [
110
+ source_path / frame["file_path"],
111
+ source_path / (frame["file_path"] + '.png'),
112
+ source_path / (frame["file_path"] + '.jpg'),
113
+ source_path / (frame["file_path"] + '.JPG'),
114
+ ]
115
+ for image_path in path_candidates:
116
+ if image_path.exists():
117
+ break
118
+
119
+ if frame.get('heldout', False):
120
+ image = Image.new('RGB', (frame['w'], frame['h']))
121
+ elif image_path.exists():
122
+ image = Image.open(image_path)
123
+ else:
124
+ raise Exception(f"File not found: {str(image_path)}")
125
+
126
+ # Load camera intrinsic
127
+ fovx = frame.get('camera_angle_x', global_fovx)
128
+ cx_p = frame.get('cx_p', global_cx_p)
129
+ cy_p = frame.get('cy_p', global_cy_p)
130
+
131
+ if 'camera_angle_y' in frame:
132
+ fovy = frame['camera_angle_y']
133
+ elif global_fovy > 0:
134
+ fovy = global_fovy
135
+ else:
136
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
137
+
138
+ # Load camera pose
139
+ c2w = np.array(frame["transform_matrix"])
140
+ c2w[:3, 1:3] *= -1 # from opengl y-up-z-back to colmap y-down-z-forward
141
+ w2c = np.linalg.inv(c2w).astype(np.float32)
142
+
143
+ # Load sparse point
144
+ if point_cloud is not None:
145
+ sparse_pt = point_cloud.points[correspondent[image_path.name]]
146
+ else:
147
+ sparse_pt = None
148
+
149
+ cam_lst.append(camera_creator(
150
+ image=image,
151
+ w2c=w2c,
152
+ fovx=fovx,
153
+ fovy=fovy,
154
+ cx_p=cx_p,
155
+ cy_p=cy_p,
156
+ sparse_pt=sparse_pt,
157
+ image_name=image_path.name,
158
+ ))
159
+
160
+ return cam_lst, point_cloud
161
+
162
+
163
+ def parse_principle_point(info, is_cx):
164
+ key = "cx" if is_cx else "cy"
165
+ key_res = "w" if is_cx else "h"
166
+ if f"{key}_p" in info:
167
+ return info[f"{key}_p"]
168
+ if key in info and key_res in info:
169
+ return info[key] / info[key_res]
170
+ return None
src/sparse_voxel_gears/__pycache__/adaptive.cpython-39.pyc ADDED
Binary file (6.77 kB). View file
 
src/sparse_voxel_gears/__pycache__/constructor.cpython-39.pyc ADDED
Binary file (8 kB). View file
 
src/sparse_voxel_gears/__pycache__/io.cpython-39.pyc ADDED
Binary file (4.76 kB). View file
 
src/sparse_voxel_gears/__pycache__/pooling.cpython-39.pyc ADDED
Binary file (1.81 kB). View file
 
src/sparse_voxel_gears/__pycache__/properties.cpython-39.pyc ADDED
Binary file (5.06 kB). View file
 
src/sparse_voxel_gears/__pycache__/renderer.cpython-39.pyc ADDED
Binary file (3.74 kB). View file
 
src/sparse_voxel_gears/adaptive.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+
11
+ from src.utils import octree_utils
12
+
13
+ '''
14
+ Adaptive sparse voxel pruning and subdivision.
15
+ There are three types of data mode to tackle.
16
+
17
+ 1. Per-voxel attribute:
18
+ Each voxel has it's own non-trainable data field.
19
+
20
+ 2. Per-voxel parameters:
21
+ Similar to per-voxel attribute but these are trainable parameters.
22
+
23
+ 3. Grid points parameters:
24
+ The trainable parameters are attached to the eight grid points of each voxel.
25
+ A grid point parameter can be shared by adjacent voxels.
26
+ '''
27
+
28
+ class SVAdaptive:
29
+
30
+ @torch.no_grad()
31
+ def pruning(self, prune_mask):
32
+ '''
33
+ Prune sparse voxels. The grid points are updated accordingly.
34
+
35
+ Input:
36
+ @prune_mask [N] Mask indicating the voxels to prune.
37
+ '''
38
+ if len(prune_mask.shape) == 2:
39
+ assert prune_mask.shape[1] == 1
40
+ prune_mask = prune_mask.squeeze(1)
41
+ assert prune_mask.shape == (self.num_voxels, )
42
+ kept_idx = (~prune_mask).argwhere().squeeze(1)
43
+ if len(kept_idx) == 0:
44
+ return
45
+
46
+ old_vox_key = self.vox_key.clone()
47
+
48
+ # Prune non-trainable per-voxel attributes.
49
+ for name in self.per_voxel_attr_lst:
50
+ ori_attr = getattr(self, name)
51
+ new_attr = mask_cat_perm(ori_attr, kept_idx=kept_idx)
52
+ setattr(self, name, new_attr)
53
+ if name == '_subdiv_p' and ori_attr.grad is not None:
54
+ self._subdiv_p.grad = mask_cat_perm(ori_attr.grad, kept_idx=kept_idx)
55
+ self._subdiv_p.requires_grad_()
56
+ del ori_attr
57
+ torch.cuda.empty_cache()
58
+
59
+ # Prune trainable per-voxel parameters.
60
+ for name in self.per_voxel_param_lst:
61
+ ori_param = getattr(self, name).detach()
62
+ new_param = mask_cat_perm(
63
+ ori_param,
64
+ kept_idx=kept_idx).requires_grad_()
65
+ setattr(self, name, new_param)
66
+ del ori_param, new_param
67
+ torch.cuda.empty_cache()
68
+
69
+ # Prune trainable grid points parameters (on voxel corners).
70
+ for name in self.grid_pts_param_lst:
71
+ ori_grid_pts = getattr(self, name).detach()
72
+
73
+ # Update parameter
74
+ ori_vox_grid_pts_val = ori_grid_pts[old_vox_key]
75
+ new_vox_val = mask_cat_perm(
76
+ ori_vox_grid_pts_val,
77
+ kept_idx=kept_idx)
78
+ new_param = agg_voxel_into_grid_pts(
79
+ self.num_grid_pts, # It's the updated one
80
+ self.vox_key,
81
+ new_vox_val).requires_grad_()
82
+ setattr(self, name, new_param)
83
+ del ori_grid_pts, ori_vox_grid_pts_val, new_vox_val, new_param
84
+ torch.cuda.empty_cache()
85
+
86
+ @torch.no_grad()
87
+ def subdividing(self, subdivide_mask):
88
+ '''
89
+ Prune sparse voxels. The grid points are updated accordingly.
90
+
91
+ Input:
92
+ @subdivide_mask [N] Mask indicating the voxels to subdivide.
93
+ '''
94
+ # Compute voxel index to keep and to subdivided
95
+ if len(subdivide_mask.shape) == 2:
96
+ assert subdivide_mask.shape[1] == 1
97
+ subdivide_mask = subdivide_mask.squeeze(1)
98
+ assert subdivide_mask.shape == (self.num_voxels, )
99
+ kept_idx = (~subdivide_mask).argwhere().squeeze(1)
100
+ subdivide_idx = subdivide_mask.argwhere().squeeze(1)
101
+ if len(subdivide_idx) == 0:
102
+ return
103
+
104
+ old_vox_key = self.vox_key.clone()
105
+
106
+ # Subdivide non-trainable per-voxel attributes.
107
+ octpath, octlevel = octree_utils.gen_children(
108
+ self.octpath[subdivide_idx],
109
+ self.octlevel[subdivide_idx])
110
+
111
+ special_subdiv = dict(
112
+ octpath=octpath,
113
+ octlevel=octlevel,
114
+ )
115
+
116
+ for name in self.per_voxel_attr_lst:
117
+ ori_attr = getattr(self, name)
118
+ if name in special_subdiv:
119
+ subdiv_attr = special_subdiv.pop(name)
120
+ else:
121
+ subdiv_attr = ori_attr[subdivide_idx].repeat_interleave(8, dim=0)
122
+ new_attr = mask_cat_perm(
123
+ ori_attr,
124
+ kept_idx=kept_idx,
125
+ cat_tensor=subdiv_attr)
126
+ setattr(self, name, new_attr)
127
+ if name == '_subdiv_p' and ori_attr.grad is not None:
128
+ self._subdiv_p.grad = mask_cat_perm(
129
+ ori_attr.grad,
130
+ kept_idx=kept_idx,
131
+ cat_tensor=subdiv_attr)
132
+ self._subdiv_p.requires_grad_()
133
+ del ori_attr, subdiv_attr
134
+
135
+ assert len(special_subdiv) == 0
136
+ torch.cuda.empty_cache()
137
+
138
+ # Subdivide trainable per-voxel parameters.
139
+ for name in self.per_voxel_param_lst:
140
+ ori_param = getattr(self, name).detach()
141
+
142
+ # Update parameter
143
+ subdiv_param = ori_param[subdivide_idx].repeat_interleave(8, dim=0)
144
+ new_param = mask_cat_perm(
145
+ ori_param,
146
+ kept_idx=kept_idx,
147
+ cat_tensor=subdiv_param).requires_grad_()
148
+ setattr(self, name, new_param)
149
+ del ori_param, subdiv_param, new_param
150
+ torch.cuda.empty_cache()
151
+
152
+ # Subdivide grid points parameters (on voxel corners).
153
+ for name in self.grid_pts_param_lst:
154
+ ori_grid_pts = getattr(self, name).detach()
155
+
156
+ # Update parameter
157
+ # First we gather grid_pts values into each voxel first.
158
+ # The voxel is then subdivided by trilinear interpolation.
159
+ # Finally, we gather voxel values back to the grid_pts.
160
+ ori_vox_grid_pts_val = ori_grid_pts[old_vox_key]
161
+ subdiv_vox_grid_pts_val = subdivide_by_interp(
162
+ ori_vox_grid_pts_val[subdivide_idx])
163
+ new_vox_val = mask_cat_perm(
164
+ ori_vox_grid_pts_val,
165
+ kept_idx=kept_idx,
166
+ cat_tensor=subdiv_vox_grid_pts_val)
167
+ del ori_grid_pts, ori_vox_grid_pts_val, subdiv_vox_grid_pts_val
168
+
169
+ new_param = agg_voxel_into_grid_pts(
170
+ self.num_grid_pts, # It's the updated one
171
+ self.vox_key,
172
+ new_vox_val).cuda().requires_grad_()
173
+ setattr(self, name, new_param)
174
+ del new_vox_val, new_param
175
+ torch.cuda.empty_cache()
176
+
177
+ @torch.no_grad()
178
+ def sh_degree_add1(self):
179
+ if self.active_sh_degree < self.max_sh_degree:
180
+ self.active_sh_degree += 1
181
+
182
+ @torch.no_grad()
183
+ def compute_training_stat(self, camera_lst):
184
+ '''
185
+ Compute the following statistic of each voxel from the given cameras.
186
+ 1. max_w: the maximum blending weight.
187
+ 2. min_samp_interval: the minimum sampling interval (inverse of maximum sampling rate).
188
+ 3. view_cnt: number of cameras with non-zero blending weight.
189
+
190
+ Input:
191
+ @camera_lst [Camera, ...] A list of cameras.
192
+ '''
193
+ self.freeze_vox_geo()
194
+ max_w = torch.zeros([self.num_voxels, 1], dtype=torch.float32, device="cuda")
195
+ min_samp_interval = torch.full([self.num_voxels, 1], 1e30, dtype=torch.float32, device="cuda")
196
+ view_cnt = torch.zeros([self.num_voxels, 1], dtype=torch.float32, device="cuda")
197
+ for camera in camera_lst:
198
+ max_w_i = self.render(camera, color_mode='dontcare', track_max_w=True)['max_w']
199
+ max_w = torch.maximum(max_w, max_w_i)
200
+
201
+ vis_idx = (max_w_i > 0).squeeze().argwhere().squeeze()
202
+ zdist = ((self.vox_center[vis_idx] - camera.position) * camera.lookat).sum(-1, keepdims=True)
203
+ samp_interval = zdist * camera.pix_size
204
+ min_samp_interval[vis_idx] = torch.minimum(min_samp_interval[vis_idx], samp_interval)
205
+
206
+ view_cnt[vis_idx] += 1
207
+
208
+ stat_pkg = {
209
+ 'max_w': max_w,
210
+ 'min_samp_interval': min_samp_interval,
211
+ 'view_cnt': view_cnt,
212
+ }
213
+ self.unfreeze_vox_geo()
214
+ return stat_pkg
215
+
216
+
217
+ # Some helpful functions
218
+ def mask_cat_perm(tensor, kept_idx=None, cat_tensor=None, perm=None):
219
+ '''
220
+ Perform tensor masking, concatenation, and permutation.
221
+ '''
222
+ if kept_idx is None and cat_tensor is None and perm is None:
223
+ raise Exception("No op for mask_cat_perm??")
224
+ device = tensor.device
225
+ if kept_idx is not None:
226
+ tensor = tensor[kept_idx.to(device)]
227
+ if cat_tensor is not None:
228
+ tensor = torch.cat([tensor, cat_tensor.to(device)])
229
+ if perm is not None:
230
+ assert len(perm) == len(tensor)
231
+ tensor = tensor[perm.to(device)]
232
+ return tensor.contiguous()
233
+
234
+ def agg_voxel_into_grid_pts(num_grid_pts, vox_key, vox_val, reduce='mean'):
235
+ '''
236
+ Aggregate per-voxel data into their eight grid points.
237
+ Input:
238
+ @num_grid_pts Number of final grid points.
239
+ @vox_key [N, 8] Index to the eight grid points of each voxel.
240
+ @vox_val [N, 8, *] Data of the eight grid points of each voxel.
241
+ Output:
242
+ @new_param [num_grid_pts, *] Grid points data aggregated from vox_val.
243
+ '''
244
+ ch = vox_val.shape[2:]
245
+ device = vox_val.device
246
+ vox_key = vox_key.to(device)
247
+ new_param = torch.zeros([num_grid_pts, *ch], dtype=torch.float32, device=device)
248
+ new_param.index_reduce_(
249
+ dim=0,
250
+ index=vox_key.flatten(),
251
+ source=vox_val.flatten(0,1),
252
+ reduce=reduce,
253
+ include_self=False)
254
+ # Equivalent implementation by old API
255
+ # new_param /= vox_key.flatten().bincount(minlength=num_grid_pts).unsqueeze(-1)
256
+ # new_param.nan_to_num_()
257
+ return new_param.contiguous()
258
+
259
+ def subdivide_by_interp(vox_val):
260
+ '''
261
+ Subdivide grid point data by trilinear interpolation.
262
+ The subdivided children order is the same as those from `_subdivide_attr` and `gen_children`.
263
+ Input:
264
+ @vox_val [N, 8, *] Data of the eight grid points of each voxel.
265
+ Output:
266
+ @new_vox_val [8N, 8, *] Data of the eight grid points of the subdivided voxel.
267
+ '''
268
+ vox_val = vox_val.contiguous()
269
+ main_idx = torch.arange(8, dtype=torch.int64, device=vox_val.device)
270
+ new_vox_val = torch.zeros([len(vox_val), 8, *vox_val.shape[1:]], device=vox_val.device)
271
+ new_vox_val[:, main_idx, main_idx] = vox_val
272
+ new_vox_val[:, main_idx, main_idx^0b001] = 0.5 * (vox_val + vox_val[:, main_idx^0b001])
273
+ new_vox_val[:, main_idx, main_idx^0b010] = 0.5 * (vox_val + vox_val[:, main_idx^0b010])
274
+ new_vox_val[:, main_idx, main_idx^0b100] = 0.5 * (vox_val + vox_val[:, main_idx^0b100])
275
+ new_vox_val[:, main_idx, main_idx^0b011] = 0.25 * (
276
+ vox_val + \
277
+ vox_val[:, main_idx^0b001] + \
278
+ vox_val[:, main_idx^0b010] + \
279
+ vox_val[:, main_idx^0b011]
280
+ )
281
+ new_vox_val[:, main_idx, main_idx^0b101] = 0.25 * (
282
+ vox_val + \
283
+ vox_val[:, main_idx^0b001] + \
284
+ vox_val[:, main_idx^0b100] + \
285
+ vox_val[:, main_idx^0b101]
286
+ )
287
+ new_vox_val[:, main_idx, main_idx^0b110] = 0.25 * (
288
+ vox_val + \
289
+ vox_val[:, main_idx^0b010] + \
290
+ vox_val[:, main_idx^0b100] + \
291
+ vox_val[:, main_idx^0b110]
292
+ )
293
+ new_vox_val[:, main_idx, main_idx^0b111] = vox_val.mean(1, keepdim=True)
294
+
295
+ new_vox_val = new_vox_val.reshape(len(vox_val)*8, *vox_val.shape[1:])
296
+ return new_vox_val.contiguous()
src/sparse_voxel_gears/constructor.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+ import svraster_cuda
12
+
13
+ from src.utils.activation_utils import rgb2shzero
14
+ from src.utils import octree_utils
15
+
16
+ class SVConstructor:
17
+
18
+ def model_init(self,
19
+ bounding, # Scene bound [min_xyz, max_xyz]
20
+ outside_level, # Number of Octree levels for background
21
+ init_n_level=6, # Starting from (2^init_n_level)^3 voxels
22
+ init_out_ratio=2.0, # Number of voxel ratio for outside (background region)
23
+ sh_degree_init=3, # Initial activated sh degree
24
+ geo_init=-10.0, # Init pre-activation density
25
+ sh0_init=0.5, # Init voxel colors in range [0,1]
26
+ shs_init=0.0, # Init coefficients of higher-degree sh
27
+ cameras=None, # Cameras that helps voxel allocation
28
+ ):
29
+
30
+ assert outside_level <= svraster_cuda.meta.MAX_NUM_LEVELS
31
+
32
+ # Define scene bound
33
+ center = (bounding[0] + bounding[1]) * 0.5
34
+ extent = max(bounding[1] - bounding[0])
35
+ self.scene_center, self.scene_extent, self.inside_extent = get_scene_bound_tensor(
36
+ center=center, extent=extent, outside_level=outside_level)
37
+
38
+ # Init voxel layout.
39
+ # The world is seperated into inside (main foreground) and outside (background) regions.
40
+ in_path, in_level = octlayout_inside_uniform(
41
+ scene_center=self.scene_center,
42
+ scene_extent=self.scene_extent,
43
+ outside_level=outside_level,
44
+ n_level=init_n_level,
45
+ cameras=cameras,
46
+ filter_zero_visiblity=(cameras is not None),
47
+ filter_near=-1)
48
+
49
+ if outside_level == 0:
50
+ # Object centric bounded scenes
51
+ ou_path = torch.empty([0, 1], dtype=in_path.dtype, device="cuda")
52
+ ou_level = torch.empty([0, 1], dtype=in_level.dtype, device="cuda")
53
+ else:
54
+ min_num = len(in_path) * init_out_ratio
55
+ max_level = outside_level + init_n_level
56
+ ou_path, ou_level = octlayout_outside_heuristic(
57
+ scene_center=self.scene_center,
58
+ scene_extent=self.scene_extent,
59
+ outside_level=outside_level,
60
+ cameras=cameras,
61
+ min_num=min_num,
62
+ max_level=max_level,
63
+ filter_near=-1)
64
+
65
+ self.octpath = torch.cat([ou_path, in_path])
66
+ self.octlevel = torch.cat([ou_level, in_level])
67
+
68
+ self.active_sh_degree = min(sh_degree_init, self.max_sh_degree)
69
+
70
+ # Init trainable parameters
71
+ self._geo_grid_pts = torch.full(
72
+ [self.num_grid_pts, 1], geo_init,
73
+ dtype=torch.float32, device="cuda").requires_grad_()
74
+
75
+ self._sh0 = torch.full(
76
+ [self.num_voxels, 3], rgb2shzero(sh0_init),
77
+ dtype=torch.float32, device="cuda").requires_grad_()
78
+
79
+ self._shs = torch.full(
80
+ [self.num_voxels, (self.max_sh_degree+1)**2 - 1, 3], shs_init,
81
+ dtype=torch.float32, device="cuda").requires_grad_()
82
+
83
+ # Subdivision priority trackor
84
+ self._subdiv_p = torch.ones(
85
+ [self.num_voxels, 1],
86
+ dtype=torch.float32, device="cuda").requires_grad_()
87
+
88
+ def octpath_init(self,
89
+ scene_center,
90
+ scene_extent,
91
+ octpath, # Nx1 octpath.
92
+ octlevel, # Nx1 or scalar for the Octree level of each voxel.
93
+
94
+ # The following are model parameters.
95
+ # If the input are tensors, the gradient of rendering can be backprop to them.
96
+ # Otherwise, it creates new trainable tensors.
97
+ rgb=0.5, # Nx3 or scalar for voxel color in range of 0~1.
98
+ shs=0.0, # NxDx3 or scalar for voxel higher-deg sh coefficient.
99
+ density=-10., # Nx8 or Ngridx1 or scalar for voxel density field.
100
+ # The order is [0,0,0] => [0,0,1] => [0,1,0] => [0,1,1] ...
101
+ reduce_density=False, # Whether to merge grid points if density is Nx8.
102
+ ):
103
+
104
+ self.scene_center, self.scene_extent, self.inside_extent = get_scene_bound_tensor(
105
+ center=scene_center, extent=scene_extent)
106
+
107
+ assert torch.is_tensor(octpath)
108
+ octlevel = get_octlevel_tensor(octlevel, num_voxels=len(octpath))
109
+
110
+ self.octpath = octpath.view(-1, 1).contiguous()
111
+ self.octlevel = octlevel.view(-1, 1).contiguous()
112
+ assert len(self.octpath) == len(self.octlevel)
113
+
114
+ # Subdivision priority trackor
115
+ self._subdiv_p = torch.ones(
116
+ [self.num_voxels, 1],
117
+ dtype=torch.float32, device="cuda").requires_grad_()
118
+
119
+ # Setup appearence parameters
120
+ if torch.is_tensor(rgb):
121
+ assert rgb.shape == (self.num_voxels, 3)
122
+ self._sh0 = rgb2shzero(rgb.contiguous().cuda())
123
+ else:
124
+ self._sh0 = torch.full(
125
+ [self.num_voxels, 3], rgb2shzero(rgb),
126
+ dtype=torch.float32, device="cuda").requires_grad_()
127
+
128
+ if torch.is_tensor(shs):
129
+ assert shs.shape == (self.num_voxels, (self.max_sh_degree+1)**2 - 1, 3)
130
+ self.shs = shs.contiguous().cuda()
131
+ else:
132
+ self._shs = torch.full(
133
+ [self.num_voxels, (self.max_sh_degree+1)**2 - 1, 3], shs,
134
+ dtype=torch.float32, device="cuda").requires_grad_()
135
+
136
+ # Setup geometry parameters
137
+ if torch.is_tensor(density):
138
+ if density.shape == (self.num_grid_pts, 1):
139
+ self._geo_grid_pts = density.contiguous().cuda()
140
+ elif density.shape == (self.num_voxels, 8):
141
+ if reduce_density:
142
+ self._geo_grid_pts = torch.zeros(
143
+ [self.num_grid_pts, 1], dtype=torch.float32, device="cuda")
144
+ self._geo_grid_pts.index_reduce_(
145
+ dim=0,
146
+ index=self.vox_key.flatten(),
147
+ source=density.flatten(),
148
+ reduce="mean",
149
+ include_self=False)
150
+ else:
151
+ self.frozen_vox_geo = density.contiguous().cuda()
152
+ else:
153
+ raise Exception(f"Unexpected density shape. "
154
+ f"It should be either {(self.num_grid_pts,1)} or {(self.num_voxels,8)}")
155
+ else:
156
+ self._geo_grid_pts = torch.full(
157
+ [self.num_grid_pts, 1], density,
158
+ dtype=torch.float32, device="cuda").requires_grad_()
159
+
160
+ def ijkl_init(self,
161
+ scene_center,
162
+ scene_extent,
163
+ ijk, # Nx3 integer coordinates of each voxel.
164
+ octlevel, # Nx1 or scalar for the Octree level of each voxel.
165
+
166
+ # The following are model parameters.
167
+ # If the input are tensors, the gradient of rendering can be backprop to them.
168
+ # Otherwise, it creates new trainable tensors.
169
+ rgb=0.5, # Nx3 or scalar for voxel color in range of 0~1.
170
+ shs=0.0, # NxDx3 or scalar for voxel higher-deg sh coefficient.
171
+ density=-10., # Nx8 or Ngridx1 or scalar for voxel density field.
172
+ # The order is [0,0,0] => [0,0,1] => [0,1,0] => [0,1,1] ...
173
+ reduce_density=False, # Whether to merge grid points if density is Nx8.
174
+ ):
175
+
176
+ scene_center, scene_extent, _ = get_scene_bound_tensor(
177
+ center=scene_center, extent=scene_extent)
178
+
179
+ # Convert to ijkl to octpath
180
+ octlevel = get_octlevel_tensor(octlevel, num_voxels=len(ijk))
181
+
182
+ assert torch.is_tensor(ijk)
183
+ assert len(ijk.shape) == 2 and ijk.shape[1] == 3
184
+ assert len(ijk) == len(octlevel)
185
+ ijk = ijk.long()
186
+ if (ijk < 0).any():
187
+ raise Exception("xyz out of scene bound")
188
+ if (ijk >= (1 << octlevel.long())).any():
189
+ raise Exception("xyz out of scene bound")
190
+ octpath = svraster_cuda.utils.ijk_2_octpath(ijk, octlevel)
191
+
192
+ self.octpath_init(
193
+ scene_center=scene_center,
194
+ scene_extent=scene_extent,
195
+ octpath=octpath,
196
+ octlevel=octlevel,
197
+ rgb=rgb,
198
+ shs=shs,
199
+ density=density,
200
+ reduce_density=reduce_density)
201
+
202
+ def points_init(self,
203
+ scene_center,
204
+ scene_extent,
205
+ xyz, # Nx3 point coordinates in world space.
206
+ octlevel=None, # Nx1 or scalar for the Octree level of each voxel.
207
+ expected_vox_size=None,
208
+ level_round_mode='nearest',
209
+
210
+ # The following are model parameters.
211
+ # If the input are tensors, the gradient of rendering can be backprop to them.
212
+ # Otherwise, it creates new trainable tensors.
213
+ rgb=0.5, # Nx3 or scalar for voxel color in range of 0~1.
214
+ shs=0.0, # NxDx3 or scalar for voxel higher-deg sh coefficient.
215
+ density=-10., # Nx8 or scalar for voxel density field.
216
+ # The order is [0,0,0] => [0,0,1] => [0,1,0] => [0,1,1] ...
217
+ reduce_density=False, # Whether to merge grid points if density is Nx8.
218
+ ):
219
+
220
+ scene_center, scene_extent, _ = get_scene_bound_tensor(center=scene_center, extent=scene_extent)
221
+
222
+ # Compute voxel level
223
+ if octlevel is not None:
224
+ assert expected_vox_size is None
225
+ octlevel = get_octlevel_tensor(octlevel, num_voxels=len(xyz))
226
+ elif expected_vox_size is not None:
227
+ octlevel_fp32 = octree_utils.vox_size_2_level(scene_extent, expected_vox_size)
228
+ if level_round_mode == "nearest":
229
+ octlevel_fp32 = octlevel_fp32.round()
230
+ elif level_round_mode == "down":
231
+ octlevel_fp32 = octlevel_fp32.floor()
232
+ elif level_round_mode == "up":
233
+ octlevel_fp32 = octlevel_fp32.ceil()
234
+ else:
235
+ raise Exception("Unknonw level_round_mode")
236
+ octlevel_fp32 = octlevel_fp32.clamp(1, svraster_cuda.meta.MAX_NUM_LEVELS)
237
+ octlevel = get_octlevel_tensor(octlevel_fp32.to(torch.int8), num_voxels=len(xyz))
238
+ else:
239
+ raise Exception("Either octlevel or expected_vox_size should be given.")
240
+
241
+ # Transform point to ijk integer coordinate
242
+ scene_min_xyz = scene_center - 0.5 * scene_extent
243
+ vox_size = octree_utils.level_2_vox_size(scene_extent, octlevel)
244
+ ijk = ((xyz - scene_min_xyz) / vox_size).long()
245
+
246
+ # Reduce duplicated tensor
247
+ ijkl = torch.cat([ijk, octlevel], dim=1)
248
+ ijkl_unq, invmap = ijkl.unique(dim=0, return_inverse=True)
249
+ ijk, octlevel = ijkl_unq.split([3, 1], dim=1)
250
+ octlevel = octlevel.to(torch.int8)
251
+
252
+ if torch.is_tensor(rgb):
253
+ assert rgb.shape == (len(invmap), 3)
254
+ new_shape = (len(ijk), 3)
255
+ rgb = torch.zeros(new_shape, dtype=torch.float32, device="cuda").index_reduce_(
256
+ dim=0,
257
+ index=invmap,
258
+ source=rgb,
259
+ reduce="mean",
260
+ include_self=False)
261
+
262
+ if torch.is_tensor(shs):
263
+ assert shs.shape == (len(invmap), (self.max_sh_degree+1)**2 - 1, 3)
264
+ new_shape = (len(ijk), (self.max_sh_degree+1)**2 - 1, 3)
265
+ shs = torch.zeros(new_shape, dtype=torch.float32, device="cuda").index_reduce_(
266
+ dim=0,
267
+ index=invmap,
268
+ source=shs,
269
+ reduce="mean",
270
+ include_self=False)
271
+
272
+ if torch.is_tensor(density):
273
+ assert density.shape == (len(invmap), 8)
274
+ new_shape = (len(ijk), 8)
275
+ density = torch.zeros(new_shape, dtype=torch.float32, device="cuda").index_reduce_(
276
+ dim=0,
277
+ index=invmap,
278
+ source=density,
279
+ reduce="mean",
280
+ include_self=False)
281
+
282
+ # Allocate voxel using ijkl coordinate
283
+ self.ijkl_init(
284
+ scene_center=scene_center,
285
+ scene_extent=scene_extent,
286
+ ijk=ijk,
287
+ octlevel=octlevel,
288
+ rgb=rgb,
289
+ shs=shs,
290
+ density=density,
291
+ reduce_density=reduce_density)
292
+
293
+
294
+ #################################################
295
+ # Helper function
296
+ #################################################
297
+ def get_scene_bound_tensor(center, extent, outside_level=0):
298
+ if torch.is_tensor(center):
299
+ scene_center = center.float().clone().cuda()
300
+ else:
301
+ scene_center = torch.tensor(center, dtype=torch.float32, device="cuda")
302
+
303
+ if torch.is_tensor(extent):
304
+ inside_extent = extent.float().clone().cuda()
305
+ else:
306
+ inside_extent = torch.tensor(extent, dtype=torch.float32, device="cuda")
307
+
308
+ scene_extent = inside_extent * (2 ** outside_level)
309
+
310
+ assert scene_center.shape == (3,)
311
+ assert scene_extent.numel() == 1
312
+
313
+ return scene_center, scene_extent, inside_extent
314
+
315
+ def get_octlevel_tensor(octlevel, num_voxels=None):
316
+ if not torch.is_tensor(octlevel):
317
+ assert np.all(octlevel > 0)
318
+ assert np.all(octlevel <= svraster_cuda.meta.MAX_NUM_LEVELS)
319
+ octlevel = torch.tensor(octlevel, dtype=torch.int8, device="cuda")
320
+ if octlevel.numel() == 1:
321
+ octlevel = octlevel.view(1, 1).repeat(num_voxels, 1).contiguous()
322
+ octlevel = octlevel.reshape(-1, 1)
323
+ assert octlevel.dtype == torch.int8
324
+ assert num_voxels is None or octlevel.numel() == num_voxels
325
+
326
+ return octlevel
327
+
328
+
329
+ #################################################
330
+ # Octree layout construction heuristic
331
+ #################################################
332
+ def octlayout_filtering(octpath, octlevel, scene_center, scene_extent, cameras=None, filter_zero_visiblity=True, filter_near=-1):
333
+
334
+ vox_center, vox_size = octree_utils.octpath_decoding(
335
+ octpath, octlevel,
336
+ scene_center, scene_extent)
337
+
338
+ # Filtering
339
+ kept_mask = torch.ones([len(octpath)], dtype=torch.bool, device="cuda")
340
+ if filter_zero_visiblity:
341
+ assert cameras is not None, "Cameras should be given to filter invisible voxels"
342
+ rate = svraster_cuda.renderer.mark_max_samp_rate(
343
+ cameras, octpath, vox_center, vox_size)
344
+ kept_mask &= (rate > 0)
345
+ if filter_near > 0:
346
+ is_near = svraster_cuda.renderer.mark_near(
347
+ cameras, octpath, vox_center, vox_size, near=filter_near)
348
+ kept_mask &= (~is_near)
349
+ kept_idx = torch.where(kept_mask)[0]
350
+ octpath = octpath[kept_idx]
351
+ octlevel = octlevel[kept_idx]
352
+ return octpath, octlevel
353
+
354
+
355
+ def octlayout_inside_uniform(scene_center, scene_extent, outside_level, n_level, cameras=None, filter_zero_visiblity=True, filter_near=-1):
356
+ octpath, octlevel = octree_utils.gen_octpath_dense(
357
+ outside_level=outside_level,
358
+ n_level_inside=n_level)
359
+
360
+ octpath, octlevel = octlayout_filtering(
361
+ octpath=octpath,
362
+ octlevel=octlevel,
363
+ scene_center=scene_center,
364
+ scene_extent=scene_extent,
365
+ cameras=cameras,
366
+ filter_zero_visiblity=filter_zero_visiblity,
367
+ filter_near=filter_near)
368
+ return octpath, octlevel
369
+
370
+
371
+ def octlayout_outside_heuristic(scene_center, scene_extent, outside_level, cameras, min_num, max_level, filter_near=-1):
372
+
373
+ assert cameras is not None, "Cameras should provided in this mode."
374
+
375
+ # Init by adding one sub-level in each shell level
376
+ octpath = []
377
+ octlevel = []
378
+ for lv in range(1, 1+outside_level):
379
+ path, lv = octree_utils.gen_octpath_shell(
380
+ shell_level=lv,
381
+ n_level_inside=1)
382
+ octpath.append(path)
383
+ octlevel.append(lv)
384
+ octpath = torch.cat(octpath)
385
+ octlevel = torch.cat(octlevel)
386
+
387
+ # Iteratively subdivide voxels with maximum sampling rate
388
+ while True:
389
+ vox_center, vox_size = octree_utils.octpath_decoding(
390
+ octpath, octlevel, scene_center, scene_extent)
391
+ samp_rate = svraster_cuda.renderer.mark_max_samp_rate(
392
+ cameras, octpath, vox_center, vox_size)
393
+
394
+ kept_idx = torch.where((samp_rate > 0))[0]
395
+ octpath = octpath[kept_idx]
396
+ octlevel = octlevel[kept_idx]
397
+ octlevel_mask = (octlevel.squeeze(1) < max_level)
398
+ samp_rate = samp_rate[kept_idx] * octlevel_mask
399
+ vox_size = vox_size[kept_idx]
400
+ still_need_n = (min_num - len(octpath)) // 7
401
+ still_need_n = min(len(octpath), round(still_need_n))
402
+ if still_need_n <= 0:
403
+ break
404
+ rank = samp_rate * (octlevel.squeeze(1) < svraster_cuda.meta.MAX_NUM_LEVELS)
405
+ subdiv_mask = (rank >= rank.sort().values[-still_need_n])
406
+ subdiv_mask &= (octlevel.squeeze(1) < svraster_cuda.meta.MAX_NUM_LEVELS)
407
+ subdiv_mask &= octlevel_mask
408
+ samp_rate *= subdiv_mask
409
+ subdiv_mask &= (samp_rate >= samp_rate.quantile(0.9)) # Subdivide only 10% each iteration
410
+ if subdiv_mask.sum() == 0:
411
+ break
412
+ octpath_children, octlevel_children = octree_utils.gen_children(
413
+ octpath[subdiv_mask], octlevel[subdiv_mask])
414
+ octpath = torch.cat([octpath[~subdiv_mask], octpath_children])
415
+ octlevel = torch.cat([octlevel[~subdiv_mask], octlevel_children])
416
+
417
+ octpath, octlevel = octlayout_filtering(
418
+ octpath=octpath,
419
+ octlevel=octlevel,
420
+ scene_center=scene_center,
421
+ scene_extent=scene_extent,
422
+ cameras=cameras,
423
+ filter_zero_visiblity=True,
424
+ filter_near=filter_near)
425
+ return octpath, octlevel
src/sparse_voxel_gears/io.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import re
11
+ import torch
12
+
13
+ from src.utils import octree_utils
14
+
15
+ class SVInOut:
16
+
17
+ def save(self, path, quantize=False):
18
+ '''
19
+ Save the necessary attributes and parameters for reproducing rendering.
20
+ '''
21
+ os.makedirs(os.path.dirname(path), exist_ok=True)
22
+ state_dict = {
23
+ 'active_sh_degree': self.active_sh_degree,
24
+ 'scene_center': self.scene_center.data.contiguous(),
25
+ 'inside_extent': self.inside_extent.data.contiguous(),
26
+ 'scene_extent': self.scene_extent.data.contiguous(),
27
+ 'octpath': self.octpath.data.contiguous(),
28
+ 'octlevel': self.octlevel.data.contiguous(),
29
+ '_geo_grid_pts': self._geo_grid_pts.data.contiguous(),
30
+ '_sh0': self._sh0.data.contiguous(),
31
+ '_shs': self._shs.data.contiguous(),
32
+ }
33
+
34
+ if quantize:
35
+ quantize_state_dict(state_dict)
36
+ state_dict['quantized'] = True
37
+ else:
38
+ state_dict['quantized'] = False
39
+
40
+ for k, v in state_dict.items():
41
+ if torch.is_tensor(v):
42
+ state_dict[k] = v.cpu()
43
+ torch.save(state_dict, path)
44
+ self.latest_save_path = path
45
+
46
+ def load(self, path):
47
+ '''
48
+ Load the saved models.
49
+ '''
50
+ self.loaded_path = path
51
+ state_dict = torch.load(path, map_location="cpu", weights_only=False)
52
+
53
+ if state_dict.get('quantized', False):
54
+ dequantize_state_dict(state_dict)
55
+
56
+ self.active_sh_degree = state_dict['active_sh_degree']
57
+
58
+ self.scene_center = state_dict['scene_center'].cuda()
59
+ self.inside_extent = state_dict['inside_extent'].cuda()
60
+ self.scene_extent = state_dict['scene_extent'].cuda()
61
+
62
+ self.octpath = state_dict['octpath'].cuda()
63
+ self.octlevel = state_dict['octlevel'].cuda().to(torch.int8)
64
+
65
+ self._geo_grid_pts = state_dict['_geo_grid_pts'].cuda().requires_grad_()
66
+ self._sh0 = state_dict['_sh0'].cuda().requires_grad_()
67
+ self._shs = state_dict['_shs'].cuda().requires_grad_()
68
+
69
+ # Subdivision priority trackor
70
+ self._subdiv_p = torch.ones(
71
+ [self.num_voxels, 1],
72
+ dtype=torch.float32, device="cuda").requires_grad_()
73
+
74
+ def save_iteration(self, model_path, iteration, quantize=False):
75
+ path = os.path.join(model_path, "checkpoints", f"iter{iteration:06d}_model.pt")
76
+ self.save(path, quantize=quantize)
77
+ self.latest_save_iter = iteration
78
+
79
+ def load_iteration(self, model_path, iteration=-1):
80
+ if iteration == -1:
81
+ # Find the maximum iteration if it is -1.
82
+ fnames = os.listdir(os.path.join(model_path, "checkpoints"))
83
+ loaded_iter = max(int(re.sub("[^0-9]", "", fname)) for fname in fnames)
84
+ else:
85
+ loaded_iter = iteration
86
+
87
+ path = os.path.join(model_path, "checkpoints", f"iter{loaded_iter:06d}_model.pt")
88
+ self.load(path)
89
+
90
+ self.loaded_iter = iteration
91
+
92
+ return loaded_iter
93
+
94
+
95
+ # Quantization utilities to reduce size when saving model.
96
+ # It can reduce ~70% model size with minor PSNR drop.
97
+ def quantize_state_dict(state_dict):
98
+ state_dict['_geo_grid_pts'] = quantization(state_dict['_geo_grid_pts'])
99
+ state_dict['_sh0'] = [quantization(v) for v in state_dict['_sh0'].split(1, dim=1)]
100
+ state_dict['_shs'] = [quantization(v) for v in state_dict['_shs'].split(1, dim=1)]
101
+
102
+ def dequantize_state_dict(state_dict):
103
+ state_dict['_geo_grid_pts'] = dequantization(state_dict['_geo_grid_pts'])
104
+ state_dict['_sh0'] = torch.cat(
105
+ [dequantization(v) for v in state_dict['_sh0']], dim=1)
106
+ state_dict['_shs'] = torch.cat(
107
+ [dequantization(v) for v in state_dict['_shs']], dim=1)
108
+
109
+ def quantization(src_tensor, max_iter=10):
110
+ src_shape = src_tensor.shape
111
+ src_vals = src_tensor.flatten().contiguous()
112
+ order = src_vals.argsort()
113
+ quantile_ind = (torch.linspace(0,1,257) * (len(order) - 1)).long().clamp_(0, len(order)-1)
114
+ codebook = src_vals[order[quantile_ind]].contiguous()
115
+ codebook[0] = -torch.inf
116
+ ind = torch.searchsorted(codebook, src_vals)
117
+
118
+ codebook = codebook[1:]
119
+ ind = (ind - 1).clamp_(0, 255)
120
+
121
+ diff_l = (src_vals - codebook[ind-1]).abs()
122
+ diff_m = (src_vals - codebook[ind]).abs()
123
+ ind = ind - 1 + (diff_m < diff_l)
124
+ ind.clamp_(0, 255)
125
+
126
+ for _ in range(max_iter):
127
+ codebook = torch.zeros_like(codebook).index_reduce_(
128
+ dim=0,
129
+ index=ind,
130
+ source=src_vals,
131
+ reduce='mean',
132
+ include_self=False)
133
+ diff_l = (src_vals - codebook[ind-1]).abs()
134
+ diff_r = (src_vals - codebook[(ind+1).clamp_max_(255)]).abs()
135
+ diff_m = (src_vals - codebook[ind]).abs()
136
+ upd_mask = torch.minimum(diff_l, diff_r) < diff_m
137
+ if upd_mask.sum() == 0:
138
+ break
139
+ shift = (diff_r < diff_l) * 2 - 1
140
+ ind[upd_mask] += shift[upd_mask]
141
+ ind.clamp_(0, 255)
142
+
143
+ codebook = torch.zeros_like(codebook).index_reduce_(
144
+ dim=0,
145
+ index=ind,
146
+ source=src_vals,
147
+ reduce='mean',
148
+ include_self=False)
149
+
150
+ return dict(
151
+ index=ind.reshape(src_shape).to(torch.uint8),
152
+ codebook=codebook,
153
+ )
154
+
155
+ def dequantization(quant_dict):
156
+ return quant_dict['codebook'][quant_dict['index'].long()]
src/sparse_voxel_gears/pooling.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import svraster_cuda
11
+
12
+ from src.utils import octree_utils
13
+
14
+
15
+ class SVPooling:
16
+
17
+ def pooling_to_level(self, max_level, octpath=None, octlevel=None):
18
+ octpath = self.octpath if octpath is None else octpath
19
+ octlevel = self.octlevel if octlevel is None else octlevel
20
+
21
+ num_bit_to_mask = 3 * max(0, svraster_cuda.meta.MAX_NUM_LEVELS - max_level)
22
+ octpath = (octpath >> num_bit_to_mask) << num_bit_to_mask
23
+ octlevel = octlevel.clamp_max(max_level)
24
+ octpack, invmap = torch.stack([octpath, octlevel]).unique(sorted=True, dim=1, return_inverse=True)
25
+ octpath, octlevel = octpack
26
+ octlevel = octlevel.to(torch.int8)
27
+
28
+ vox_center, vox_size = octree_utils.octpath_decoding(
29
+ octpath, octlevel, self.scene_center, self.scene_extent)
30
+
31
+ return dict(
32
+ invmap=invmap,
33
+ octpath=octpath,
34
+ octlevel=octlevel,
35
+ vox_center=vox_center,
36
+ vox_size=vox_size,
37
+ )
38
+
39
+ def pooling_to_rate(self, cameras, max_rate, octpath=None, octlevel=None):
40
+ octpath = self.octpath.clone() if octpath is None else octpath
41
+ octlevel = self.octlevel.clone() if octlevel is None else octlevel
42
+ invmap = torch.arange(len(octpath), device="cuda")
43
+
44
+ for _ in range(svraster_cuda.meta.MAX_NUM_LEVELS):
45
+ vox_center, vox_size = octree_utils.octpath_decoding(octpath, octlevel, self.scene_center, self.scene_extent)
46
+ samp_rate = svraster_cuda.renderer.mark_max_samp_rate(cameras, octpath, vox_center, vox_size)
47
+ pool_mask = (samp_rate < max_rate) & (octlevel.squeeze(1) > 1)
48
+ if pool_mask.sum() == 0:
49
+ break
50
+ octlevel[pool_mask] = octlevel[pool_mask] - 1
51
+ num_bit_to_mask = 3 * (svraster_cuda.meta.MAX_NUM_LEVELS - octlevel[pool_mask])
52
+ octpath[pool_mask] = octpath[pool_mask] >> num_bit_to_mask << num_bit_to_mask
53
+
54
+ octpack, cur_invmap = torch.stack([octpath, octlevel]).unique(sorted=True, dim=1, return_inverse=True)
55
+ octpath, octlevel = octpack
56
+ octlevel = octlevel.to(torch.int8)
57
+ invmap = cur_invmap[invmap]
58
+
59
+ vox_center, vox_size = octree_utils.octpath_decoding(
60
+ octpath, octlevel, self.scene_center, self.scene_extent)
61
+
62
+ return dict(
63
+ invmap=invmap,
64
+ octpath=octpath,
65
+ octlevel=octlevel,
66
+ vox_center=vox_center,
67
+ vox_size=vox_size,
68
+ )
src/sparse_voxel_gears/properties.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+
11
+ from src.utils import octree_utils
12
+ from src.utils.fuser_utils import rgb_fusion
13
+ from src.utils.activation_utils import rgb2shzero
14
+
15
+ import svraster_cuda
16
+
17
+
18
+ class SVProperties:
19
+
20
+ @property
21
+ def num_voxels(self):
22
+ return len(self.octpath)
23
+
24
+ @property
25
+ def num_grid_pts(self):
26
+ return len(self.grid_pts_key)
27
+
28
+ @property
29
+ def scene_min(self):
30
+ return self.scene_center - 0.5 * self.scene_extent
31
+
32
+ @property
33
+ def scene_max(self):
34
+ return self.scene_center + 0.5 * self.scene_extent
35
+
36
+ @property
37
+ def inside_min(self):
38
+ return self.scene_center - 0.5 * self.inside_extent
39
+
40
+ @property
41
+ def inside_max(self):
42
+ return self.scene_center + 0.5 * self.inside_extent
43
+
44
+ @property
45
+ def outside_level(self):
46
+ return (self.scene_extent / self.inside_extent).log2().round().long().item()
47
+
48
+ @property
49
+ def bounding(self):
50
+ return torch.stack([self.scene_min, self.scene_max])
51
+
52
+ @property
53
+ def inside_mask(self):
54
+ isin = ((self.inside_min < self.vox_center) & (self.vox_center < self.inside_max)).all(1)
55
+ return isin
56
+
57
+ @property
58
+ def sh0(self):
59
+ return self._sh0
60
+
61
+ @property
62
+ def shs(self):
63
+ return self._shs
64
+
65
+ @property
66
+ def subdivision_priority(self):
67
+ return self._subdiv_p.grad
68
+
69
+ def reset_subdivision_priority(self):
70
+ self._subdiv_p.grad = None
71
+
72
+ @property
73
+ def signature(self):
74
+ # Signature to check if the voxel grid layout is updated
75
+ return (self.num_voxels, id(self.octpath), id(self.octlevel))
76
+
77
+ def _check_derived_voxel_attr(self):
78
+ # Lazy computation of inverse voxel sizes
79
+ signature = self.signature
80
+ need_recompute = not hasattr(self, '_check_derived_voxel_attr_signature') or \
81
+ self._check_derived_voxel_attr_signature != signature
82
+ if need_recompute:
83
+ self._vox_center, self._vox_size = octree_utils.octpath_decoding(
84
+ self.octpath, self.octlevel, self.scene_center, self.scene_extent)
85
+ self._grid_pts_key, self._vox_key = octree_utils.build_grid_pts_link(self.octpath, self.octlevel)
86
+ self._check_derived_voxel_attr_signature = signature
87
+
88
+ @property
89
+ def vox_center(self):
90
+ self._check_derived_voxel_attr()
91
+ return self._vox_center
92
+
93
+ @property
94
+ def vox_size(self):
95
+ self._check_derived_voxel_attr()
96
+ return self._vox_size
97
+
98
+ @property
99
+ def grid_pts_key(self):
100
+ self._check_derived_voxel_attr()
101
+ return self._grid_pts_key
102
+
103
+ @property
104
+ def vox_key(self):
105
+ self._check_derived_voxel_attr()
106
+ return self._vox_key
107
+
108
+ @property
109
+ def vox_size_inv(self):
110
+ # Lazy computation of inverse voxel sizes
111
+ signature = self.signature
112
+ need_recompute = not hasattr(self, '_vox_size_inv_signature') or \
113
+ self._vox_size_inv_signature != signature
114
+ if need_recompute:
115
+ self._vox_size_inv = 1 / self.vox_size
116
+ self._vox_size_inv_signature = signature
117
+ return self._vox_size_inv
118
+
119
+ @property
120
+ def grid_pts_xyz(self):
121
+ # Lazy computation of grid points xyz
122
+ signature = self.signature
123
+ need_recompute = not hasattr(self, '_grid_pts_xyz_signature') or \
124
+ self._grid_pts_xyz_signature != signature
125
+ if need_recompute:
126
+ self._grid_pts_xyz = octree_utils.compute_gridpoints_xyz(
127
+ self.grid_pts_key, self.scene_center, self.scene_extent)
128
+ self._grid_pts_xyz_signature = signature
129
+ return self._grid_pts_xyz
130
+
131
+ @torch.no_grad()
132
+ def reset_sh_from_cameras(self, cameras):
133
+ self._sh0.data.copy_(rgb2shzero(rgb_fusion(self, cameras)))
134
+ self._shs.data.zero_()
135
+
136
+ def apply_tv_on_density_field(self, lambda_tv_density):
137
+ if self._geo_grid_pts.grad is None:
138
+ self._geo_grid_pts.grad = torch.zeros_like(self._geo_grid_pts.data)
139
+ svraster_cuda.grid_loss_bw.total_variation(
140
+ grid_pts=self._geo_grid_pts,
141
+ vox_key=self.vox_key,
142
+ weight=lambda_tv_density,
143
+ vox_size_inv=self.vox_size_inv,
144
+ no_tv_s=True,
145
+ tv_sparse=False,
146
+ grid_pts_grad=self._geo_grid_pts.grad)
src/sparse_voxel_gears/renderer.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import svraster_cuda
11
+
12
+ from src.utils.image_utils import resize_rendering
13
+
14
+ class SVRenderer:
15
+
16
+ def freeze_vox_geo(self):
17
+ '''
18
+ Freeze grid points parameter and pre-gather them to each voxel.
19
+ '''
20
+ with torch.no_grad():
21
+ self.frozen_vox_geo = svraster_cuda.renderer.GatherGeoParams.apply(
22
+ self.vox_key,
23
+ torch.arange(self.num_voxels, device="cuda"),
24
+ self._geo_grid_pts
25
+ )
26
+ self._geo_grid_pts.requires_grad = False
27
+
28
+ def unfreeze_vox_geo(self):
29
+ '''
30
+ Unfreeze grid points parameter.
31
+ '''
32
+ del self.frozen_vox_geo
33
+ self._geo_grid_pts.requires_grad = True
34
+
35
+ def vox_fn(self, idx, cam_pos, color_mode=None, viewdir=None):
36
+ '''
37
+ Per-frame voxel property processing. Two important operations:
38
+ 1. Gather grid points parameter into each voxel.
39
+ 2. Compute view-dependent color of each voxel.
40
+
41
+ Input:
42
+ @idx Indices for active voxel for current frame.
43
+ @cam_pos Camera position.
44
+ Output:
45
+ @vox_params A dictionary of the pre-process voxel properties.
46
+ '''
47
+
48
+ # Gather the density values at the eight corners of each voxel.
49
+ # It defined a trilinear density field.
50
+ # The final tensor are in shape [#vox, 8]
51
+ if hasattr(self, 'frozen_vox_geo'):
52
+ geos = self.frozen_vox_geo
53
+ else:
54
+ geos = svraster_cuda.renderer.GatherGeoParams.apply(
55
+ self.vox_key,
56
+ idx,
57
+ self._geo_grid_pts
58
+ )
59
+
60
+ # Compute voxel colors
61
+ if color_mode is None or color_mode == "sh":
62
+ active_sh_degree = self.active_sh_degree
63
+ color_mode = "sh"
64
+ elif color_mode.startswith("sh"):
65
+ active_sh_degree = int(color_mode[2])
66
+ color_mode = "sh"
67
+
68
+ if color_mode == "sh":
69
+ rgbs = svraster_cuda.renderer.SH_eval.apply(
70
+ active_sh_degree,
71
+ idx,
72
+ self.vox_center,
73
+ cam_pos,
74
+ viewdir, # Ignore above two when viewdir is not None
75
+ self.sh0,
76
+ self.shs,
77
+ )
78
+ elif color_mode == "rand":
79
+ rgbs = torch.rand([self.num_voxels, 3], dtype=torch.float32, device="cuda")
80
+ elif color_mode == "dontcare":
81
+ rgbs = torch.empty([self.num_voxels, 3], dtype=torch.float32, device="cuda")
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ # Pack everything
86
+ vox_params = {
87
+ 'geos': geos,
88
+ 'rgbs': rgbs,
89
+ 'subdiv_p': self._subdiv_p, # Dummy param to record subdivision priority
90
+ }
91
+ if vox_params['subdiv_p'] is None:
92
+ vox_params['subdiv_p'] = torch.ones([self.num_voxels, 1], device="cuda")
93
+
94
+ return vox_params
95
+
96
+ def render(
97
+ self,
98
+ camera,
99
+ color_mode=None,
100
+ track_max_w=False,
101
+ ss=None,
102
+ output_depth=False,
103
+ output_normal=False,
104
+ output_T=False,
105
+ rand_bg=False,
106
+ use_auto_exposure=False,
107
+ **other_opt):
108
+
109
+ ###################################
110
+ # Pre-processing
111
+ ###################################
112
+ if ss is None:
113
+ ss = self.ss
114
+ w_src, h_src = camera.image_width, camera.image_height
115
+ w, h = round(w_src * ss), round(h_src * ss)
116
+ w_ss, h_ss = w / w_src, h / h_src
117
+ if ss != 1.0 and 'gt_color' in other_opt:
118
+ other_opt['gt_color'] = resize_rendering(other_opt['gt_color'], size=(h, w))
119
+
120
+ n_samp_per_vox = other_opt.pop('n_samp_per_vox', self.n_samp_per_vox)
121
+
122
+ ###################################
123
+ # Call low-level rasterization API
124
+ ###################################
125
+ raster_settings = svraster_cuda.renderer.RasterSettings(
126
+ color_mode=color_mode,
127
+ n_samp_per_vox=n_samp_per_vox,
128
+ image_width=w,
129
+ image_height=h,
130
+ tanfovx=camera.tanfovx,
131
+ tanfovy=camera.tanfovy,
132
+ cx=camera.cx * w_ss,
133
+ cy=camera.cy * h_ss,
134
+ w2c_matrix=camera.w2c,
135
+ c2w_matrix=camera.c2w,
136
+ bg_color=float(self.white_background),
137
+ near=camera.near,
138
+ need_depth=output_depth,
139
+ need_normal=output_normal,
140
+ track_max_w=track_max_w,
141
+ **other_opt)
142
+ color, depth, normal, T, max_w = svraster_cuda.renderer.rasterize_voxels(
143
+ raster_settings,
144
+ self.octpath,
145
+ self.vox_center,
146
+ self.vox_size,
147
+ self.vox_fn)
148
+
149
+ ###################################
150
+ # Post-processing and pack output
151
+ ###################################
152
+ if rand_bg:
153
+ color = color + T * torch.rand_like(color, requires_grad=False)
154
+ elif not self.white_background and not self.black_background:
155
+ color = color + T * color.mean((1,2), keepdim=True)
156
+
157
+ if use_auto_exposure:
158
+ color = camera.auto_exposure_apply(color)
159
+
160
+ render_pkg = {
161
+ 'color': color,
162
+ 'depth': depth if output_depth else None,
163
+ 'normal': normal if output_normal else None,
164
+ 'T': T if output_T else None,
165
+ 'max_w': max_w,
166
+ }
167
+
168
+ for k in ['color', 'depth', 'normal', 'T']:
169
+ render_pkg[f'raw_{k}'] = render_pkg[k]
170
+
171
+ # Post process super-sampling
172
+ if render_pkg[k] is not None and render_pkg[k].shape[-2:] != (h_src, w_src):
173
+ render_pkg[k] = resize_rendering(render_pkg[k], size=(h_src, w_src))
174
+
175
+ # Clip intensity
176
+ render_pkg['color'] = render_pkg['color'].clamp(0, 1)
177
+
178
+ return render_pkg
src/sparse_voxel_gears/renderer_copy.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ import svraster_cuda
11
+
12
+ from src.utils.image_utils import resize_rendering
13
+
14
+ class SVRenderer:
15
+
16
+ def freeze_vox_geo(self):
17
+ '''
18
+ Freeze grid points parameter and pre-gather them to each voxel.
19
+ '''
20
+ with torch.no_grad():
21
+ self.frozen_vox_geo = svraster_cuda.renderer.GatherGeoParams.apply(
22
+ self.vox_key,
23
+ torch.arange(self.num_voxels, device="cuda"),
24
+ self._geo_grid_pts
25
+ )
26
+ self._geo_grid_pts.requires_grad = False
27
+
28
+ def unfreeze_vox_geo(self):
29
+ '''
30
+ Unfreeze grid points parameter.
31
+ '''
32
+ del self.frozen_vox_geo
33
+ self._geo_grid_pts.requires_grad = True
34
+
35
+ def vox_fn(self, idx, cam_pos, color_mode=None, viewdir=None):
36
+ '''
37
+ Per-frame voxel property processing. Two important operations:
38
+ 1. Gather grid points parameter into each voxel.
39
+ 2. Compute view-dependent color of each voxel.
40
+
41
+ Input:
42
+ @idx Indices for active voxel for current frame.
43
+ @cam_pos Camera position.
44
+ Output:
45
+ @vox_params A dictionary of the pre-process voxel properties.
46
+ '''
47
+
48
+ # Gather the density values at the eight corners of each voxel.
49
+ # It defined a trilinear density field.
50
+ # The final tensor are in shape [#vox, 8]
51
+ if hasattr(self, 'frozen_vox_geo'):
52
+ geos = self.frozen_vox_geo
53
+ else:
54
+ geos = svraster_cuda.renderer.GatherGeoParams.apply(
55
+ self.vox_key,
56
+ idx,
57
+ self._geo_grid_pts
58
+ )
59
+
60
+ # Compute voxel colors
61
+ if color_mode is None or color_mode == "sh":
62
+ active_sh_degree = self.active_sh_degree
63
+ color_mode = "sh"
64
+ elif color_mode.startswith("sh"):
65
+ active_sh_degree = int(color_mode[2])
66
+ color_mode = "sh"
67
+
68
+ if color_mode == "sh":
69
+ rgbs = svraster_cuda.renderer.SH_eval.apply(
70
+ active_sh_degree,
71
+ idx,
72
+ self.vox_center,
73
+ cam_pos,
74
+ viewdir, # Ignore above two when viewdir is not None
75
+ self.sh0,
76
+ self.shs,
77
+ )
78
+ elif color_mode == "rand":
79
+ rgbs = torch.rand([self.num_voxels, 3], dtype=torch.float32, device="cuda")
80
+ elif color_mode == "dontcare":
81
+ rgbs = torch.empty([self.num_voxels, 3], dtype=torch.float32, device="cuda")
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ # Pack everything
86
+ vox_params = {
87
+ 'geos': geos,
88
+ 'rgbs': rgbs,
89
+ 'subdiv_p': self._subdiv_p, # Dummy param to record subdivision priority
90
+ }
91
+ if vox_params['subdiv_p'] is None:
92
+ vox_params['subdiv_p'] = torch.ones([self.num_voxels, 1], device="cuda")
93
+
94
+ return vox_params
95
+
96
+ def render(
97
+ self,
98
+ camera,
99
+ color_mode=None,
100
+ track_max_w=False,
101
+ ss=None,
102
+ output_depth=False,
103
+ output_normal=False,
104
+ output_T=False,
105
+ rand_bg=False,
106
+ use_auto_exposure=False,
107
+ **other_opt):
108
+
109
+ ###################################
110
+ # Pre-processing
111
+ ###################################
112
+ if ss is None:
113
+ ss = self.ss
114
+ w_src, h_src = camera.image_width, camera.image_height
115
+ w, h = round(w_src * ss), round(h_src * ss)
116
+ w_ss, h_ss = w / w_src, h / h_src
117
+ if ss != 1.0 and 'gt_color' in other_opt:
118
+ other_opt['gt_color'] = resize_rendering(other_opt['gt_color'], size=(h, w))
119
+
120
+ n_samp_per_vox = other_opt.pop('n_samp_per_vox', self.n_samp_per_vox)
121
+
122
+ ###################################
123
+ # Call low-level rasterization API
124
+ ###################################
125
+ raster_settings = svraster_cuda.renderer.RasterSettings(
126
+ color_mode=color_mode,
127
+ n_samp_per_vox=n_samp_per_vox,
128
+ image_width=w,
129
+ image_height=h,
130
+ tanfovx=camera.tanfovx,
131
+ tanfovy=camera.tanfovy,
132
+ cx=camera.cx * w_ss,
133
+ cy=camera.cy * h_ss,
134
+ w2c_matrix=camera.w2c,
135
+ c2w_matrix=camera.c2w,
136
+ bg_color=float(self.white_background),
137
+ near=camera.near,
138
+ need_depth=output_depth,
139
+ need_normal=output_normal,
140
+ track_max_w=track_max_w,
141
+ **other_opt)
142
+ color, depth, normal, T, max_w = svraster_cuda.renderer.rasterize_voxels(
143
+ raster_settings,
144
+ self.octpath,
145
+ self.vox_center,
146
+ self.vox_size,
147
+ self.vox_fn)
148
+
149
+ ###################################
150
+ # Post-processing and pack output
151
+ ###################################
152
+ if rand_bg:
153
+ color = color + T * torch.rand_like(color, requires_grad=False)
154
+ elif not self.white_background and not self.black_background:
155
+ color = color + T * color.mean((1,2), keepdim=True)
156
+
157
+ if use_auto_exposure:
158
+ color = camera.auto_exposure_apply(color)
159
+
160
+ render_pkg = {
161
+ 'color': color,
162
+ 'depth': depth if output_depth else None,
163
+ 'normal': normal if output_normal else None,
164
+ 'T': T if output_T else None,
165
+ 'max_w': max_w,
166
+ }
167
+
168
+ for k in ['color', 'depth', 'normal', 'T']:
169
+ render_pkg[f'raw_{k}'] = render_pkg[k]
170
+
171
+ # Post process super-sampling
172
+ if render_pkg[k] is not None and render_pkg[k].shape[-2:] != (h_src, w_src):
173
+ render_pkg[k] = resize_rendering(render_pkg[k], size=(h_src, w_src))
174
+
175
+ # Clip intensity
176
+ render_pkg['color'] = render_pkg['color'].clamp(0, 1)
177
+
178
+ return render_pkg
src/sparse_voxel_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from src.sparse_voxel_gears.constructor import SVConstructor
10
+ from src.sparse_voxel_gears.properties import SVProperties
11
+ from src.sparse_voxel_gears.renderer import SVRenderer
12
+ from src.sparse_voxel_gears.adaptive import SVAdaptive
13
+ from src.sparse_voxel_gears.io import SVInOut
14
+ from src.sparse_voxel_gears.pooling import SVPooling
15
+
16
+
17
+ class SparseVoxelModel(SVConstructor, SVProperties, SVRenderer, SVAdaptive, SVInOut, SVPooling):
18
+
19
+ def __init__(self,
20
+ n_samp_per_vox=1, # Number of sampled points per visited voxel
21
+ sh_degree=3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
22
+ ss=1.5, # Super-sampling rates for anti-aliasing
23
+ white_background=False, # Assum white background
24
+ black_background=False, # Assum black background
25
+ ):
26
+ '''
27
+ Setup of the model meta. At this point, no voxel is allocated.
28
+ Use the following methods to allocate voxels and parameters.
29
+
30
+ 1. `model_load` defined in `src/sparse_voxel_gears/io.py`.
31
+ Load the saved models from a given path.
32
+
33
+ 2. `model_init` defined in `src/sparse_voxel_gears/constructor.py`.
34
+ Heuristically initial the sparse grid layout and parameters from the training datas.
35
+ '''
36
+ super().__init__()
37
+
38
+ self.n_samp_per_vox = n_samp_per_vox
39
+ self.max_sh_degree = sh_degree
40
+ self.ss = ss
41
+ self.white_background = white_background
42
+ self.black_background = black_background
43
+
44
+ # List the variable names
45
+ self.per_voxel_attr_lst = [
46
+ 'octpath', 'octlevel',
47
+ '_subdiv_p',
48
+ ]
49
+ self.per_voxel_param_lst = [
50
+ '_sh0', '_shs',
51
+ ]
52
+ self.grid_pts_param_lst = [
53
+ '_geo_grid_pts',
54
+ ]
55
+
56
+ # To be init from model_init
57
+ self.scene_center = None
58
+ self.scene_extent = None
59
+ self.inside_extent = None
60
+ self.octpath = None
61
+ self.octlevel = None
62
+ self.active_sh_degree = sh_degree
63
+
64
+ self._geo_grid_pts = None
65
+ self._sh0 = None
66
+ self._shs = None
67
+ self._subdiv_p = None
src/sparse_voxel_model_copy.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from src.sparse_voxel_gears.constructor import SVConstructor
10
+ from src.sparse_voxel_gears.properties import SVProperties
11
+ from src.sparse_voxel_gears.renderer import SVRenderer
12
+ from src.sparse_voxel_gears.adaptive import SVAdaptive
13
+ from src.sparse_voxel_gears.io import SVInOut
14
+ from src.sparse_voxel_gears.pooling import SVPooling
15
+
16
+
17
+ class SparseVoxelModel(SVConstructor, SVProperties, SVRenderer, SVAdaptive, SVInOut, SVPooling):
18
+
19
+ def __init__(self,
20
+ n_samp_per_vox=1, # Number of sampled points per visited voxel
21
+ sh_degree=3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
22
+ ss=1.5, # Super-sampling rates for anti-aliasing
23
+ white_background=False, # Assum white background
24
+ black_background=False, # Assum black background
25
+ ):
26
+ '''
27
+ Setup of the model meta. At this point, no voxel is allocated.
28
+ Use the following methods to allocate voxels and parameters.
29
+
30
+ 1. `model_load` defined in `src/sparse_voxel_gears/io.py`.
31
+ Load the saved models from a given path.
32
+
33
+ 2. `model_init` defined in `src/sparse_voxel_gears/constructor.py`.
34
+ Heuristically initial the sparse grid layout and parameters from the training datas.
35
+ '''
36
+ super().__init__()
37
+
38
+ self.n_samp_per_vox = n_samp_per_vox
39
+ self.max_sh_degree = sh_degree
40
+ self.ss = ss
41
+ self.white_background = white_background
42
+ self.black_background = black_background
43
+
44
+ # List the variable names
45
+ self.per_voxel_attr_lst = [
46
+ 'octpath', 'octlevel',
47
+ '_subdiv_p',
48
+ ]
49
+ self.per_voxel_param_lst = [
50
+ '_sh0', '_shs',
51
+ ]
52
+ self.grid_pts_param_lst = [
53
+ '_geo_grid_pts',
54
+ ]
55
+
56
+ # To be init from model_init
57
+ self.scene_center = None
58
+ self.scene_extent = None
59
+ self.inside_extent = None
60
+ self.octpath = None
61
+ self.octlevel = None
62
+ self.active_sh_degree = sh_degree
63
+
64
+ self._geo_grid_pts = None
65
+ self._sh0 = None
66
+ self._shs = None
67
+ self._subdiv_p = None
src/utils/__pycache__/activation_utils.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
src/utils/__pycache__/bounding_utils.cpython-39.pyc ADDED
Binary file (3.05 kB). View file
 
src/utils/__pycache__/camera_utils.cpython-39.pyc ADDED
Binary file (2.38 kB). View file
 
src/utils/__pycache__/colmap_utils.cpython-39.pyc ADDED
Binary file (1.77 kB). View file
 
src/utils/__pycache__/fuser_utils.cpython-39.pyc ADDED
Binary file (3.87 kB). View file
 
src/utils/__pycache__/image_utils.cpython-39.pyc ADDED
Binary file (2.51 kB). View file
 
src/utils/__pycache__/loss_utils.cpython-39.pyc ADDED
Binary file (8.78 kB). View file
 
src/utils/__pycache__/marching_cubes_utils.cpython-39.pyc ADDED
Binary file (25.1 kB). View file
 
src/utils/__pycache__/mono_utils.cpython-39.pyc ADDED
Binary file (4.88 kB). View file
 
src/utils/__pycache__/octree_utils.cpython-39.pyc ADDED
Binary file (7.49 kB). View file
 
src/utils/__pycache__/system_utils.cpython-39.pyc ADDED
Binary file (372 Bytes). View file
 
src/utils/activation_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import torch
10
+ from svraster_cuda.meta import STEP_SZ_SCALE
11
+
12
+ def softplus(x):
13
+ return torch.nn.functional.softplus(x)
14
+
15
+ def exp_linear_10(x):
16
+ return torch.where(x > 1, x, torch.exp(x - 1))
17
+
18
+ def exp_linear_11(x):
19
+ return torch.where(x > 1.1, x, torch.exp(0.909090909091 * x - 0.904689820196))
20
+
21
+ def exp_linear_20(x):
22
+ return torch.where(x > 2.0, x, torch.exp(0.5 * x - 0.30685281944))
23
+
24
+ def softplus_inverse(y):
25
+ return y + torch.log(-torch.expm1(-y))
26
+
27
+ def exp_linear_10_inverse(y):
28
+ return torch.where(y > 1, y, torch.log(y) + 1)
29
+
30
+ def exp_linear_11_inverse(y):
31
+ return torch.where(y > 1.1, y, (torch.log(y) + 0.904689820196) / 0.909090909091)
32
+
33
+ def exp_linear_20_inverse(x):
34
+ return torch.where(y > 2.0, y, (torch.log(y) + 0.30685281944) / 0.5)
35
+
36
+ def smooth_clamp_max(x, max_val):
37
+ return max_val - torch.nn.functional.softplus(max_val - x)
38
+
39
+ def density2alpha(density, interval):
40
+ return 1 - torch.exp(-STEP_SZ_SCALE * interval * density)
41
+
42
+ def alpha2density(alpha, interval):
43
+ return torch.log(1 - alpha) / (-STEP_SZ_SCALE * interval)
44
+
45
+ def rgb2shzero(x):
46
+ return (x - 0.5) / 0.28209479177387814
47
+
48
+ def shzero2rgb(x):
49
+ return x * 0.28209479177387814 + 0.5
src/utils/bounding_utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+
11
+
12
+ def decide_main_bounding(bound_mode="default",
13
+ forward_dist_scale=1.0, # For "forward" mode
14
+ pcd_density_rate=0.1, # For "pcd" mode
15
+ bound_scale=1.0, # Scaling of the bounding
16
+ tr_cams=None, # Cameras
17
+ pcd=None, # Point cloud
18
+ suggested_bounding=None):
19
+ if bound_mode == "default" and suggested_bounding is not None:
20
+ print("Use suggested bounding")
21
+ center = suggested_bounding.mean(0)
22
+ radius = (suggested_bounding[1] - suggested_bounding[0]) * 0.5
23
+ elif bound_mode in ["camera_max", "camera_median"]:
24
+ center, radius = main_scene_bound_camera_heuristic(
25
+ cams=tr_cams, bound_mode=bound_mode)
26
+ elif bound_mode == "forward":
27
+ center, radius = main_scene_bound_forward_heuristic(
28
+ cams=tr_cams, forward_dist_scale=forward_dist_scale)
29
+ elif bound_mode == "pcd":
30
+ center, radius = main_scene_bound_pcd_heuristic(
31
+ pcd=pcd, pcd_density_rate=pcd_density_rate)
32
+ elif bound_mode == "default":
33
+ cam_lookats = np.stack([cam.lookat.tolist() for cam in tr_cams])
34
+ lookat_dots = (cam_lookats[:,None] * cam_lookats).sum(-1)
35
+ is_forward_facing = lookat_dots.min() > 0
36
+
37
+ if is_forward_facing:
38
+ center, radius = main_scene_bound_forward_heuristic(
39
+ cams=tr_cams, forward_dist_scale=forward_dist_scale)
40
+ else:
41
+ center, radius = main_scene_bound_camera_heuristic(
42
+ cams=tr_cams, bound_mode="camera_median")
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ radius = radius * bound_scale
47
+
48
+ bounding = np.array([
49
+ center - radius,
50
+ center + radius,
51
+ ], dtype=np.float32)
52
+ return bounding
53
+
54
+
55
+ def main_scene_bound_camera_heuristic(cams, bound_mode):
56
+ print("Heuristic bounding:", bound_mode)
57
+ cam_positions = np.stack([cam.position.tolist() for cam in cams])
58
+ center = cam_positions.mean(0)
59
+ dists = np.linalg.norm(cam_positions - center, axis=1)
60
+ if bound_mode == "camera_max":
61
+ radius = np.max(dists)
62
+ elif bound_mode == "camera_median":
63
+ radius = np.median(dists)
64
+ else:
65
+ raise NotImplementedError
66
+ return center, radius
67
+
68
+
69
+ def main_scene_bound_forward_heuristic(cams, forward_dist_scale):
70
+ print("Heuristic bounding: forward")
71
+ positions = np.stack([cam.position.tolist() for cam in cams])
72
+ cam_center = positions.mean(0)
73
+ cam_lookat = np.stack([cam.lookat.tolist() for cam in cams]).mean(0)
74
+ cam_lookat /= np.linalg.norm(cam_lookat)
75
+ cam_extent = 2 * np.linalg.norm(positions - cam_center, axis=1).max()
76
+
77
+ center = cam_center + forward_dist_scale * cam_extent * cam_lookat
78
+ radius = 0.8 * forward_dist_scale * cam_extent
79
+
80
+ return center, radius
81
+
82
+
83
+ def main_scene_bound_pcd_heuristic(pcd, pcd_density_rate):
84
+ print("Heuristic bounding: pcd")
85
+ center = np.median(pcd.points, axis=0)
86
+ dist = np.abs(pcd.points - center).max(axis=1)
87
+ dist = np.sort(dist)
88
+ density = (1 + np.arange(len(dist))) * (dist > 0) / ((2 * dist) ** 3 + 1e-6)
89
+
90
+ # Should cover at least 5% of the point
91
+ begin_idx = round(len(density) * 0.05)
92
+
93
+ # Find the radius with maximum point density
94
+ max_idx = begin_idx + density[begin_idx:].argmax()
95
+
96
+ # Find the smallest radius with point density equal to pcd_density_rate of maximum
97
+ target_density = pcd_density_rate * density[max_idx]
98
+ target_idx = max_idx + np.where(density[max_idx:] < target_density)[0][0]
99
+
100
+ radius = dist[target_idx]
101
+
102
+ return center, radius
src/utils/camera_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import numpy as np
10
+ from scipy.interpolate import make_interp_spline
11
+
12
+
13
+ def fov2focal(fov, pixels):
14
+ return pixels / (2 * np.tan(0.5 * fov))
15
+
16
+ def focal2fov(focal, pixels):
17
+ return 2 * np.arctan(pixels / (2 * focal))
18
+
19
+
20
+ def interpolate_poses(poses, n_frame, periodic=True):
21
+
22
+ assert len(poses) > 1
23
+
24
+ poses = list(poses)
25
+ bc_type = None
26
+
27
+ if periodic:
28
+ poses.append(poses[0])
29
+ bc_type = "periodic"
30
+
31
+ pos_lst = np.stack([pose[:3, 3] for pose in poses])
32
+ lookat_lst = np.stack([pose[:3, 2] for pose in poses])
33
+ right_lst = np.stack([pose[:3, 0] for pose in poses])
34
+
35
+ ts = np.linspace(0, 1, len(poses))
36
+ pos_interp_f = make_interp_spline(ts, pos_lst, bc_type=bc_type)
37
+ lookat_interp_f = make_interp_spline(ts, lookat_lst, bc_type=bc_type)
38
+ right_interp_f = make_interp_spline(ts, right_lst, bc_type=bc_type)
39
+
40
+ samps = np.linspace(0, 1, n_frame+1)[:n_frame]
41
+ pos_video = pos_interp_f(samps)
42
+ lookat_video = lookat_interp_f(samps)
43
+ right_video = right_interp_f(samps)
44
+ interp_poses = []
45
+ for i in range(n_frame):
46
+ pos = pos_video[i]
47
+ lookat = lookat_video[i] / np.linalg.norm(lookat_video[i])
48
+ right_ = right_video[i] / np.linalg.norm(right_video[i])
49
+ down = np.cross(lookat, right_)
50
+ right = np.cross(down, lookat)
51
+ c2w = np.eye(4, dtype=np.float32)
52
+ c2w[:3, 0] = right
53
+ c2w[:3, 1] = down
54
+ c2w[:3, 2] = lookat
55
+ c2w[:3, 3] = pos
56
+ interp_poses.append(c2w)
57
+
58
+ return interp_poses
59
+
60
+
61
+ def gen_circular_poses(radius,
62
+ n_frame,
63
+ starting=1.5 * np.pi, # Starting from -z
64
+ ):
65
+ poses = []
66
+ for rad in np.linspace(starting, starting + 2 * np.pi, n_frame):
67
+ pos = radius * np.array([np.cos(rad), 0, np.sin(rad)])
68
+ lookat = -pos / np.linalg.norm(pos)
69
+ down = np.array([0, 1, 0])
70
+ right = np.cross(down, lookat)
71
+ right = right / np.linalg.norm(right)
72
+ down = np.cross(lookat, right)
73
+ c2w = np.eye(4, dtype=np.float32)
74
+ c2w[:3, 0] = right
75
+ c2w[:3, 1] = down
76
+ c2w[:3, 2] = lookat
77
+ c2w[:3, 3] = pos
78
+ poses.append(c2w)
79
+ return poses
src/utils/colmap_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import pycolmap
10
+ import numpy as np
11
+
12
+ from typing import NamedTuple
13
+
14
+
15
+ class PointCloud(NamedTuple):
16
+ points: np.array
17
+ colors: np.array
18
+ errors: np.array
19
+ corr: dict
20
+
21
+
22
+ def parse_colmap_pts(sfm: pycolmap.Reconstruction, transform: np.array =None):
23
+ """
24
+ Parse COLMAP points and correspondents.
25
+
26
+ Input:
27
+ @sfm Reconstruction from COLMAP.
28
+ @transform 3x3 matrix to transform xyz.
29
+ Output:
30
+ @xyz Nx3 point positions.
31
+ @rgb Nx3 point colors.
32
+ @err N errors.
33
+ @corr Dictionary from file name to point indices.
34
+ """
35
+
36
+ xyz = []
37
+ rgb = []
38
+ err = []
39
+ points_id = []
40
+ for k, v in sfm.points3D.items():
41
+ points_id.append(k)
42
+ xyz.append(v.xyz)
43
+ rgb.append(v.color)
44
+ err.append(v.error)
45
+ if transform is not None:
46
+ xyz[-1] = transform @ xyz[-1]
47
+
48
+ xyz = np.array(xyz)
49
+ rgb = np.array(rgb)
50
+ err = np.array(err)
51
+ points_id = np.array(points_id)
52
+
53
+ points_idmap = np.full([points_id.max()+2], -1, dtype=np.int64)
54
+ points_idmap[points_id] = np.arange(len(xyz))
55
+
56
+ corr = {}
57
+ for image in sfm.images.values():
58
+ idx = np.array([p.point3D_id for p in image.points2D if p.has_point3D()])
59
+ corr[image.name] = points_idmap[idx]
60
+ assert corr[image.name].min() >= 0 and corr[image.name].max() < len(xyz)
61
+
62
+ return PointCloud(points=xyz, colors=rgb, errors=err, corr=corr)
src/utils/fuser_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ '''
10
+ Reference: KinectFusion algorithm.
11
+ '''
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+
17
+
18
+ class Fuser:
19
+ def __init__(self,
20
+ xyz,
21
+ bandwidth,
22
+ use_trunc=True,
23
+ fuse_tsdf=True,
24
+ feat_dim=0,
25
+ alpha_thres=0.5,
26
+ crop_border=0.0,
27
+ normal_weight=False,
28
+ depth_weight=False,
29
+ border_weight=False,
30
+ max_norm_dist=10.,
31
+ use_half=False):
32
+ assert len(xyz.shape) == 2
33
+ assert xyz.shape[1] == 3
34
+ self.xyz = xyz
35
+ self.bandwidth = bandwidth
36
+ self.use_trunc = use_trunc
37
+ self.fuse_tsdf = fuse_tsdf
38
+ self.feat_dim = feat_dim
39
+ self.alpha_thres = alpha_thres
40
+ self.crop_border = crop_border
41
+ self.normal_weight = normal_weight
42
+ self.depth_weight = depth_weight
43
+ self.border_weight = border_weight
44
+ self.max_norm_dist = max_norm_dist
45
+
46
+ self.dtype = torch.float16 if use_half else torch.float32
47
+ self.weight = torch.zeros([len(xyz), 1], dtype=self.dtype, device="cuda")
48
+ self.feat = torch.zeros([len(xyz), feat_dim], dtype=self.dtype, device="cuda")
49
+ if self.fuse_tsdf:
50
+ self.sd_val = torch.zeros([len(xyz), 1], dtype=self.dtype, device="cuda")
51
+ else:
52
+ self.sd_val = None
53
+
54
+ def integrate(self, cam, depth, feat=None, alpha=None):
55
+ # Project grid points to image
56
+ xyz_uv = cam.project(self.xyz)
57
+ xyz_front = ((self.xyz - cam.position) @ cam.lookat) > cam.near
58
+
59
+ # Filter points projected outside
60
+ filter_idx = torch.where((xyz_uv.abs() <= 1-self.crop_border).all(-1) & xyz_front)[0]
61
+ valid_idx = filter_idx
62
+ valid_xyz = self.xyz[filter_idx]
63
+ valid_uv = xyz_uv[filter_idx]
64
+
65
+ # Compute projective sdf
66
+ valid_frame_depth = torch.nn.functional.grid_sample(
67
+ depth.view(1,1,*depth.shape[-2:]),
68
+ valid_uv.view(1,1,-1,2),
69
+ mode='bilinear',
70
+ align_corners=False).flatten()
71
+ valid_xyz_depth = (valid_xyz - cam.position) @ cam.lookat
72
+ valid_sdf = valid_frame_depth - valid_xyz_depth
73
+
74
+ if torch.is_tensor(self.bandwidth):
75
+ bandwidth = self.bandwidth[valid_idx]
76
+ else:
77
+ bandwidth = self.bandwidth
78
+
79
+ valid_sdf *= (1 / bandwidth)
80
+
81
+ if self.use_trunc:
82
+ # Filter occluded
83
+ filter_idx = torch.where(valid_sdf >= -1)[0]
84
+ valid_idx = valid_idx[filter_idx]
85
+ valid_uv = valid_uv[filter_idx]
86
+ valid_frame_depth = valid_frame_depth[filter_idx]
87
+ valid_sdf = valid_sdf[filter_idx]
88
+ valid_sdf = valid_sdf.clamp_(-1, 1)
89
+
90
+ # Init weighting
91
+ w = torch.ones_like(valid_frame_depth)
92
+ else:
93
+ norm_dist = valid_sdf.abs()
94
+ w = torch.exp(-norm_dist.clamp_max(self.max_norm_dist))
95
+
96
+ # Alpha filtering
97
+ if alpha is not None:
98
+ valid_alpha = torch.nn.functional.grid_sample(
99
+ alpha.view(1,1,*alpha.shape[-2:]),
100
+ valid_uv.view(1,1,-1,2),
101
+ mode='bilinear',
102
+ align_corners=False).flatten()
103
+ w *= valid_alpha
104
+
105
+ filter_idx = torch.where(valid_alpha >= self.alpha_thres)[0]
106
+ valid_idx = valid_idx[filter_idx]
107
+ valid_uv = valid_uv[filter_idx]
108
+ valid_frame_depth = valid_frame_depth[filter_idx]
109
+ valid_sdf = valid_sdf[filter_idx]
110
+ w = w[filter_idx]
111
+
112
+ # Compute geometric weighting
113
+ if self.depth_weight:
114
+ w *= 1 / valid_frame_depth.clamp_min(0.1)
115
+
116
+ if self.normal_weight:
117
+ normal = cam.depth2normal(depth)
118
+ rd = torch.nn.functional.normalize(cam.depth2pts(depth) - cam.position.view(3,1,1), dim=0)
119
+ cos_theta = (normal * rd).sum(0).clamp_min(0)
120
+ valid_cos_theta = torch.nn.functional.grid_sample(
121
+ cos_theta.view(1,1,*cos_theta.shape[-2:]),
122
+ valid_uv.view(1,1,-1,2),
123
+ mode='bilinear',
124
+ align_corners=False).flatten()
125
+ w *= valid_cos_theta
126
+
127
+ if self.border_weight:
128
+ # The image center get 1.0; corners get 0.1
129
+ w *= 1 / (1 + 9/np.sqrt(2) * valid_uv.square().sum(1).sqrt())
130
+
131
+ # Reshape integration weight
132
+ w = w.unsqueeze(-1).to(self.dtype)
133
+
134
+ # Integrate weight
135
+ self.weight[valid_idx] += w
136
+
137
+ # Integrate tsdf
138
+ if self.fuse_tsdf:
139
+ valid_sdf = valid_sdf.unsqueeze(-1).to(self.dtype)
140
+ self.sd_val[valid_idx] += w * valid_sdf
141
+
142
+ # Sample feature
143
+ if self.feat_dim > 0:
144
+ valid_feat = torch.nn.functional.grid_sample(
145
+ feat.view(1,self.feat_dim,*feat.shape[-2:]).to(self.dtype),
146
+ valid_uv.view(1,1,-1,2).to(self.dtype),
147
+ mode='bilinear',
148
+ align_corners=False)[0,:,0].T
149
+ self.feat[valid_idx] += w * valid_feat
150
+
151
+ @property
152
+ def feature(self):
153
+ return self.feat / self.weight
154
+
155
+ @property
156
+ def tsdf(self):
157
+ return self.sd_val / self.weight
158
+
159
+
160
+ @torch.no_grad()
161
+ def rgb_fusion(voxel_model, cameras):
162
+
163
+ from .octree_utils import level_2_vox_size
164
+
165
+ # Define volume integrator
166
+ finest_vox_size = level_2_vox_size(voxel_model.scene_extent, voxel_model.octlevel.max()).item()
167
+ feat_volume = Fuser(
168
+ xyz=voxel_model.vox_center,
169
+ bandwidth=10 * finest_vox_size,
170
+ use_trunc=False,
171
+ fuse_tsdf=False,
172
+ feat_dim=3,
173
+ crop_border=0.,
174
+ normal_weight=False,
175
+ depth_weight=False,
176
+ border_weight=False,
177
+ use_half=True)
178
+
179
+ # Run semantic maps fusion
180
+ for cam in cameras:
181
+ render_pkg = voxel_model.render(cam, color_mode="dontcare", output_depth=True)
182
+ depth = render_pkg['depth'][2]
183
+ feat_volume.integrate(cam=cam, feat=cam.image.cuda(), depth=depth)
184
+
185
+ return feat_volume.feature.nan_to_num_(0.5).float()