Spaces:
Paused
Paused
Commit
·
2098a77
0
Parent(s):
Reinitialize clean repo without large files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- README.md +13 -0
- app.py +497 -0
- requirements.txt +14 -0
- src/__pycache__/cameras.cpython-39.pyc +0 -0
- src/__pycache__/config.cpython-38.pyc +0 -0
- src/__pycache__/config.cpython-39.pyc +0 -0
- src/__pycache__/sparse_voxel_model.cpython-39.pyc +0 -0
- src/cameras.py +287 -0
- src/config.py +230 -0
- src/config_old.py +230 -0
- src/dataloader/__pycache__/data_pack.cpython-39.pyc +0 -0
- src/dataloader/__pycache__/reader_colmap_dataset.cpython-39.pyc +0 -0
- src/dataloader/__pycache__/reader_nerf_dataset.cpython-39.pyc +0 -0
- src/dataloader/data_pack.py +232 -0
- src/dataloader/reader_colmap_dataset.py +162 -0
- src/dataloader/reader_colmap_dataset_or.py +148 -0
- src/dataloader/reader_nerf_dataset.py +180 -0
- src/dataloader/reader_nerf_dataset_copy.py +170 -0
- src/sparse_voxel_gears/__pycache__/adaptive.cpython-39.pyc +0 -0
- src/sparse_voxel_gears/__pycache__/constructor.cpython-39.pyc +0 -0
- src/sparse_voxel_gears/__pycache__/io.cpython-39.pyc +0 -0
- src/sparse_voxel_gears/__pycache__/pooling.cpython-39.pyc +0 -0
- src/sparse_voxel_gears/__pycache__/properties.cpython-39.pyc +0 -0
- src/sparse_voxel_gears/__pycache__/renderer.cpython-39.pyc +0 -0
- src/sparse_voxel_gears/adaptive.py +296 -0
- src/sparse_voxel_gears/constructor.py +425 -0
- src/sparse_voxel_gears/io.py +156 -0
- src/sparse_voxel_gears/pooling.py +68 -0
- src/sparse_voxel_gears/properties.py +146 -0
- src/sparse_voxel_gears/renderer.py +178 -0
- src/sparse_voxel_gears/renderer_copy.py +178 -0
- src/sparse_voxel_model.py +67 -0
- src/sparse_voxel_model_copy.py +67 -0
- src/utils/__pycache__/activation_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/bounding_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/camera_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/colmap_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/fuser_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/image_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/loss_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/marching_cubes_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/mono_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/octree_utils.cpython-39.pyc +0 -0
- src/utils/__pycache__/system_utils.cpython-39.pyc +0 -0
- src/utils/activation_utils.py +49 -0
- src/utils/bounding_utils.py +102 -0
- src/utils/camera_utils.py +79 -0
- src/utils/colmap_utils.py +62 -0
- src/utils/fuser_utils.py +185 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.ply filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Beetle Viz
|
| 3 |
+
emoji: 😻
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Created on Mon Oct 6 10:16:31 2025
|
| 5 |
+
|
| 6 |
+
@author: nibio
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 10 |
+
|
| 11 |
+
import os, time
|
| 12 |
+
import numpy as np
|
| 13 |
+
import imageio.v3 as iio
|
| 14 |
+
from scipy.spatial.transform import Rotation
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from src.config import cfg, update_config
|
| 20 |
+
from src.dataloader.data_pack import DataPack
|
| 21 |
+
from src.sparse_voxel_model import SparseVoxelModel
|
| 22 |
+
from src.utils.image_utils import im_tensor2np, viz_tensordepth
|
| 23 |
+
from src.cameras import MiniCam
|
| 24 |
+
|
| 25 |
+
import viser
|
| 26 |
+
import viser.transforms as tf
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def matrix2wxyz(R: np.ndarray) -> np.ndarray:
|
| 30 |
+
return Rotation.from_matrix(R).as_quat()[[3, 0, 1, 2]]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def wxyz2matrix(wxyz: np.ndarray) -> np.ndarray:
|
| 34 |
+
return Rotation.from_quat(wxyz[[1, 2, 3, 0]]).as_matrix()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SVRasterViewer:
|
| 38 |
+
def __init__(self, cfg):
|
| 39 |
+
|
| 40 |
+
# ---------- Data & model ----------
|
| 41 |
+
data_pack = DataPack(
|
| 42 |
+
source_path=cfg.data.source_path,
|
| 43 |
+
image_dir_name=cfg.data.image_dir_name,
|
| 44 |
+
res_downscale=cfg.data.res_downscale,
|
| 45 |
+
res_width=cfg.data.res_width,
|
| 46 |
+
skip_blend_alpha=cfg.data.skip_blend_alpha,
|
| 47 |
+
alpha_is_white=cfg.model.white_background,
|
| 48 |
+
data_device=cfg.data.data_device,
|
| 49 |
+
use_test=cfg.data.eval,
|
| 50 |
+
test_every=cfg.data.test_every,
|
| 51 |
+
camera_params_only=True,
|
| 52 |
+
)
|
| 53 |
+
self.tr_cam_lst = data_pack.get_train_cameras()
|
| 54 |
+
self.te_cam_lst = data_pack.get_test_cameras()
|
| 55 |
+
|
| 56 |
+
self.scene_center = (
|
| 57 |
+
np.mean([c.c2w[:3, 3].cpu().numpy() for c in self.tr_cam_lst], axis=0)
|
| 58 |
+
if len(self.tr_cam_lst)
|
| 59 |
+
else np.zeros(3, dtype=np.float32)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.voxel_model = SparseVoxelModel(
|
| 63 |
+
n_samp_per_vox=cfg.model.n_samp_per_vox,
|
| 64 |
+
sh_degree=cfg.model.sh_degree,
|
| 65 |
+
ss=cfg.model.ss,
|
| 66 |
+
white_background=cfg.model.white_background,
|
| 67 |
+
black_background=cfg.model.black_background,
|
| 68 |
+
)
|
| 69 |
+
self.voxel_model.load_iteration(args.model_path, args.iteration) # args from __main__
|
| 70 |
+
self.voxel_model.freeze_vox_geo()
|
| 71 |
+
|
| 72 |
+
# ---------- UI ----------
|
| 73 |
+
self.server = viser.ViserServer(port=cfg.port)
|
| 74 |
+
self.is_connected = False
|
| 75 |
+
|
| 76 |
+
self.server.gui.set_panel_label("SVRaster viser")
|
| 77 |
+
self.server.gui.add_markdown(
|
| 78 |
+
"**View control:**\n- Mouse drag + scroll\n- WASD + QE keys"
|
| 79 |
+
)
|
| 80 |
+
self.fps = self.server.gui.add_text("Rending FPS", initial_value="-1", disabled=True)
|
| 81 |
+
|
| 82 |
+
self.active_sh_degree_slider = self.server.gui.add_slider(
|
| 83 |
+
"active_sh_degree", min=0, max=self.voxel_model.max_sh_degree, step=1,
|
| 84 |
+
initial_value=self.voxel_model.active_sh_degree
|
| 85 |
+
)
|
| 86 |
+
self.ss_slider = self.server.gui.add_slider("ss", min=0.5, max=2.0, step=0.05, initial_value=self.voxel_model.ss)
|
| 87 |
+
self.width_slider = self.server.gui.add_slider("width", min=64, max=2048, step=8, initial_value=1024)
|
| 88 |
+
self.fovx_slider = self.server.gui.add_slider("fovx", min=10, max=150, step=1, initial_value=70)
|
| 89 |
+
self.near_slider = self.server.gui.add_slider("near", min=0.02,max=10, step=0.01,initial_value=0.2)
|
| 90 |
+
|
| 91 |
+
self.render_dropdown = self.server.gui.add_dropdown(
|
| 92 |
+
"render mod", options=["all","rgb only","depth only","normal only"], initial_value="all"
|
| 93 |
+
)
|
| 94 |
+
self.output_dropdown = self.server.gui.add_dropdown(
|
| 95 |
+
"output", options=["rgb","alpha","dmean","dmed","dmean2n","dmed2n","n"], initial_value="rgb"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# ---- Focus & crop controls ----
|
| 99 |
+
self.alpha_thr_slider = self.server.gui.add_slider(
|
| 100 |
+
"alpha_threshold", min=0.0, max=0.95, step=0.01, initial_value=0.35
|
| 101 |
+
)
|
| 102 |
+
self.keep_closest_slider = self.server.gui.add_slider(
|
| 103 |
+
"keep_closest_pct", min=0.2, max=1.0, step=0.05, initial_value=0.6
|
| 104 |
+
)
|
| 105 |
+
self.hide_outside_checkbox = self.server.gui.add_checkbox(
|
| 106 |
+
"hide_outside_focus", initial_value=False
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.center_btn = self.server.gui.add_button("Center on object")
|
| 110 |
+
self.reset_btn = self.server.gui.add_button("Reset to first view")
|
| 111 |
+
self.autoframe_btn = self.server.gui.add_button("Auto-frame (depth)")
|
| 112 |
+
self.focus_btn = self.server.gui.add_button("Focus foreground")
|
| 113 |
+
self.rebase_btn = self.server.gui.add_button("Recenter world to focus")
|
| 114 |
+
|
| 115 |
+
# ---- state for world rebase / focus mask ----
|
| 116 |
+
self.world_offset = np.zeros(3, dtype=np.float32) # world translation applied during render
|
| 117 |
+
self.focus_center: Optional[np.ndarray] = None
|
| 118 |
+
|
| 119 |
+
# ---------- Camera frusta ----------
|
| 120 |
+
self.tr_frust, self.te_frust = [], []
|
| 121 |
+
|
| 122 |
+
def add_frustum(name, cam, color):
|
| 123 |
+
c2w = cam.c2w.cpu().numpy()
|
| 124 |
+
frame = self.server.scene.add_camera_frustum(
|
| 125 |
+
name,
|
| 126 |
+
fov=cam.fovy,
|
| 127 |
+
aspect=cam.image_width / cam.image_height,
|
| 128 |
+
scale=0.10,
|
| 129 |
+
wxyz=matrix2wxyz(c2w[:3, :3]),
|
| 130 |
+
position=c2w[:3, 3],
|
| 131 |
+
color=color,
|
| 132 |
+
visible=False,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@frame.on_click
|
| 136 |
+
def _(event: viser.SceneNodePointerEvent):
|
| 137 |
+
client = event.client
|
| 138 |
+
with client.atomic():
|
| 139 |
+
client.camera.wxyz = event.target.wxyz
|
| 140 |
+
client.camera.position = event.target.position
|
| 141 |
+
self._camera_lookat(client, self.scene_center)
|
| 142 |
+
|
| 143 |
+
return frame
|
| 144 |
+
|
| 145 |
+
for i, cam in enumerate(self.tr_cam_lst):
|
| 146 |
+
self.tr_frust.append(add_frustum(f"/frustum/train/{i:04d}", cam, [0.0, 1.0, 0.0]))
|
| 147 |
+
for i, cam in enumerate(self.te_cam_lst):
|
| 148 |
+
self.te_frust.append(add_frustum(f"/frustum/test/{i:04d}", cam, [1.0, 0.0, 0.0]))
|
| 149 |
+
|
| 150 |
+
self.show_cam_dropdown = self.server.gui.add_dropdown(
|
| 151 |
+
"show cameras", options=["none","train","test","all"], initial_value="none"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
@self.show_cam_dropdown.on_update
|
| 155 |
+
def _(_):
|
| 156 |
+
for f in self.tr_frust: f.visible = self.show_cam_dropdown.value in ["train","all"]
|
| 157 |
+
for f in self.te_frust: f.visible = self.show_cam_dropdown.value in ["test","all"]
|
| 158 |
+
|
| 159 |
+
# ---------- Button handlers ----------
|
| 160 |
+
@self.center_btn.on_click
|
| 161 |
+
def _(event: viser.GuiEvent):
|
| 162 |
+
if event.client: self._camera_lookat(event.client, self.scene_center)
|
| 163 |
+
|
| 164 |
+
@self.reset_btn.on_click
|
| 165 |
+
def _(event: viser.GuiEvent):
|
| 166 |
+
client = event.client
|
| 167 |
+
if not client: return
|
| 168 |
+
init = self.tr_cam_lst[0].c2w.cpu().numpy()
|
| 169 |
+
with client.atomic():
|
| 170 |
+
client.camera.wxyz = matrix2wxyz(init[:3, :3])
|
| 171 |
+
client.camera.position = init[:3, 3]
|
| 172 |
+
self._camera_lookat(client, self.scene_center)
|
| 173 |
+
|
| 174 |
+
@self.autoframe_btn.on_click
|
| 175 |
+
def _(event: viser.GuiEvent):
|
| 176 |
+
if event.client: self._auto_frame_by_depth(event.client)
|
| 177 |
+
|
| 178 |
+
@self.focus_btn.on_click
|
| 179 |
+
def _(event: viser.GuiEvent):
|
| 180 |
+
if event.client: self._focus_foreground(event.client)
|
| 181 |
+
|
| 182 |
+
@self.rebase_btn.on_click
|
| 183 |
+
def _(event: viser.GuiEvent):
|
| 184 |
+
client = event.client
|
| 185 |
+
if not client or self.focus_center is None:
|
| 186 |
+
print("[rebase] Run 'Focus foreground' first.")
|
| 187 |
+
return
|
| 188 |
+
delta = self.focus_center.astype(np.float32)
|
| 189 |
+
self.world_offset = self.world_offset + delta # accumulate translation
|
| 190 |
+
with client.atomic():
|
| 191 |
+
client.camera.position = (np.asarray(client.camera.position) - delta).astype(np.float32)
|
| 192 |
+
self.scene_center = np.zeros(3, dtype=np.float32)
|
| 193 |
+
print("[rebase] World recentered; new world_offset:", self.world_offset)
|
| 194 |
+
|
| 195 |
+
# ---------- On connect ----------
|
| 196 |
+
@self.server.on_client_connect
|
| 197 |
+
def _(client: viser.ClientHandle):
|
| 198 |
+
init = self.tr_cam_lst[0].c2w.cpu().numpy()
|
| 199 |
+
with client.atomic():
|
| 200 |
+
client.camera.wxyz = matrix2wxyz(init[:3, :3])
|
| 201 |
+
client.camera.position = init[:3, 3]
|
| 202 |
+
ok = self._auto_frame_by_depth(client, quiet=True)
|
| 203 |
+
if not ok:
|
| 204 |
+
self._camera_lookat(client, self.scene_center)
|
| 205 |
+
self.is_connected = True
|
| 206 |
+
|
| 207 |
+
# ---------- Download ----------
|
| 208 |
+
self.download_button = self.server.gui.add_button("Download view")
|
| 209 |
+
|
| 210 |
+
@self.download_button.on_click
|
| 211 |
+
def _(event: viser.GuiEvent):
|
| 212 |
+
im, _ = self.render_viser_camera(event.client.camera)
|
| 213 |
+
event.client.send_file_download(
|
| 214 |
+
"svraster_viser.png", iio.imwrite("<bytes>", im, extension=".png")
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# ---------------- camera utils ----------------
|
| 218 |
+
def _camera_lookat(
|
| 219 |
+
self,
|
| 220 |
+
client: viser.ClientHandle,
|
| 221 |
+
target: np.ndarray,
|
| 222 |
+
distance: Optional[float] = None,
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
Point the camera at `target` by writing orientation (wxyz) and position directly.
|
| 226 |
+
Compatible with Viser builds where camera.look_at is not callable.
|
| 227 |
+
"""
|
| 228 |
+
target = np.asarray(target, dtype=np.float32)
|
| 229 |
+
eye = np.asarray(client.camera.position, dtype=np.float32)
|
| 230 |
+
|
| 231 |
+
vec = eye - target # target -> eye
|
| 232 |
+
norm = np.linalg.norm(vec)
|
| 233 |
+
if not np.isfinite(norm) or norm < 1e-6:
|
| 234 |
+
vec = np.array([0, 0, 1.0], dtype=np.float32)
|
| 235 |
+
norm = 0.5
|
| 236 |
+
|
| 237 |
+
d = float(norm if distance is None else distance)
|
| 238 |
+
|
| 239 |
+
# Orthonormal basis that looks at target.
|
| 240 |
+
fwd = -(vec / max(norm, 1e-6)) # camera forward (eye->target)
|
| 241 |
+
up_guess = np.array([0, 1, 0], dtype=np.float32)
|
| 242 |
+
if abs(np.dot(fwd, up_guess)) > 0.99:
|
| 243 |
+
up_guess = np.array([1, 0, 0], dtype=np.float32)
|
| 244 |
+
right = np.cross(up_guess, fwd)
|
| 245 |
+
right /= max(np.linalg.norm(right), 1e-6)
|
| 246 |
+
up = np.cross(fwd, right)
|
| 247 |
+
up /= max(np.linalg.norm(up), 1e-6)
|
| 248 |
+
|
| 249 |
+
R = np.stack([right, up, fwd], axis=1).astype(np.float32)
|
| 250 |
+
new_pos = target - fwd * d
|
| 251 |
+
|
| 252 |
+
with client.atomic():
|
| 253 |
+
client.camera.wxyz = matrix2wxyz(R)
|
| 254 |
+
client.camera.position = new_pos
|
| 255 |
+
|
| 256 |
+
def _auto_frame_by_depth(self, client: viser.ClientHandle, quiet: bool = False) -> bool:
|
| 257 |
+
"""Render once, use center-pixel median depth to determine a good pivot."""
|
| 258 |
+
try:
|
| 259 |
+
_, _, depth_med = self.render_viser_camera(client.camera, return_depth=True)
|
| 260 |
+
except Exception as e:
|
| 261 |
+
if not quiet: print("[auto-frame] render error:", e)
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
+
H, W = depth_med.shape
|
| 265 |
+
d = float(depth_med[H // 2, W // 2])
|
| 266 |
+
if not np.isfinite(d) or d <= 0:
|
| 267 |
+
if not quiet: print("[auto-frame] invalid depth; falling back")
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
R = wxyz2matrix(client.camera.wxyz)
|
| 271 |
+
fwd = R @ np.array([0, 0, 1], dtype=np.float32)
|
| 272 |
+
target = np.asarray(client.camera.position, dtype=np.float32) + fwd * d
|
| 273 |
+
self._camera_lookat(client, target, distance=d)
|
| 274 |
+
if not quiet: print("[auto-frame] success; depth =", d)
|
| 275 |
+
return True
|
| 276 |
+
|
| 277 |
+
# ----------- Focus only the foreground object -----------
|
| 278 |
+
def _focus_foreground(self, client: viser.ClientHandle):
|
| 279 |
+
"""
|
| 280 |
+
Use alpha (1 - T) to mask foreground, keep closest depths,
|
| 281 |
+
back-project to world, compute tight AABB, center & fit view.
|
| 282 |
+
Stores self.focus_center so you can 'Recenter world to focus'.
|
| 283 |
+
"""
|
| 284 |
+
try:
|
| 285 |
+
_, _, depth_med, T = self.render_viser_camera(client.camera, return_depth=True, return_T=True)
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print("[focus] render error:", e)
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
alpha = 1.0 - T
|
| 291 |
+
thr = float(self.alpha_thr_slider.value)
|
| 292 |
+
mask = (alpha > thr) & np.isfinite(depth_med) & (depth_med > 0)
|
| 293 |
+
|
| 294 |
+
if mask.sum() < 50:
|
| 295 |
+
print("[focus] Not enough foreground; lower alpha_threshold or change view.")
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
# Keep only the closest K% pixels to drop the outer ring
|
| 299 |
+
K = float(self.keep_closest_slider.value)
|
| 300 |
+
dvals = depth_med[mask]
|
| 301 |
+
q = np.quantile(dvals, K)
|
| 302 |
+
mask &= depth_med <= q
|
| 303 |
+
if mask.sum() < 50:
|
| 304 |
+
print("[focus] Too few pixels after depth filtering; raise keep_closest_pct.")
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
# Back-project masked pixels to world
|
| 308 |
+
width = int(self.width_slider.value)
|
| 309 |
+
aspect = max(1e-6, float(client.camera.aspect))
|
| 310 |
+
height = max(1, int(round(width / aspect)))
|
| 311 |
+
fovx = np.deg2rad(float(self.fovx_slider.value))
|
| 312 |
+
fovy = fovx * height / max(width, 1)
|
| 313 |
+
|
| 314 |
+
fx = width / (2.0 * np.tan(fovx * 0.5))
|
| 315 |
+
fy = height / (2.0 * np.tan(fovy * 0.5))
|
| 316 |
+
cx, cy = (width - 1) / 2.0, (height - 1) / 2.0
|
| 317 |
+
|
| 318 |
+
ys, xs = np.where(mask)
|
| 319 |
+
zs = depth_med[ys, xs].astype(np.float32)
|
| 320 |
+
|
| 321 |
+
x_cam = (xs - cx) / fx * zs
|
| 322 |
+
y_cam = (ys - cy) / fy * zs
|
| 323 |
+
z_cam = zs
|
| 324 |
+
P_cam = np.stack([x_cam, y_cam, z_cam], axis=0) # (3, N)
|
| 325 |
+
|
| 326 |
+
R = wxyz2matrix(client.camera.wxyz)
|
| 327 |
+
t = np.asarray(client.camera.position, dtype=np.float32)[:, None]
|
| 328 |
+
# Apply current world rebase so P_world matches what we render
|
| 329 |
+
t = (t - self.world_offset[:, None]).astype(np.float32)
|
| 330 |
+
|
| 331 |
+
P_world = (R @ P_cam) + t # (3, N)
|
| 332 |
+
|
| 333 |
+
pmin = np.min(P_world, axis=1)
|
| 334 |
+
pmax = np.max(P_world, axis=1)
|
| 335 |
+
center = (pmin + pmax) * 0.5
|
| 336 |
+
extent = (pmax - pmin) * 0.5
|
| 337 |
+
|
| 338 |
+
# Save for rebase
|
| 339 |
+
self.focus_center = center.astype(np.float32)
|
| 340 |
+
|
| 341 |
+
# Choose distance that fits bbox into the view (larger FOV dimension)
|
| 342 |
+
fovx_deg = float(self.fovx_slider.value)
|
| 343 |
+
fovy_deg = fovx_deg * height / max(width, 1)
|
| 344 |
+
fov_rad = np.deg2rad(max(fovx_deg, fovy_deg))
|
| 345 |
+
radius = float(np.linalg.norm(extent, ord=np.inf))
|
| 346 |
+
dist = radius / np.tan(max(1e-4, fov_rad * 0.5)) * 1.25 # padding
|
| 347 |
+
|
| 348 |
+
# Update logical scene center for orbiting & go there
|
| 349 |
+
self.scene_center = center.astype(np.float32)
|
| 350 |
+
self._camera_lookat(client, self.scene_center, distance=dist)
|
| 351 |
+
|
| 352 |
+
print(f"[focus] bbox half-extent ~{extent}, distance {dist:.3f}")
|
| 353 |
+
|
| 354 |
+
# ---------------- rendering ----------------
|
| 355 |
+
@torch.no_grad()
|
| 356 |
+
def render_viser_camera(
|
| 357 |
+
self,
|
| 358 |
+
camera: viser.CameraHandle,
|
| 359 |
+
return_depth: bool = False,
|
| 360 |
+
return_T: bool = False,
|
| 361 |
+
):
|
| 362 |
+
width = int(self.width_slider.value)
|
| 363 |
+
aspect = max(1e-6, float(camera.aspect))
|
| 364 |
+
height = max(1, int(round(width / aspect)))
|
| 365 |
+
|
| 366 |
+
fovx_deg = float(self.fovx_slider.value)
|
| 367 |
+
fovy_deg = fovx_deg * height / max(width, 1)
|
| 368 |
+
near = float(self.near_slider.value)
|
| 369 |
+
|
| 370 |
+
c2w = np.eye(4, dtype=np.float32)
|
| 371 |
+
c2w[:3, :3] = wxyz2matrix(camera.wxyz)
|
| 372 |
+
c2w[:3, 3] = camera.position
|
| 373 |
+
# Apply world rebase: move the *world* by -offset equivalently by moving camera by -offset in world coords.
|
| 374 |
+
c2w[:3, 3] = c2w[:3, 3] - self.world_offset
|
| 375 |
+
|
| 376 |
+
minicam = MiniCam(
|
| 377 |
+
c2w, fovx=np.deg2rad(fovx_deg), fovy=np.deg2rad(fovy_deg),
|
| 378 |
+
width=width, height=height, near=near
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
self.voxel_model.active_sh_degree = int(self.active_sh_degree_slider.value)
|
| 382 |
+
|
| 383 |
+
render_opt = {
|
| 384 |
+
"ss": self.ss_slider.value,
|
| 385 |
+
"output_T": True,
|
| 386 |
+
"output_depth": True,
|
| 387 |
+
"output_normal": True,
|
| 388 |
+
}
|
| 389 |
+
if self.render_dropdown.value == "rgb only":
|
| 390 |
+
render_opt["output_depth"] = False; render_opt["output_normal"] = False
|
| 391 |
+
elif self.render_dropdown.value == "depth only":
|
| 392 |
+
render_opt["color_mode"] = "dontcare"; render_opt["output_normal"] = False
|
| 393 |
+
elif self.render_dropdown.value == "normal only":
|
| 394 |
+
render_opt["color_mode"] = "dontcare"; render_opt["output_depth"] = False
|
| 395 |
+
|
| 396 |
+
t0 = time.time()
|
| 397 |
+
try:
|
| 398 |
+
render_pkg = self.voxel_model.render(minicam, **render_opt)
|
| 399 |
+
except RuntimeError as e:
|
| 400 |
+
print("[render] RuntimeError:", e)
|
| 401 |
+
im = np.ones((height, width, 3), dtype=np.uint8) * 255
|
| 402 |
+
if return_depth and return_T:
|
| 403 |
+
depth_med = np.full((height, width), np.nan, dtype=np.float32)
|
| 404 |
+
T = np.ones((height, width), dtype=np.float32)
|
| 405 |
+
return im, 0.0, depth_med, T
|
| 406 |
+
if return_depth:
|
| 407 |
+
depth_med = np.full((height, width), np.nan, dtype=np.float32)
|
| 408 |
+
return im, 0.0, depth_med
|
| 409 |
+
if return_T:
|
| 410 |
+
T = np.ones((height, width), dtype=np.float32)
|
| 411 |
+
return im, 0.0, T
|
| 412 |
+
return im, 0.0
|
| 413 |
+
torch.cuda.synchronize()
|
| 414 |
+
eps = time.time() - t0
|
| 415 |
+
|
| 416 |
+
# choose output image
|
| 417 |
+
if self.output_dropdown.value == "dmean":
|
| 418 |
+
im = viz_tensordepth(render_pkg["depth"][0])
|
| 419 |
+
elif self.output_dropdown.value == "dmed":
|
| 420 |
+
im = viz_tensordepth(render_pkg["depth"][2])
|
| 421 |
+
elif self.output_dropdown.value == "dmean2n":
|
| 422 |
+
im = im_tensor2np(minicam.depth2normal(render_pkg["depth"][0]) * 0.5 + 0.5)
|
| 423 |
+
elif self.output_dropdown.value == "dmed2n":
|
| 424 |
+
im = im_tensor2np(minicam.depth2normal(render_pkg["depth"][2]) * 0.5 + 0.5)
|
| 425 |
+
elif self.output_dropdown.value == "n":
|
| 426 |
+
im = im_tensor2np(render_pkg["normal"] * 0.5 + 0.5)
|
| 427 |
+
elif self.output_dropdown.value == "alpha":
|
| 428 |
+
im = im_tensor2np(1 - render_pkg["T"].repeat(3, 1, 1))
|
| 429 |
+
else:
|
| 430 |
+
im = im_tensor2np(render_pkg["color"])
|
| 431 |
+
|
| 432 |
+
depth_med = render_pkg["depth"][2].detach().cpu().numpy()
|
| 433 |
+
T = render_pkg["T"].detach().cpu().numpy() # (H, W)
|
| 434 |
+
|
| 435 |
+
# Optional image-level masking to hide outside the focused object
|
| 436 |
+
if self.hide_outside_checkbox.value:
|
| 437 |
+
alpha = 1.0 - T
|
| 438 |
+
thr = float(self.alpha_thr_slider.value)
|
| 439 |
+
mask = (alpha > thr) & np.isfinite(depth_med) & (depth_med > 0)
|
| 440 |
+
if mask.any():
|
| 441 |
+
K = float(self.keep_closest_slider.value)
|
| 442 |
+
dvals = depth_med[mask]
|
| 443 |
+
q = np.quantile(dvals, K)
|
| 444 |
+
mask &= depth_med <= q
|
| 445 |
+
mask3 = np.repeat(mask[..., None], 3, axis=2)
|
| 446 |
+
bg = np.zeros_like(im) # black background
|
| 447 |
+
im = np.where(mask3, im, bg)
|
| 448 |
+
|
| 449 |
+
del render_pkg
|
| 450 |
+
|
| 451 |
+
if return_depth and return_T:
|
| 452 |
+
return im, eps, depth_med, T
|
| 453 |
+
if return_depth:
|
| 454 |
+
return im, eps, depth_med
|
| 455 |
+
if return_T:
|
| 456 |
+
return im, eps, T
|
| 457 |
+
return im, eps
|
| 458 |
+
|
| 459 |
+
# ---------------- server tick ----------------
|
| 460 |
+
def update(self):
|
| 461 |
+
if not self.is_connected:
|
| 462 |
+
return
|
| 463 |
+
times = []
|
| 464 |
+
for client in self.server.get_clients().values():
|
| 465 |
+
im, eps = self.render_viser_camera(client.camera)
|
| 466 |
+
times.append(eps)
|
| 467 |
+
client.scene.set_background_image(im, format="jpeg")
|
| 468 |
+
if times:
|
| 469 |
+
self.fps.value = f"{round(1 / np.mean(times)):4d}"
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
if __name__ == "__main__":
|
| 473 |
+
import os, time
|
| 474 |
+
|
| 475 |
+
class Args:
|
| 476 |
+
model_path = "Entimus_imperialis_out_model/2025-1008-1320-c3c8c5"
|
| 477 |
+
iteration = -1
|
| 478 |
+
port = 7860 # Hugging Face default port
|
| 479 |
+
|
| 480 |
+
args = Args()
|
| 481 |
+
print(f"[INFO] Launching SVRaster viewer on Hugging Face...")
|
| 482 |
+
print(f"[INFO] Model path: {args.model_path}")
|
| 483 |
+
|
| 484 |
+
update_config(os.path.join(args.model_path, "config.yaml"))
|
| 485 |
+
cfg.port = args.port
|
| 486 |
+
|
| 487 |
+
svraster_viewer = SVRasterViewer(cfg)
|
| 488 |
+
|
| 489 |
+
# Keep process alive so Hugging Face doesn't stop it
|
| 490 |
+
while True:
|
| 491 |
+
svraster_viewer.update()
|
| 492 |
+
time.sleep(0.01)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
torchaudio
|
| 4 |
+
numpy
|
| 5 |
+
scipy
|
| 6 |
+
imageio
|
| 7 |
+
open3d
|
| 8 |
+
trimesh
|
| 9 |
+
matplotlib
|
| 10 |
+
Pillow
|
| 11 |
+
tqdm
|
| 12 |
+
huggingface_hub
|
| 13 |
+
viser==0.1.30
|
| 14 |
+
gradio==5.2.0
|
src/__pycache__/cameras.cpython-39.pyc
ADDED
|
Binary file (8.98 kB). View file
|
|
|
src/__pycache__/config.cpython-38.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
src/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
src/__pycache__/sparse_voxel_model.cpython-39.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
src/cameras.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import svraster_cuda
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CameraBase:
|
| 16 |
+
|
| 17 |
+
'''
|
| 18 |
+
Base class of perspective cameras.
|
| 19 |
+
'''
|
| 20 |
+
|
| 21 |
+
def __repr__(self):
|
| 22 |
+
clsname = self.__class__.__name__
|
| 23 |
+
fname = f"image_name='{self.image_name}'"
|
| 24 |
+
res = f"HW=({self.image_height}x{self.image_width})"
|
| 25 |
+
fov = f"fovx={np.rad2deg(self.fovx):.1f}deg"
|
| 26 |
+
return f"{clsname}({fname}, {res}, {fov})"
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def lookat(self):
|
| 30 |
+
return self.c2w[:3, 2]
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def position(self):
|
| 34 |
+
return self.c2w[:3, 3]
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def down(self):
|
| 38 |
+
return self.c2w[:3, 1]
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def right(self):
|
| 42 |
+
return self.c2w[:3, 0]
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def cx(self):
|
| 46 |
+
return self.image_width * self.cx_p
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def cy(self):
|
| 50 |
+
return self.image_height * self.cy_p
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def pix_size(self):
|
| 54 |
+
return 2 * self.tanfovx / self.image_width
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def tanfovx(self):
|
| 58 |
+
return np.tan(self.fovx * 0.5)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def tanfovy(self):
|
| 62 |
+
return np.tan(self.fovy * 0.5)
|
| 63 |
+
|
| 64 |
+
def compute_rd(self, wh=None, cxcy=None, device=None):
|
| 65 |
+
'''Ray directions in world space.'''
|
| 66 |
+
if wh is None:
|
| 67 |
+
wh = (self.image_width, self.image_height)
|
| 68 |
+
if cxcy is None:
|
| 69 |
+
cxcy = (self.cx * wh[0] / self.image_width, self.cy * wh[1] / self.image_height)
|
| 70 |
+
rd = svraster_cuda.utils.compute_rd(
|
| 71 |
+
width=wh[0], height=wh[1],
|
| 72 |
+
cx=cxcy[0], cy=cxcy[1],
|
| 73 |
+
tanfovx=self.tanfovx, tanfovy=self.tanfovy,
|
| 74 |
+
c2w_matrix=self.c2w.cuda())
|
| 75 |
+
rd = rd.to(device if device is None else self.c2w.device)
|
| 76 |
+
return rd
|
| 77 |
+
|
| 78 |
+
def project(self, pts, return_depth=False):
|
| 79 |
+
# Return normalized image coordinate in [-1, 1]
|
| 80 |
+
cam_pts = pts @ self.w2c[:3, :3].T + self.w2c[:3, 3]
|
| 81 |
+
depth = cam_pts[:, [2]]
|
| 82 |
+
cam_uv = cam_pts[:, :2] / depth
|
| 83 |
+
scale_x = 1 / self.tanfovx
|
| 84 |
+
scale_y = 1 / self.tanfovy
|
| 85 |
+
shift_x = 2 * self.cx_p - 1
|
| 86 |
+
shift_y = 2 * self.cy_p - 1
|
| 87 |
+
cam_uv[:, 0] = cam_uv[:, 0] * scale_x + shift_x
|
| 88 |
+
cam_uv[:, 1] = cam_uv[:, 1] * scale_y + shift_y
|
| 89 |
+
if return_depth:
|
| 90 |
+
return cam_uv, depth
|
| 91 |
+
return cam_uv
|
| 92 |
+
|
| 93 |
+
def depth2pts(self, depth):
|
| 94 |
+
device = depth.device
|
| 95 |
+
h, w = depth.shape[-2:]
|
| 96 |
+
rd = self.compute_rd(wh=(w, h), device=device)
|
| 97 |
+
return self.position.view(3,1,1).to(device) + rd * depth
|
| 98 |
+
|
| 99 |
+
def depth2normal(self, depth, ks=3, tol_cos=-1):
|
| 100 |
+
assert ks % 2 == 1
|
| 101 |
+
pad = ks // 2
|
| 102 |
+
ks_1 = ks - 1
|
| 103 |
+
pts = self.depth2pts(depth)
|
| 104 |
+
normal_pseudo = torch.zeros_like(pts)
|
| 105 |
+
dx = pts[:, pad:-pad, ks_1:] - pts[:, pad:-pad, :-ks_1]
|
| 106 |
+
dy = pts[:, ks_1:, pad:-pad] - pts[:, :-ks_1, pad:-pad]
|
| 107 |
+
normal_pseudo[:, pad:-pad, pad:-pad] = torch.nn.functional.normalize(torch.cross(dx, dy, dim=0), dim=0)
|
| 108 |
+
|
| 109 |
+
if tol_cos > 0:
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
pts_dir = torch.nn.functional.normalize(pts - self.position.view(3,1,1), dim=0)
|
| 112 |
+
dot = (normal_pseudo * pts_dir).sum(0)
|
| 113 |
+
mask = (dot > tol_cos)
|
| 114 |
+
normal_pseudo = normal_pseudo * mask
|
| 115 |
+
|
| 116 |
+
return normal_pseudo
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class Camera(CameraBase):
|
| 120 |
+
def __init__(
|
| 121 |
+
self, image_name,
|
| 122 |
+
w2c, fovx, fovy, cx_p, cy_p,
|
| 123 |
+
near=0.02,
|
| 124 |
+
image=None, mask=None, depth=None,
|
| 125 |
+
sparse_pt=None):
|
| 126 |
+
|
| 127 |
+
self.image_name = image_name
|
| 128 |
+
|
| 129 |
+
# Camera parameters
|
| 130 |
+
self.w2c = torch.tensor(w2c, dtype=torch.float32, device="cuda")
|
| 131 |
+
self.c2w = self.w2c.inverse().contiguous()
|
| 132 |
+
|
| 133 |
+
self.fovx = fovx
|
| 134 |
+
self.fovy = fovy
|
| 135 |
+
|
| 136 |
+
# Load frame
|
| 137 |
+
self.image = image.cpu()
|
| 138 |
+
|
| 139 |
+
# Other camera parameters
|
| 140 |
+
self.image_width = self.image.shape[2]
|
| 141 |
+
self.image_height = self.image.shape[1]
|
| 142 |
+
self.cx_p = (0.5 if cx_p is None else cx_p)
|
| 143 |
+
self.cy_p = (0.5 if cy_p is None else cy_p)
|
| 144 |
+
self.near = near
|
| 145 |
+
|
| 146 |
+
# Load mask and depth if there are
|
| 147 |
+
self.mask = mask.cpu() if mask is not None else None
|
| 148 |
+
self.depth = depth.cpu() if depth is not None else None
|
| 149 |
+
|
| 150 |
+
# Load sparse depth
|
| 151 |
+
if sparse_pt is not None:
|
| 152 |
+
self.sparse_pt = torch.tensor(sparse_pt, dtype=torch.float32, device="cpu")
|
| 153 |
+
else:
|
| 154 |
+
self.sparse_pt = None
|
| 155 |
+
|
| 156 |
+
def to(self, device):
|
| 157 |
+
self.image = self.image.to(device)
|
| 158 |
+
if self.mask is not None:
|
| 159 |
+
self.mask = self.mask.to(device)
|
| 160 |
+
if self.depth is not None:
|
| 161 |
+
self.depth = self.depth.to(device)
|
| 162 |
+
return self
|
| 163 |
+
|
| 164 |
+
def auto_exposure_init(self):
|
| 165 |
+
self._exposure_A = torch.eye(3, dtype=torch.float32, device="cuda")
|
| 166 |
+
self._exposure_t = torch.zeros([3,1,1], dtype=torch.float32, device="cuda")
|
| 167 |
+
self.exposure_updated = False
|
| 168 |
+
|
| 169 |
+
def auto_exposure_apply(self, image):
|
| 170 |
+
if self.exposure_updated:
|
| 171 |
+
image = torch.einsum('ij,jhw->ihw', self._exposure_A, image) + self._exposure_t
|
| 172 |
+
return image
|
| 173 |
+
|
| 174 |
+
def auto_exposure_update(self, ren, ref):
|
| 175 |
+
self.exposure_updated = True
|
| 176 |
+
self._exposure_A.requires_grad_()
|
| 177 |
+
self._exposure_t.requires_grad_()
|
| 178 |
+
optim = torch.optim.Adam([self._exposure_A, self._exposure_t], lr=1e-3)
|
| 179 |
+
for _ in range(100):
|
| 180 |
+
loss = (self.auto_exposure_apply(ren).clamp(0, 1) - ref).abs().mean()
|
| 181 |
+
loss.backward()
|
| 182 |
+
optim.step()
|
| 183 |
+
optim.zero_grad(set_to_none=True)
|
| 184 |
+
self._exposure_A.requires_grad_(False)
|
| 185 |
+
self._exposure_t.requires_grad_(False)
|
| 186 |
+
|
| 187 |
+
def clone_mini(self):
|
| 188 |
+
return MiniCam(
|
| 189 |
+
c2w=self.c2w.clone(),
|
| 190 |
+
fovx=self.fovx, fovy=self.fovy,
|
| 191 |
+
width=self.image_width, height=self.image_height,
|
| 192 |
+
near=self.near,
|
| 193 |
+
cx_p=self.cx_p, cy_p=self.cy_p)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MiniCam(CameraBase):
|
| 197 |
+
def __init__(self,
|
| 198 |
+
c2w, fovx, fovy,
|
| 199 |
+
width, height,
|
| 200 |
+
near=0.02,
|
| 201 |
+
cx_p=None, cy_p=None,
|
| 202 |
+
image_name="minicam"):
|
| 203 |
+
|
| 204 |
+
self.image_name = image_name
|
| 205 |
+
self.c2w = torch.tensor(c2w).clone().cuda()
|
| 206 |
+
self.w2c = self.c2w.inverse()
|
| 207 |
+
|
| 208 |
+
self.fovx = fovx
|
| 209 |
+
self.fovy = fovy
|
| 210 |
+
self.image_width = width
|
| 211 |
+
self.image_height = height
|
| 212 |
+
self.cx_p = (0.5 if cx_p is None else cx_p)
|
| 213 |
+
self.cy_p = (0.5 if cy_p is None else cy_p)
|
| 214 |
+
self.near = near
|
| 215 |
+
|
| 216 |
+
self.depth = None
|
| 217 |
+
self.mask = None
|
| 218 |
+
|
| 219 |
+
def clone_mini(self):
|
| 220 |
+
return MiniCam(
|
| 221 |
+
c2w=self.c2w.clone(),
|
| 222 |
+
fovx=self.fovx, fovy=self.fovy,
|
| 223 |
+
width=self.image_width, height=self.image_height,
|
| 224 |
+
near=self.near,
|
| 225 |
+
cx_p=self.cx_p, cy_p=self.cy_p)
|
| 226 |
+
|
| 227 |
+
def move_forward(self, dist):
|
| 228 |
+
new_position = self.position + dist * self.lookat
|
| 229 |
+
self.c2w[:3, 3] = new_position
|
| 230 |
+
self.w2c = self.c2w.inverse()
|
| 231 |
+
return self
|
| 232 |
+
|
| 233 |
+
def move_up(self, dist):
|
| 234 |
+
return self.move_down(-dist)
|
| 235 |
+
|
| 236 |
+
def move_down(self, dist):
|
| 237 |
+
new_position = self.position + dist * self.down
|
| 238 |
+
self.c2w[:3, 3] = new_position
|
| 239 |
+
self.w2c = self.c2w.inverse()
|
| 240 |
+
return self
|
| 241 |
+
|
| 242 |
+
def move_right(self, dist):
|
| 243 |
+
new_position = self.position + dist * self.right
|
| 244 |
+
self.c2w[:3, 3] = new_position
|
| 245 |
+
self.w2c = self.c2w.inverse()
|
| 246 |
+
return self
|
| 247 |
+
|
| 248 |
+
def move_left(self, dist):
|
| 249 |
+
return self.move_right(-dist)
|
| 250 |
+
|
| 251 |
+
def rotate(self, R):
|
| 252 |
+
self.c2w[:3, :3] = (R @ self.w2c[:3, :3]).T
|
| 253 |
+
self.w2c = self.c2w.inverse()
|
| 254 |
+
return self
|
| 255 |
+
|
| 256 |
+
def rotate_x(self, rad=None, deg=None):
|
| 257 |
+
assert rad is None or deg is None, "Can only specify rotation by either rad or deg."
|
| 258 |
+
if rad is None:
|
| 259 |
+
rad = np.deg2rad(deg)
|
| 260 |
+
R = torch.tensor([
|
| 261 |
+
[1, 0, 0],
|
| 262 |
+
[0, np.cos(rad), -np.sin(rad)],
|
| 263 |
+
[0, np.sin(rad), np.cos(rad)],
|
| 264 |
+
], dtype=torch.float32, device="cuda")
|
| 265 |
+
return self.rotate(R)
|
| 266 |
+
|
| 267 |
+
def rotate_y(self, rad=None, deg=None):
|
| 268 |
+
assert rad is None or deg is None, "Can only specify rotation by either rad or deg."
|
| 269 |
+
if rad is None:
|
| 270 |
+
rad = np.deg2rad(deg)
|
| 271 |
+
R = torch.tensor([
|
| 272 |
+
[np.cos(rad), 0, -np.sin(rad)],
|
| 273 |
+
[0, 1, 0],
|
| 274 |
+
[np.sin(rad), 0, np.cos(rad)],
|
| 275 |
+
], dtype=torch.float32, device="cuda")
|
| 276 |
+
return self.rotate(R)
|
| 277 |
+
|
| 278 |
+
def rotate_z(self, rad=None, deg=None):
|
| 279 |
+
assert rad is None or deg is None, "Can only specify rotation by either rad or deg."
|
| 280 |
+
if rad is None:
|
| 281 |
+
rad = np.deg2rad(deg)
|
| 282 |
+
R = torch.tensor([
|
| 283 |
+
[np.cos(rad), -np.sin(rad), 0],
|
| 284 |
+
[np.sin(rad), np.cos(rad), 0],
|
| 285 |
+
[0, 0, 1],
|
| 286 |
+
], dtype=torch.float32, device="cuda")
|
| 287 |
+
return self.rotate(R)
|
src/config.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
from yacs.config import CfgNode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
cfg = CfgNode()
|
| 14 |
+
|
| 15 |
+
cfg.model = CfgNode(dict(
|
| 16 |
+
n_samp_per_vox = 1, # Number of sampled points per visited voxel
|
| 17 |
+
sh_degree = 3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
|
| 18 |
+
ss = 1.5, # Super-sampling rates for anti-aliasing
|
| 19 |
+
white_background = False, # Assum white background
|
| 20 |
+
black_background = False, # Assum black background
|
| 21 |
+
))
|
| 22 |
+
|
| 23 |
+
cfg.data = CfgNode(dict(
|
| 24 |
+
source_path = "",
|
| 25 |
+
image_dir_name = "images",
|
| 26 |
+
mask_dir_name = "masks",
|
| 27 |
+
res_downscale = 0.,
|
| 28 |
+
res_width = 0,
|
| 29 |
+
skip_blend_alpha = False,
|
| 30 |
+
data_device = "cpu",
|
| 31 |
+
eval = False,
|
| 32 |
+
test_every = 8,
|
| 33 |
+
))
|
| 34 |
+
|
| 35 |
+
cfg.bounding = CfgNode(dict(
|
| 36 |
+
# Define the main (inside) region bounding box
|
| 37 |
+
# The default use the suggested bounding if given by dataset.
|
| 38 |
+
# Otherwise, it automatically chose from forward or camera_median modes.
|
| 39 |
+
# See src/utils/bounding_utils.py for details.
|
| 40 |
+
|
| 41 |
+
# default | camera_median | camera_max | forward | pcd
|
| 42 |
+
bound_mode = "default",
|
| 43 |
+
bound_scale = 1.0, # Scaling factor of the bound
|
| 44 |
+
forward_dist_scale = 1.0, # For forward mode
|
| 45 |
+
pcd_density_rate = 0.1, # For pcd mode
|
| 46 |
+
|
| 47 |
+
# Number of Octree level outside the main foreground region
|
| 48 |
+
outside_level = 5,
|
| 49 |
+
))
|
| 50 |
+
|
| 51 |
+
cfg.optimizer = CfgNode(dict(
|
| 52 |
+
geo_lr = 0.025,
|
| 53 |
+
sh0_lr = 0.010,
|
| 54 |
+
shs_lr = 0.00025,
|
| 55 |
+
|
| 56 |
+
optim_beta1 = 0.1,
|
| 57 |
+
optim_beta2 = 0.99,
|
| 58 |
+
optim_eps = 1e-15,
|
| 59 |
+
|
| 60 |
+
lr_decay_ckpt = [19000],
|
| 61 |
+
lr_decay_mult = 0.1,
|
| 62 |
+
))
|
| 63 |
+
|
| 64 |
+
cfg.regularizer = CfgNode(dict(
|
| 65 |
+
# Main photometric loss
|
| 66 |
+
lambda_photo = 1.0,
|
| 67 |
+
use_l1 = False,
|
| 68 |
+
use_huber = False,
|
| 69 |
+
huber_thres = 0.03,
|
| 70 |
+
|
| 71 |
+
# SSIM loss
|
| 72 |
+
lambda_ssim = 0.02,
|
| 73 |
+
|
| 74 |
+
# Sparse depth loss
|
| 75 |
+
lambda_sparse_depth = 0.0,
|
| 76 |
+
sparse_depth_until = 10_000,
|
| 77 |
+
|
| 78 |
+
# Mask loss
|
| 79 |
+
lambda_mask = 0.0,
|
| 80 |
+
|
| 81 |
+
# Depthanything loss
|
| 82 |
+
lambda_depthanythingv2 = 0.0,
|
| 83 |
+
depthanythingv2_from = 3000,
|
| 84 |
+
depthanythingv2_end = 20000,
|
| 85 |
+
depthanythingv2_end_mult = 0.1,
|
| 86 |
+
|
| 87 |
+
# Mast3r metrid loss
|
| 88 |
+
lambda_mast3r_metric_depth = 0.0,
|
| 89 |
+
mast3r_repo_path = '',
|
| 90 |
+
mast3r_metric_depth_from = 0,
|
| 91 |
+
mast3r_metric_depth_end = 20000,
|
| 92 |
+
mast3r_metric_depth_end_mult = 0.01,
|
| 93 |
+
|
| 94 |
+
# Final transmittance should concentrate to either 0 or 1
|
| 95 |
+
lambda_T_concen = 0.0,
|
| 96 |
+
|
| 97 |
+
# Final transmittance should be 0
|
| 98 |
+
lambda_T_inside = 0.0,
|
| 99 |
+
|
| 100 |
+
# Per-point rgb loss
|
| 101 |
+
lambda_R_concen = 0.01,
|
| 102 |
+
|
| 103 |
+
# Geometric regularization
|
| 104 |
+
lambda_ascending = 0.0,
|
| 105 |
+
ascending_from = 0,
|
| 106 |
+
|
| 107 |
+
# Distortion loss (encourage distribution concentration on ray)
|
| 108 |
+
lambda_dist = 0.1,
|
| 109 |
+
dist_from = 10000,
|
| 110 |
+
|
| 111 |
+
# Consistency loss of rendered normal and derived normal from expected depth
|
| 112 |
+
lambda_normal_dmean = 0.0,
|
| 113 |
+
n_dmean_from = 10_000,
|
| 114 |
+
n_dmean_end = 20_000,
|
| 115 |
+
n_dmean_ks = 3,
|
| 116 |
+
n_dmean_tol_deg = 90.0,
|
| 117 |
+
|
| 118 |
+
# Consistency loss of rendered normal and derived normal from median depth
|
| 119 |
+
lambda_normal_dmed = 0.0,
|
| 120 |
+
n_dmed_from=3000,
|
| 121 |
+
n_dmed_end=20_000,
|
| 122 |
+
|
| 123 |
+
# Total variation loss of density grid
|
| 124 |
+
lambda_tv_density = 1e-10,
|
| 125 |
+
tv_from = 0,
|
| 126 |
+
tv_until = 10000,
|
| 127 |
+
|
| 128 |
+
# Data augmentation
|
| 129 |
+
ss_aug_max = 1.5,
|
| 130 |
+
rand_bg = False,
|
| 131 |
+
))
|
| 132 |
+
|
| 133 |
+
cfg.init = CfgNode(dict(
|
| 134 |
+
# Voxel property initialization
|
| 135 |
+
geo_init = -10.0,
|
| 136 |
+
sh0_init = 0.5,
|
| 137 |
+
shs_init = 0.0,
|
| 138 |
+
|
| 139 |
+
sh_degree_init = 3,
|
| 140 |
+
|
| 141 |
+
# Init main inside region by dense voxels
|
| 142 |
+
init_n_level = 6, # (2^6)^3 voxels
|
| 143 |
+
|
| 144 |
+
# Number of voxel ratio for outside (background region)
|
| 145 |
+
init_out_ratio = 2.0,
|
| 146 |
+
))
|
| 147 |
+
|
| 148 |
+
cfg.procedure = CfgNode(dict(
|
| 149 |
+
# Schedule
|
| 150 |
+
n_iter = 20_000,
|
| 151 |
+
sche_mult = 1.0,
|
| 152 |
+
seed=3721,
|
| 153 |
+
|
| 154 |
+
# Reset sh
|
| 155 |
+
reset_sh_ckpt = [-1],
|
| 156 |
+
|
| 157 |
+
# Adaptive general setup
|
| 158 |
+
adapt_from = 1000,
|
| 159 |
+
adapt_every = 1000,
|
| 160 |
+
|
| 161 |
+
# Adaptive voxel pruning
|
| 162 |
+
prune_until = 18000,
|
| 163 |
+
prune_thres_init = 0.0001,
|
| 164 |
+
prune_thres_final = 0.05,
|
| 165 |
+
|
| 166 |
+
# Adaptive voxel pruning
|
| 167 |
+
subdivide_until = 15000,
|
| 168 |
+
subdivide_all_until = 0,
|
| 169 |
+
subdivide_samp_thres = 1.0, # A voxel max sampling rate should larger than this.
|
| 170 |
+
subdivide_prop = 0.05,
|
| 171 |
+
subdivide_max_num = 10_000_000,
|
| 172 |
+
))
|
| 173 |
+
|
| 174 |
+
cfg.auto_exposure = CfgNode(dict(
|
| 175 |
+
enable = False,
|
| 176 |
+
auto_exposure_upd_ckpt = [5000, 10000, 15000]
|
| 177 |
+
))
|
| 178 |
+
|
| 179 |
+
for i_cfg in cfg.values():
|
| 180 |
+
i_cfg.set_new_allowed(True)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def everytype2bool(v):
|
| 184 |
+
if v.isnumeric():
|
| 185 |
+
return bool(int(v))
|
| 186 |
+
v = v.lower()
|
| 187 |
+
if v in ['n', 'no', 'none', 'false']:
|
| 188 |
+
return False
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def update_argparser(parser):
|
| 193 |
+
for name in cfg.keys():
|
| 194 |
+
group = parser.add_argument_group(name)
|
| 195 |
+
for key, value in getattr(cfg, name).items():
|
| 196 |
+
t = type(value)
|
| 197 |
+
|
| 198 |
+
if t == bool:
|
| 199 |
+
group.add_argument(f"--{key}", action='store_true' if t else 'store_false')
|
| 200 |
+
elif t == list:
|
| 201 |
+
group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs="*")
|
| 202 |
+
elif t == tuple:
|
| 203 |
+
group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs=len(value))
|
| 204 |
+
else:
|
| 205 |
+
group.add_argument(f"--{key}", default=value, type=t)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def update_config(cfg_files, cmd_lst=[]):
|
| 209 |
+
# Update from config files
|
| 210 |
+
if isinstance(cfg_files, str):
|
| 211 |
+
cfg_files = [cfg_files]
|
| 212 |
+
for cfg_path in cfg_files:
|
| 213 |
+
cfg.merge_from_file(cfg_path)
|
| 214 |
+
|
| 215 |
+
if len(cmd_lst) == 0:
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
# Parse the arguments from command line
|
| 219 |
+
internal_parser = argparse.ArgumentParser()
|
| 220 |
+
update_argparser(internal_parser)
|
| 221 |
+
internal_args = internal_parser.parse_args(cmd_lst)
|
| 222 |
+
|
| 223 |
+
# Update from command line args
|
| 224 |
+
for name in cfg.keys():
|
| 225 |
+
cfg_subgroup = getattr(cfg, name)
|
| 226 |
+
for key in cfg_subgroup.keys():
|
| 227 |
+
arg_val = getattr(internal_args, key)
|
| 228 |
+
# Check if the default values is updated
|
| 229 |
+
if internal_parser.get_default(key) != arg_val:
|
| 230 |
+
cfg_subgroup[key] = arg_val
|
src/config_old.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
from yacs.config import CfgNode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
cfg = CfgNode()
|
| 14 |
+
|
| 15 |
+
cfg.model = CfgNode(dict(
|
| 16 |
+
n_samp_per_vox = 1, # Number of sampled points per visited voxel
|
| 17 |
+
sh_degree = 3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
|
| 18 |
+
ss = 1.5, # Super-sampling rates for anti-aliasing
|
| 19 |
+
white_background = False, # Assum white background
|
| 20 |
+
black_background = False, # Assum black background
|
| 21 |
+
))
|
| 22 |
+
|
| 23 |
+
cfg.data = CfgNode(dict(
|
| 24 |
+
source_path = "",
|
| 25 |
+
image_dir_name = "images",
|
| 26 |
+
res_downscale = 0.,
|
| 27 |
+
res_width = 0,
|
| 28 |
+
skip_blend_alpha = False,
|
| 29 |
+
data_device = "cpu",
|
| 30 |
+
eval = False,
|
| 31 |
+
test_every = 8,
|
| 32 |
+
alpha_is_white = True,
|
| 33 |
+
))
|
| 34 |
+
|
| 35 |
+
cfg.bounding = CfgNode(dict(
|
| 36 |
+
# Define the main (inside) region bounding box
|
| 37 |
+
# The default use the suggested bounding if given by dataset.
|
| 38 |
+
# Otherwise, it automatically chose from forward or camera_median modes.
|
| 39 |
+
# See src/utils/bounding_utils.py for details.
|
| 40 |
+
|
| 41 |
+
# default | camera_median | camera_max | forward | pcd
|
| 42 |
+
bound_mode = "default",
|
| 43 |
+
bound_scale = 1.0, # Scaling factor of the bound
|
| 44 |
+
forward_dist_scale = 1.0, # For forward mode
|
| 45 |
+
pcd_density_rate = 0.1, # For pcd mode
|
| 46 |
+
|
| 47 |
+
# Number of Octree level outside the main foreground region
|
| 48 |
+
outside_level = 5,
|
| 49 |
+
))
|
| 50 |
+
|
| 51 |
+
cfg.optimizer = CfgNode(dict(
|
| 52 |
+
geo_lr = 0.025,
|
| 53 |
+
sh0_lr = 0.010,
|
| 54 |
+
shs_lr = 0.00025,
|
| 55 |
+
|
| 56 |
+
optim_beta1 = 0.1,
|
| 57 |
+
optim_beta2 = 0.99,
|
| 58 |
+
optim_eps = 1e-15,
|
| 59 |
+
|
| 60 |
+
lr_decay_ckpt = [19000],
|
| 61 |
+
lr_decay_mult = 0.1,
|
| 62 |
+
))
|
| 63 |
+
|
| 64 |
+
cfg.regularizer = CfgNode(dict(
|
| 65 |
+
# Main photometric loss
|
| 66 |
+
lambda_photo = 1.0,
|
| 67 |
+
use_l1 = False,
|
| 68 |
+
use_huber = False,
|
| 69 |
+
huber_thres = 0.03,
|
| 70 |
+
|
| 71 |
+
# SSIM loss
|
| 72 |
+
lambda_ssim = 0.02,
|
| 73 |
+
|
| 74 |
+
# Sparse depth loss
|
| 75 |
+
lambda_sparse_depth = 0.0,
|
| 76 |
+
sparse_depth_until = 10_000,
|
| 77 |
+
|
| 78 |
+
# Mask loss
|
| 79 |
+
lambda_mask = 0.0,
|
| 80 |
+
|
| 81 |
+
# Depthanything loss
|
| 82 |
+
lambda_depthanythingv2 = 0.0,
|
| 83 |
+
depthanythingv2_from = 3000,
|
| 84 |
+
depthanythingv2_end = 20000,
|
| 85 |
+
depthanythingv2_end_mult = 0.1,
|
| 86 |
+
|
| 87 |
+
# Mast3r metrid loss
|
| 88 |
+
lambda_mast3r_metric_depth = 0.0,
|
| 89 |
+
mast3r_repo_path = '',
|
| 90 |
+
mast3r_metric_depth_from = 0,
|
| 91 |
+
mast3r_metric_depth_end = 20000,
|
| 92 |
+
mast3r_metric_depth_end_mult = 0.01,
|
| 93 |
+
|
| 94 |
+
# Final transmittance should concentrate to either 0 or 1
|
| 95 |
+
lambda_T_concen = 0.0,
|
| 96 |
+
|
| 97 |
+
# Final transmittance should be 0
|
| 98 |
+
lambda_T_inside = 0.0,
|
| 99 |
+
|
| 100 |
+
# Per-point rgb loss
|
| 101 |
+
lambda_R_concen = 0.01,
|
| 102 |
+
|
| 103 |
+
# Geometric regularization
|
| 104 |
+
lambda_ascending = 0.0,
|
| 105 |
+
ascending_from = 0,
|
| 106 |
+
|
| 107 |
+
# Distortion loss (encourage distribution concentration on ray)
|
| 108 |
+
lambda_dist = 0.1,
|
| 109 |
+
dist_from = 10000,
|
| 110 |
+
|
| 111 |
+
# Consistency loss of rendered normal and derived normal from expected depth
|
| 112 |
+
lambda_normal_dmean = 0.0,
|
| 113 |
+
n_dmean_from = 10_000,
|
| 114 |
+
n_dmean_end = 20_000,
|
| 115 |
+
n_dmean_ks = 3,
|
| 116 |
+
n_dmean_tol_deg = 90.0,
|
| 117 |
+
|
| 118 |
+
# Consistency loss of rendered normal and derived normal from median depth
|
| 119 |
+
lambda_normal_dmed = 0.0,
|
| 120 |
+
n_dmed_from=3000,
|
| 121 |
+
n_dmed_end=20_000,
|
| 122 |
+
|
| 123 |
+
# Total variation loss of density grid
|
| 124 |
+
lambda_tv_density = 1e-10,
|
| 125 |
+
tv_from = 0,
|
| 126 |
+
tv_until = 10000,
|
| 127 |
+
|
| 128 |
+
# Data augmentation
|
| 129 |
+
ss_aug_max = 1.5,
|
| 130 |
+
rand_bg = False,
|
| 131 |
+
))
|
| 132 |
+
|
| 133 |
+
cfg.init = CfgNode(dict(
|
| 134 |
+
# Voxel property initialization
|
| 135 |
+
geo_init = -10.0,
|
| 136 |
+
sh0_init = 0.5,
|
| 137 |
+
shs_init = 0.0,
|
| 138 |
+
|
| 139 |
+
sh_degree_init = 3,
|
| 140 |
+
|
| 141 |
+
# Init main inside region by dense voxels
|
| 142 |
+
init_n_level = 6, # (2^6)^3 voxels
|
| 143 |
+
|
| 144 |
+
# Number of voxel ratio for outside (background region)
|
| 145 |
+
init_out_ratio = 2.0,
|
| 146 |
+
))
|
| 147 |
+
|
| 148 |
+
cfg.procedure = CfgNode(dict(
|
| 149 |
+
# Schedule
|
| 150 |
+
n_iter = 20_000,
|
| 151 |
+
sche_mult = 1.0,
|
| 152 |
+
seed=3721,
|
| 153 |
+
|
| 154 |
+
# Reset sh
|
| 155 |
+
reset_sh_ckpt = [-1],
|
| 156 |
+
|
| 157 |
+
# Adaptive general setup
|
| 158 |
+
adapt_from = 1000,
|
| 159 |
+
adapt_every = 1000,
|
| 160 |
+
|
| 161 |
+
# Adaptive voxel pruning
|
| 162 |
+
prune_until = 18000,
|
| 163 |
+
prune_thres_init = 0.0001,
|
| 164 |
+
prune_thres_final = 0.05,
|
| 165 |
+
|
| 166 |
+
# Adaptive voxel pruning
|
| 167 |
+
subdivide_until = 15000,
|
| 168 |
+
subdivide_all_until = 0,
|
| 169 |
+
subdivide_samp_thres = 1.0, # A voxel max sampling rate should larger than this.
|
| 170 |
+
subdivide_prop = 0.05,
|
| 171 |
+
subdivide_max_num = 10_000_000,
|
| 172 |
+
))
|
| 173 |
+
|
| 174 |
+
cfg.auto_exposure = CfgNode(dict(
|
| 175 |
+
enable = False,
|
| 176 |
+
auto_exposure_upd_ckpt = [5000, 10000, 15000]
|
| 177 |
+
))
|
| 178 |
+
|
| 179 |
+
for i_cfg in cfg.values():
|
| 180 |
+
i_cfg.set_new_allowed(True)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def everytype2bool(v):
|
| 184 |
+
if v.isnumeric():
|
| 185 |
+
return bool(int(v))
|
| 186 |
+
v = v.lower()
|
| 187 |
+
if v in ['n', 'no', 'none', 'false']:
|
| 188 |
+
return False
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def update_argparser(parser):
|
| 193 |
+
for name in cfg.keys():
|
| 194 |
+
group = parser.add_argument_group(name)
|
| 195 |
+
for key, value in getattr(cfg, name).items():
|
| 196 |
+
t = type(value)
|
| 197 |
+
|
| 198 |
+
if t == bool:
|
| 199 |
+
group.add_argument(f"--{key}", action='store_true' if t else 'store_false')
|
| 200 |
+
elif t == list:
|
| 201 |
+
group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs="*")
|
| 202 |
+
elif t == tuple:
|
| 203 |
+
group.add_argument(f"--{key}", default=value, type=type(value[0]), nargs=len(value))
|
| 204 |
+
else:
|
| 205 |
+
group.add_argument(f"--{key}", default=value, type=t)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def update_config(cfg_files, cmd_lst=[]):
|
| 209 |
+
# Update from config files
|
| 210 |
+
if isinstance(cfg_files, str):
|
| 211 |
+
cfg_files = [cfg_files]
|
| 212 |
+
for cfg_path in cfg_files:
|
| 213 |
+
cfg.merge_from_file(cfg_path)
|
| 214 |
+
|
| 215 |
+
if len(cmd_lst) == 0:
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
# Parse the arguments from command line
|
| 219 |
+
internal_parser = argparse.ArgumentParser()
|
| 220 |
+
update_argparser(internal_parser)
|
| 221 |
+
internal_args = internal_parser.parse_args(cmd_lst)
|
| 222 |
+
|
| 223 |
+
# Update from command line args
|
| 224 |
+
for name in cfg.keys():
|
| 225 |
+
cfg_subgroup = getattr(cfg, name)
|
| 226 |
+
for key in cfg_subgroup.keys():
|
| 227 |
+
arg_val = getattr(internal_args, key)
|
| 228 |
+
# Check if the default values is updated
|
| 229 |
+
if internal_parser.get_default(key) != arg_val:
|
| 230 |
+
cfg_subgroup[key] = arg_val
|
src/dataloader/__pycache__/data_pack.cpython-39.pyc
ADDED
|
Binary file (5.62 kB). View file
|
|
|
src/dataloader/__pycache__/reader_colmap_dataset.cpython-39.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
src/dataloader/__pycache__/reader_nerf_dataset.cpython-39.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
src/dataloader/data_pack.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from src.dataloader.reader_colmap_dataset import read_colmap_dataset
|
| 17 |
+
from src.dataloader.reader_nerf_dataset import read_nerf_dataset
|
| 18 |
+
from src.utils.camera_utils import interpolate_poses
|
| 19 |
+
|
| 20 |
+
from src.cameras import Camera, MiniCam
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DataPack:
|
| 24 |
+
|
| 25 |
+
def __init__(self,
|
| 26 |
+
source_path,
|
| 27 |
+
image_dir_name="images",
|
| 28 |
+
mask_dir_name="masks",
|
| 29 |
+
res_downscale=0.,
|
| 30 |
+
res_width=0,
|
| 31 |
+
skip_blend_alpha=False,
|
| 32 |
+
alpha_is_white=False,
|
| 33 |
+
data_device="cpu",
|
| 34 |
+
use_test=False,
|
| 35 |
+
test_every=8,
|
| 36 |
+
camera_params_only=False):
|
| 37 |
+
|
| 38 |
+
camera_creator = CameraCreator(
|
| 39 |
+
res_downscale=res_downscale,
|
| 40 |
+
res_width=res_width,
|
| 41 |
+
skip_blend_alpha=skip_blend_alpha,
|
| 42 |
+
alpha_is_white=alpha_is_white,
|
| 43 |
+
data_device=data_device,
|
| 44 |
+
camera_params_only=camera_params_only,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
sparse_path = os.path.join(source_path, "sparse")
|
| 48 |
+
colmap_path = os.path.join(source_path, "colmap", "sparse")
|
| 49 |
+
meta_path1 = os.path.join(source_path, "transforms_train.json")
|
| 50 |
+
meta_path2 = os.path.join(source_path, "transforms.json")
|
| 51 |
+
|
| 52 |
+
# Read images concurrently
|
| 53 |
+
s_time = time.perf_counter()
|
| 54 |
+
|
| 55 |
+
if os.path.exists(sparse_path) or os.path.exists(colmap_path):
|
| 56 |
+
print("Read dataset in COLMAP format.")
|
| 57 |
+
dataset = read_colmap_dataset(
|
| 58 |
+
source_path=source_path,
|
| 59 |
+
image_dir_name=image_dir_name,
|
| 60 |
+
mask_dir_name=mask_dir_name,
|
| 61 |
+
use_test=use_test,
|
| 62 |
+
test_every=test_every,
|
| 63 |
+
camera_creator=camera_creator)
|
| 64 |
+
elif os.path.exists(meta_path1) or os.path.exists(meta_path2):
|
| 65 |
+
print("Read dataset in NeRF format.")
|
| 66 |
+
dataset = read_nerf_dataset(
|
| 67 |
+
source_path=source_path,
|
| 68 |
+
use_test=use_test,
|
| 69 |
+
test_every=test_every,
|
| 70 |
+
camera_creator=camera_creator)
|
| 71 |
+
else:
|
| 72 |
+
raise Exception("Unknown scene type!")
|
| 73 |
+
|
| 74 |
+
e_time = time.perf_counter()
|
| 75 |
+
print(f"Read dataset in {e_time - s_time:.3f} seconds.")
|
| 76 |
+
|
| 77 |
+
self._cameras = {
|
| 78 |
+
'train': dataset['train_cam_lst'],
|
| 79 |
+
'test': dataset['test_cam_lst'],
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
##############################
|
| 83 |
+
# Read additional dataset info
|
| 84 |
+
##############################
|
| 85 |
+
# If the dataset suggested a scene bound
|
| 86 |
+
self.suggested_bounding = dataset.get('suggested_bounding', None)
|
| 87 |
+
|
| 88 |
+
# If the dataset provide a transformation to other coordinate
|
| 89 |
+
self.to_world_matrix = None
|
| 90 |
+
to_world_path = os.path.join(source_path, 'to_world_matrix.txt')
|
| 91 |
+
if os.path.isfile(to_world_path):
|
| 92 |
+
self.to_world_matrix = np.loadtxt(to_world_path)
|
| 93 |
+
|
| 94 |
+
# If the dataset has a point cloud
|
| 95 |
+
self.point_cloud = dataset.get('point_cloud', None)
|
| 96 |
+
|
| 97 |
+
def get_train_cameras(self):
|
| 98 |
+
return self._cameras['train']
|
| 99 |
+
|
| 100 |
+
def get_test_cameras(self):
|
| 101 |
+
return self._cameras['test']
|
| 102 |
+
|
| 103 |
+
def interpolate_cameras(self, n_frames, starting_id=0, ids=[], step_forward=0):
|
| 104 |
+
cams = self.get_train_cameras()
|
| 105 |
+
if len(ids):
|
| 106 |
+
key_poses = [cams[i].c2w.cpu().numpy() for i in ids]
|
| 107 |
+
else:
|
| 108 |
+
assert starting_id >= 0
|
| 109 |
+
assert starting_id < len(cams)
|
| 110 |
+
cam_pos = torch.stack([cam.position for cam in cams])
|
| 111 |
+
ids = [starting_id]
|
| 112 |
+
for _ in range(3):
|
| 113 |
+
farthest_id = torch.cdist(cam_pos[ids], cam_pos).amin(0).argmax().item()
|
| 114 |
+
ids.append(farthest_id)
|
| 115 |
+
ids[1], ids[2] = ids[2], ids[1]
|
| 116 |
+
key_poses = [cams[i].c2w.cpu().numpy() for i in ids]
|
| 117 |
+
|
| 118 |
+
if step_forward != 0:
|
| 119 |
+
for i in range(len(key_poses)):
|
| 120 |
+
lookat = key_poses[i][:3, 2]
|
| 121 |
+
key_poses[i][:3, 3] += step_forward * lookat
|
| 122 |
+
|
| 123 |
+
interp_poses = interpolate_poses(key_poses, n_frame=n_frames, periodic=True)
|
| 124 |
+
|
| 125 |
+
base_cam = cams[ids[0]]
|
| 126 |
+
interp_cams = [
|
| 127 |
+
MiniCam(
|
| 128 |
+
c2w=pose,
|
| 129 |
+
fovx=base_cam.fovx, fovy=base_cam.fovy,
|
| 130 |
+
width=base_cam.image_width, height=base_cam.image_height)
|
| 131 |
+
for pose in interp_poses]
|
| 132 |
+
return interp_cams
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Create a random sequence of image indices
|
| 136 |
+
def compute_iter_idx(num_data, num_iter):
|
| 137 |
+
tr_iter_idx = []
|
| 138 |
+
while len(tr_iter_idx) < num_iter:
|
| 139 |
+
lst = list(range(num_data))
|
| 140 |
+
random.shuffle(lst)
|
| 141 |
+
tr_iter_idx.extend(lst)
|
| 142 |
+
return tr_iter_idx[:num_iter]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Function that create Camera instances while parsing dataset
|
| 146 |
+
class CameraCreator:
|
| 147 |
+
|
| 148 |
+
warned = False
|
| 149 |
+
|
| 150 |
+
def __init__(self,
|
| 151 |
+
res_downscale=0.,
|
| 152 |
+
res_width=0,
|
| 153 |
+
skip_blend_alpha=False,
|
| 154 |
+
alpha_is_white=False,
|
| 155 |
+
data_device="cpu",
|
| 156 |
+
camera_params_only=False):
|
| 157 |
+
|
| 158 |
+
self.res_downscale = res_downscale
|
| 159 |
+
self.res_width = res_width
|
| 160 |
+
self.skip_blend_alpha = skip_blend_alpha
|
| 161 |
+
self.alpha_is_white = alpha_is_white
|
| 162 |
+
self.data_device = data_device
|
| 163 |
+
self.camera_params_only = camera_params_only
|
| 164 |
+
|
| 165 |
+
def __call__(self,
|
| 166 |
+
image,
|
| 167 |
+
w2c,
|
| 168 |
+
fovx,
|
| 169 |
+
fovy,
|
| 170 |
+
cx_p=0.5,
|
| 171 |
+
cy_p=0.5,
|
| 172 |
+
sparse_pt=None,
|
| 173 |
+
image_name="",
|
| 174 |
+
mask=None):
|
| 175 |
+
|
| 176 |
+
# Determine target resolution
|
| 177 |
+
if self.res_downscale > 0:
|
| 178 |
+
downscale = self.res_downscale
|
| 179 |
+
elif self.res_width > 0:
|
| 180 |
+
downscale = image.size[0] / self.res_width
|
| 181 |
+
else:
|
| 182 |
+
downscale = 1
|
| 183 |
+
|
| 184 |
+
total_pix = image.size[0] * image.size[1]
|
| 185 |
+
if total_pix > 1200 ** 2 and not self.warned:
|
| 186 |
+
self.warned = True
|
| 187 |
+
suggest_ds = (total_pix ** 0.5) / 1200
|
| 188 |
+
print(f"###################################################################")
|
| 189 |
+
print(f"Image too large. Suggest to use `--res_downscale {suggest_ds:.1f}`.")
|
| 190 |
+
print(f"###################################################################")
|
| 191 |
+
|
| 192 |
+
# Load camera parameters only
|
| 193 |
+
if self.camera_params_only:
|
| 194 |
+
return MiniCam(
|
| 195 |
+
c2w=np.linalg.inv(w2c),
|
| 196 |
+
fovx=fovx, fovy=fovy,
|
| 197 |
+
cx_p=cx_p, cy_p=cy_p,
|
| 198 |
+
width=round(image.size[0] / downscale),
|
| 199 |
+
height=round(image.size[1] / downscale),
|
| 200 |
+
image_name=image_name)
|
| 201 |
+
|
| 202 |
+
# Resize image if needed
|
| 203 |
+
if downscale != 1:
|
| 204 |
+
size = (round(image.size[0] / downscale), round(image.size[1] / downscale))
|
| 205 |
+
image = image.resize(size)
|
| 206 |
+
|
| 207 |
+
# Convert image to tensor
|
| 208 |
+
tensor = torch.tensor(np.array(image), dtype=torch.float32).moveaxis(-1, 0) / 255.0
|
| 209 |
+
if tensor.shape[0] == 4:
|
| 210 |
+
# Blend alpha channel
|
| 211 |
+
tensor, mask = tensor.split([3, 1], dim=0)
|
| 212 |
+
if not self.skip_blend_alpha:
|
| 213 |
+
tensor = tensor * mask + int(self.alpha_is_white) * (1 - mask)
|
| 214 |
+
|
| 215 |
+
# Conver mask to tensor if there is
|
| 216 |
+
if mask is not None:
|
| 217 |
+
size = tensor.shape[-2:][::-1]
|
| 218 |
+
if mask.size != size:
|
| 219 |
+
mask = mask.resize(size)
|
| 220 |
+
mask = torch.tensor(np.array(mask), dtype=torch.float32) / 255.0
|
| 221 |
+
if len(mask.shape) == 3:
|
| 222 |
+
mask = mask.mean(-1)
|
| 223 |
+
mask = mask[None]
|
| 224 |
+
|
| 225 |
+
return Camera(
|
| 226 |
+
w2c=w2c,
|
| 227 |
+
fovx=fovx, fovy=fovy,
|
| 228 |
+
cx_p=cx_p, cy_p=cy_p,
|
| 229 |
+
image=tensor,
|
| 230 |
+
mask=mask,
|
| 231 |
+
sparse_pt=sparse_pt,
|
| 232 |
+
image_name=image_name)
|
src/dataloader/reader_colmap_dataset.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import natsort
|
| 12 |
+
import pycolmap
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import concurrent.futures
|
| 17 |
+
|
| 18 |
+
from src.utils.colmap_utils import parse_colmap_pts
|
| 19 |
+
from src.utils.camera_utils import focal2fov
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def read_colmap_dataset(source_path, image_dir_name, mask_dir_name, use_test, test_every, camera_creator):
|
| 23 |
+
"""
|
| 24 |
+
Read a COLMAP dataset and return cameras, intrinsics, extrinsics, and optional masks.
|
| 25 |
+
|
| 26 |
+
Fixes:
|
| 27 |
+
- Safe image/mask opening using `with Image.open(...)` (no file leaks).
|
| 28 |
+
- Compatible with both old/new pycolmap APIs.
|
| 29 |
+
- Returns PIL.Image objects (for backward compatibility with DataPack).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
source_path = Path(source_path)
|
| 33 |
+
|
| 34 |
+
# ---------------- Parse COLMAP reconstruction ----------------
|
| 35 |
+
sparse_path = source_path / "sparse" / "0"
|
| 36 |
+
if not sparse_path.exists():
|
| 37 |
+
sparse_path = source_path / "colmap" / "sparse" / "0"
|
| 38 |
+
if not sparse_path.exists():
|
| 39 |
+
raise Exception("Cannot find COLMAP reconstruction (expected sparse/0 or colmap/sparse/0).")
|
| 40 |
+
|
| 41 |
+
sfm = pycolmap.Reconstruction(sparse_path)
|
| 42 |
+
point_cloud = parse_colmap_pts(sfm)
|
| 43 |
+
correspondent = point_cloud.corr
|
| 44 |
+
|
| 45 |
+
# ---------------- Sort key by filename ----------------
|
| 46 |
+
keys = natsort.natsorted(sfm.images.keys(), key=lambda k: sfm.images[k].name)
|
| 47 |
+
|
| 48 |
+
# ---------------- Load all frames ----------------
|
| 49 |
+
todo_lst = []
|
| 50 |
+
for key in keys:
|
| 51 |
+
frame = sfm.images[key]
|
| 52 |
+
|
| 53 |
+
# ---- Load RGB image safely ----
|
| 54 |
+
image_path = source_path / image_dir_name / frame.name
|
| 55 |
+
if not image_path.exists():
|
| 56 |
+
image_path = image_path.with_suffix(".png")
|
| 57 |
+
if not image_path.exists():
|
| 58 |
+
image_path = image_path.with_suffix(".jpg")
|
| 59 |
+
if not image_path.exists():
|
| 60 |
+
image_path = image_path.with_suffix(".JPG")
|
| 61 |
+
if not image_path.exists():
|
| 62 |
+
raise Exception(f"File not found: {str(image_path)}")
|
| 63 |
+
|
| 64 |
+
# safely open and immediately copy to new PIL object (closed after copy)
|
| 65 |
+
with Image.open(image_path) as img:
|
| 66 |
+
image = img.copy() # copy keeps data in memory, closes file handle
|
| 67 |
+
|
| 68 |
+
# ---- Load intrinsics ----
|
| 69 |
+
if frame.camera.model.name == "SIMPLE_PINHOLE":
|
| 70 |
+
focal_x, cx, cy = frame.camera.params
|
| 71 |
+
fovx = focal2fov(focal_x, frame.camera.width)
|
| 72 |
+
fovy = focal2fov(focal_x, frame.camera.height)
|
| 73 |
+
cx_p = cx / frame.camera.width
|
| 74 |
+
cy_p = cy / frame.camera.height
|
| 75 |
+
elif frame.camera.model.name == "PINHOLE":
|
| 76 |
+
focal_x, focal_y, cx, cy = frame.camera.params
|
| 77 |
+
fovx = focal2fov(focal_x, frame.camera.width)
|
| 78 |
+
fovy = focal2fov(focal_y, frame.camera.height)
|
| 79 |
+
cx_p = cx / frame.camera.width
|
| 80 |
+
cy_p = cy / frame.camera.height
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Unsupported COLMAP camera model: {frame.camera.model.name}. "
|
| 84 |
+
"Only undistorted SIMPLE_PINHOLE and PINHOLE are supported."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# ---- Load extrinsics (support both pycolmap APIs) ----
|
| 88 |
+
w2c = np.eye(4, dtype=np.float32)
|
| 89 |
+
cam_from_world = getattr(frame, "cam_from_world", None)
|
| 90 |
+
if cam_from_world is not None:
|
| 91 |
+
if callable(cam_from_world):
|
| 92 |
+
# Old pycolmap API
|
| 93 |
+
w2c[:3] = cam_from_world().matrix()
|
| 94 |
+
else:
|
| 95 |
+
# New pycolmap API (Rigid3d object)
|
| 96 |
+
w2c[:3] = cam_from_world.matrix()
|
| 97 |
+
else:
|
| 98 |
+
raise RuntimeError("Cannot find cam_from_world attribute in COLMAP frame.")
|
| 99 |
+
|
| 100 |
+
# ---- Sparse point correspondence ----
|
| 101 |
+
sparse_pt = point_cloud.points[correspondent[frame.name]]
|
| 102 |
+
|
| 103 |
+
# ---- Optional mask ----
|
| 104 |
+
mask = None
|
| 105 |
+
if mask_dir_name is not None:
|
| 106 |
+
mask_path = (source_path / mask_dir_name / frame.name).with_suffix(".png")
|
| 107 |
+
if mask_path.exists():
|
| 108 |
+
with Image.open(mask_path) as m:
|
| 109 |
+
mask = m.copy() # keep PIL.Image for DataPack
|
| 110 |
+
|
| 111 |
+
# ---- Store frame data ----
|
| 112 |
+
todo_lst.append(dict(
|
| 113 |
+
image=image,
|
| 114 |
+
w2c=w2c,
|
| 115 |
+
fovx=fovx,
|
| 116 |
+
fovy=fovy,
|
| 117 |
+
cx_p=cx_p,
|
| 118 |
+
cy_p=cy_p,
|
| 119 |
+
sparse_pt=sparse_pt,
|
| 120 |
+
image_name=image_path.name,
|
| 121 |
+
mask=mask,
|
| 122 |
+
))
|
| 123 |
+
|
| 124 |
+
# ---------------- Create cameras concurrently ----------------
|
| 125 |
+
import torch
|
| 126 |
+
torch.inverse(torch.eye(3, device="cuda")) # fix PyTorch lazy init bug
|
| 127 |
+
|
| 128 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 129 |
+
futures = [executor.submit(camera_creator, **todo) for todo in todo_lst]
|
| 130 |
+
cam_lst = [f.result() for f in futures]
|
| 131 |
+
|
| 132 |
+
# ---------------- Split train/test ----------------
|
| 133 |
+
if use_test:
|
| 134 |
+
train_cam_lst = [cam for i, cam in enumerate(cam_lst) if i % test_every != 0]
|
| 135 |
+
test_cam_lst = [cam for i, cam in enumerate(cam_lst) if i % test_every == 0]
|
| 136 |
+
else:
|
| 137 |
+
train_cam_lst = cam_lst
|
| 138 |
+
test_cam_lst = []
|
| 139 |
+
|
| 140 |
+
# ---------------- Optional bounding box ----------------
|
| 141 |
+
nerf_normalization_path = source_path / "nerf_normalization.json"
|
| 142 |
+
if nerf_normalization_path.is_file():
|
| 143 |
+
with open(nerf_normalization_path) as f:
|
| 144 |
+
nerf_norm = json.load(f)
|
| 145 |
+
suggested_center = np.array(nerf_norm["center"], dtype=np.float32)
|
| 146 |
+
suggested_radius = np.array(nerf_norm["radius"], dtype=np.float32)
|
| 147 |
+
suggested_bounding = np.stack([
|
| 148 |
+
suggested_center - suggested_radius,
|
| 149 |
+
suggested_center + suggested_radius,
|
| 150 |
+
])
|
| 151 |
+
else:
|
| 152 |
+
suggested_bounding = None
|
| 153 |
+
|
| 154 |
+
# ---------------- Return dataset ----------------
|
| 155 |
+
dataset = {
|
| 156 |
+
"train_cam_lst": train_cam_lst,
|
| 157 |
+
"test_cam_lst": test_cam_lst,
|
| 158 |
+
"suggested_bounding": suggested_bounding,
|
| 159 |
+
"point_cloud": point_cloud,
|
| 160 |
+
}
|
| 161 |
+
return dataset
|
| 162 |
+
|
src/dataloader/reader_colmap_dataset_or.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import natsort
|
| 12 |
+
import pycolmap
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import concurrent.futures
|
| 17 |
+
|
| 18 |
+
from src.utils.colmap_utils import parse_colmap_pts
|
| 19 |
+
from src.utils.camera_utils import focal2fov
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def read_colmap_dataset(source_path, image_dir_name, mask_dir_name, use_test, test_every, camera_creator):
|
| 23 |
+
|
| 24 |
+
source_path = Path(source_path)
|
| 25 |
+
|
| 26 |
+
# Parse colmap meta data
|
| 27 |
+
sparse_path = source_path / "sparse" / "0"
|
| 28 |
+
if not sparse_path.exists():
|
| 29 |
+
sparse_path = source_path / "colmap" / "sparse" / "0"
|
| 30 |
+
if not sparse_path.exists():
|
| 31 |
+
raise Exception("Can not find COLMAP reconstruction.")
|
| 32 |
+
|
| 33 |
+
sfm = pycolmap.Reconstruction(sparse_path)
|
| 34 |
+
point_cloud = parse_colmap_pts(sfm)
|
| 35 |
+
correspondent = point_cloud.corr
|
| 36 |
+
|
| 37 |
+
# Sort key by filename
|
| 38 |
+
keys = natsort.natsorted(
|
| 39 |
+
sfm.images.keys(),
|
| 40 |
+
key = lambda k : sfm.images[k].name)
|
| 41 |
+
|
| 42 |
+
# Load all images and cameras
|
| 43 |
+
todo_lst = []
|
| 44 |
+
for key in keys:
|
| 45 |
+
|
| 46 |
+
frame = sfm.images[key]
|
| 47 |
+
|
| 48 |
+
# Load image
|
| 49 |
+
image_path = source_path / image_dir_name / frame.name
|
| 50 |
+
if not image_path.exists():
|
| 51 |
+
image_path = image_path.with_suffix('.png')
|
| 52 |
+
if not image_path.exists():
|
| 53 |
+
image_path = image_path.with_suffix('.jpg')
|
| 54 |
+
if not image_path.exists():
|
| 55 |
+
image_path = image_path.with_suffix('.JPG')
|
| 56 |
+
if not image_path.exists():
|
| 57 |
+
raise Exception(f"File not found: {str(image_path)}")
|
| 58 |
+
image = Image.open(image_path)
|
| 59 |
+
|
| 60 |
+
# Load camera intrinsic
|
| 61 |
+
if frame.camera.model.name == "SIMPLE_PINHOLE":
|
| 62 |
+
focal_x, cx, cy = frame.camera.params
|
| 63 |
+
fovx = focal2fov(focal_x, frame.camera.width)
|
| 64 |
+
fovy = focal2fov(focal_x, frame.camera.height)
|
| 65 |
+
cx_p = cx / frame.camera.width
|
| 66 |
+
cy_p = cy / frame.camera.height
|
| 67 |
+
elif frame.camera.model.name == "PINHOLE":
|
| 68 |
+
focal_x, focal_y, cx, cy = frame.camera.params
|
| 69 |
+
fovx = focal2fov(focal_x, frame.camera.width)
|
| 70 |
+
fovy = focal2fov(focal_y, frame.camera.height)
|
| 71 |
+
cx_p = cx / frame.camera.width
|
| 72 |
+
cy_p = cy / frame.camera.height
|
| 73 |
+
else:
|
| 74 |
+
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
| 75 |
+
|
| 76 |
+
# Load camera extrinsic
|
| 77 |
+
w2c = np.eye(4, dtype=np.float32)
|
| 78 |
+
try:
|
| 79 |
+
w2c[:3] = frame.cam_from_world().matrix()
|
| 80 |
+
except:
|
| 81 |
+
# Older version of pycolmap
|
| 82 |
+
w2c[:3] = frame.cam_from_world.matrix()
|
| 83 |
+
|
| 84 |
+
# Load sparse point
|
| 85 |
+
sparse_pt = point_cloud.points[correspondent[frame.name]]
|
| 86 |
+
|
| 87 |
+
# Load mask if there is
|
| 88 |
+
mask_path = (source_path / mask_dir_name / frame.name).with_suffix('.png')
|
| 89 |
+
if mask_path.exists():
|
| 90 |
+
mask = Image.open(mask_path)
|
| 91 |
+
else:
|
| 92 |
+
mask = None
|
| 93 |
+
|
| 94 |
+
todo_lst.append(dict(
|
| 95 |
+
image=image,
|
| 96 |
+
w2c=w2c,
|
| 97 |
+
fovx=fovx,
|
| 98 |
+
fovy=fovy,
|
| 99 |
+
cx_p=cx_p,
|
| 100 |
+
cy_p=cy_p,
|
| 101 |
+
sparse_pt=sparse_pt,
|
| 102 |
+
image_name=image_path.name,
|
| 103 |
+
mask=mask,
|
| 104 |
+
))
|
| 105 |
+
|
| 106 |
+
# Load all cameras concurrently
|
| 107 |
+
import torch
|
| 108 |
+
torch.inverse(torch.eye(3, device="cuda")) # Fix module lazy loading bug:
|
| 109 |
+
# https://github.com/pytorch/pytorch/issues/90613
|
| 110 |
+
|
| 111 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 112 |
+
futures = [executor.submit(camera_creator, **todo) for todo in todo_lst]
|
| 113 |
+
cam_lst = [f.result() for f in futures]
|
| 114 |
+
|
| 115 |
+
# Split train/test
|
| 116 |
+
if use_test:
|
| 117 |
+
train_cam_lst = [
|
| 118 |
+
cam for i, cam in enumerate(cam_lst)
|
| 119 |
+
if i % test_every != 0]
|
| 120 |
+
test_cam_lst = [
|
| 121 |
+
cam for i, cam in enumerate(cam_lst)
|
| 122 |
+
if i % test_every == 0]
|
| 123 |
+
else:
|
| 124 |
+
train_cam_lst = cam_lst
|
| 125 |
+
test_cam_lst = []
|
| 126 |
+
|
| 127 |
+
# Parse main scene bound if there is
|
| 128 |
+
nerf_normalization_path = os.path.join(source_path, "nerf_normalization.json")
|
| 129 |
+
if os.path.isfile(nerf_normalization_path):
|
| 130 |
+
with open(nerf_normalization_path) as f:
|
| 131 |
+
nerf_normalization = json.load(f)
|
| 132 |
+
suggested_center = np.array(nerf_normalization["center"], dtype=np.float32)
|
| 133 |
+
suggested_radius = np.array(nerf_normalization["radius"], dtype=np.float32)
|
| 134 |
+
suggested_bounding = np.stack([
|
| 135 |
+
suggested_center - suggested_radius,
|
| 136 |
+
suggested_center + suggested_radius,
|
| 137 |
+
])
|
| 138 |
+
else:
|
| 139 |
+
suggested_bounding = None
|
| 140 |
+
|
| 141 |
+
# Pack dataset
|
| 142 |
+
dataset = {
|
| 143 |
+
'train_cam_lst': train_cam_lst,
|
| 144 |
+
'test_cam_lst': test_cam_lst,
|
| 145 |
+
'suggested_bounding': suggested_bounding,
|
| 146 |
+
'point_cloud': point_cloud,
|
| 147 |
+
}
|
| 148 |
+
return dataset
|
src/dataloader/reader_nerf_dataset.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import pycolmap
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import concurrent.futures
|
| 16 |
+
|
| 17 |
+
from src.utils.colmap_utils import parse_colmap_pts
|
| 18 |
+
from src.utils.camera_utils import fov2focal, focal2fov
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def read_nerf_dataset(source_path, test_every, use_test, camera_creator):
|
| 22 |
+
|
| 23 |
+
source_path = Path(source_path)
|
| 24 |
+
|
| 25 |
+
# Load training cameras
|
| 26 |
+
if (source_path / "transforms_train.json").exists():
|
| 27 |
+
train_cam_lst, point_cloud = read_cameras_from_json(
|
| 28 |
+
source_path=source_path,
|
| 29 |
+
meta_fname="transforms_train.json",
|
| 30 |
+
camera_creator=camera_creator)
|
| 31 |
+
else:
|
| 32 |
+
train_cam_lst, point_cloud = read_cameras_from_json(
|
| 33 |
+
source_path=source_path,
|
| 34 |
+
meta_fname="transforms.json",
|
| 35 |
+
camera_creator=camera_creator)
|
| 36 |
+
|
| 37 |
+
# Load testing cameras
|
| 38 |
+
if (source_path / "transforms_test.json").exists():
|
| 39 |
+
test_cam_lst, _ = read_cameras_from_json(
|
| 40 |
+
source_path=source_path,
|
| 41 |
+
meta_fname="transforms_test.json",
|
| 42 |
+
camera_creator=camera_creator)
|
| 43 |
+
elif use_test:
|
| 44 |
+
test_cam_lst = [
|
| 45 |
+
cam for i, cam in enumerate(train_cam_lst)
|
| 46 |
+
if i % test_every == 0]
|
| 47 |
+
train_cam_lst = [
|
| 48 |
+
cam for i, cam in enumerate(train_cam_lst)
|
| 49 |
+
if i % test_every != 0]
|
| 50 |
+
else:
|
| 51 |
+
test_cam_lst = []
|
| 52 |
+
|
| 53 |
+
# Parse main scene bound if there is
|
| 54 |
+
nerf_normalization_path = os.path.join(source_path, "nerf_normalization.json")
|
| 55 |
+
if os.path.isfile(nerf_normalization_path):
|
| 56 |
+
with open(nerf_normalization_path) as f:
|
| 57 |
+
nerf_normalization = json.load(f)
|
| 58 |
+
suggested_center = np.array(nerf_normalization["center"], dtype=np.float32)
|
| 59 |
+
suggested_radius = np.array(nerf_normalization["radius"], dtype=np.float32)
|
| 60 |
+
suggested_bounding = np.stack([
|
| 61 |
+
suggested_center - suggested_radius,
|
| 62 |
+
suggested_center + suggested_radius,
|
| 63 |
+
])
|
| 64 |
+
else:
|
| 65 |
+
# Assume synthetic blender scene bound
|
| 66 |
+
suggested_bounding = np.array([
|
| 67 |
+
[-1.5, -1.5, -1.5],
|
| 68 |
+
[1.5, 1.5, 1.5],
|
| 69 |
+
], dtype=np.float32)
|
| 70 |
+
|
| 71 |
+
# Pack dataset
|
| 72 |
+
dataset = {
|
| 73 |
+
'train_cam_lst': train_cam_lst,
|
| 74 |
+
'test_cam_lst': test_cam_lst,
|
| 75 |
+
'suggested_bounding': suggested_bounding,
|
| 76 |
+
'point_cloud': point_cloud,
|
| 77 |
+
}
|
| 78 |
+
return dataset
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def read_cameras_from_json(source_path, meta_fname, camera_creator):
|
| 82 |
+
|
| 83 |
+
with open(source_path / meta_fname) as f:
|
| 84 |
+
meta = json.load(f)
|
| 85 |
+
|
| 86 |
+
# Load COLMAP points if there is
|
| 87 |
+
if "colmap" in meta:
|
| 88 |
+
sfm = pycolmap.Reconstruction(source_path / meta["colmap"]["path"])
|
| 89 |
+
if "transform" in meta["colmap"]:
|
| 90 |
+
transform = np.array(meta["colmap"]["transform"])
|
| 91 |
+
else:
|
| 92 |
+
transform = None
|
| 93 |
+
point_cloud = parse_colmap_pts(sfm, transform)
|
| 94 |
+
correspondent = point_cloud.corr
|
| 95 |
+
else:
|
| 96 |
+
point_cloud = None
|
| 97 |
+
correspondent = None
|
| 98 |
+
|
| 99 |
+
# Load global setup
|
| 100 |
+
global_fovx = meta.get("camera_angle_x", 0)
|
| 101 |
+
global_fovy = meta.get("camera_angle_y", 0)
|
| 102 |
+
global_cx_p = parse_principle_point(meta, is_cx=True)
|
| 103 |
+
global_cy_p = parse_principle_point(meta, is_cx=False)
|
| 104 |
+
|
| 105 |
+
# Load all images and cameras
|
| 106 |
+
todo_lst = []
|
| 107 |
+
for frame in meta["frames"]:
|
| 108 |
+
|
| 109 |
+
# Guess the rgb image path and load image
|
| 110 |
+
path_candidates = [
|
| 111 |
+
source_path / frame["file_path"],
|
| 112 |
+
source_path / (frame["file_path"] + '.png'),
|
| 113 |
+
source_path / (frame["file_path"] + '.jpg'),
|
| 114 |
+
source_path / (frame["file_path"] + '.JPG'),
|
| 115 |
+
]
|
| 116 |
+
for image_path in path_candidates:
|
| 117 |
+
if image_path.exists():
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
if frame.get('heldout', False):
|
| 121 |
+
image = Image.new('RGB', (frame['w'], frame['h']))
|
| 122 |
+
elif image_path.exists():
|
| 123 |
+
image = Image.open(image_path)
|
| 124 |
+
else:
|
| 125 |
+
raise Exception(f"File not found: {str(image_path)}")
|
| 126 |
+
|
| 127 |
+
# Load camera intrinsic
|
| 128 |
+
fovx = frame.get('camera_angle_x', global_fovx)
|
| 129 |
+
cx_p = frame.get('cx_p', global_cx_p)
|
| 130 |
+
cy_p = frame.get('cy_p', global_cy_p)
|
| 131 |
+
|
| 132 |
+
if 'camera_angle_y' in frame:
|
| 133 |
+
fovy = frame['camera_angle_y']
|
| 134 |
+
elif global_fovy > 0:
|
| 135 |
+
fovy = global_fovy
|
| 136 |
+
else:
|
| 137 |
+
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
| 138 |
+
|
| 139 |
+
# Load camera pose
|
| 140 |
+
c2w = np.array(frame["transform_matrix"])
|
| 141 |
+
c2w[:3, 1:3] *= -1 # from opengl y-up-z-back to colmap y-down-z-forward
|
| 142 |
+
w2c = np.linalg.inv(c2w).astype(np.float32)
|
| 143 |
+
|
| 144 |
+
# Load sparse point
|
| 145 |
+
if point_cloud is not None:
|
| 146 |
+
sparse_pt = point_cloud.points[correspondent[image_path.name]]
|
| 147 |
+
else:
|
| 148 |
+
sparse_pt = None
|
| 149 |
+
|
| 150 |
+
todo_lst.append(dict(
|
| 151 |
+
image=image,
|
| 152 |
+
w2c=w2c,
|
| 153 |
+
fovx=fovx,
|
| 154 |
+
fovy=fovy,
|
| 155 |
+
cx_p=cx_p,
|
| 156 |
+
cy_p=cy_p,
|
| 157 |
+
sparse_pt=sparse_pt,
|
| 158 |
+
image_name=image_path.name,
|
| 159 |
+
))
|
| 160 |
+
|
| 161 |
+
# Load all cameras concurrently
|
| 162 |
+
import torch
|
| 163 |
+
torch.inverse(torch.eye(3, device="cuda")) # Fix module lazy loading bug:
|
| 164 |
+
# https://github.com/pytorch/pytorch/issues/90613
|
| 165 |
+
|
| 166 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 167 |
+
futures = [executor.submit(camera_creator, **todo) for todo in todo_lst]
|
| 168 |
+
cam_lst = [f.result() for f in futures]
|
| 169 |
+
|
| 170 |
+
return cam_lst, point_cloud
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def parse_principle_point(info, is_cx):
|
| 174 |
+
key = "cx" if is_cx else "cy"
|
| 175 |
+
key_res = "w" if is_cx else "h"
|
| 176 |
+
if f"{key}_p" in info:
|
| 177 |
+
return info[f"{key}_p"]
|
| 178 |
+
if key in info and key_res in info:
|
| 179 |
+
return info[key] / info[key_res]
|
| 180 |
+
return None
|
src/dataloader/reader_nerf_dataset_copy.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import pycolmap
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from src.utils.colmap_utils import parse_colmap_pts
|
| 17 |
+
from src.utils.camera_utils import fov2focal, focal2fov
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def read_nerf_dataset(source_path, test_every, use_test, camera_creator):
|
| 21 |
+
|
| 22 |
+
source_path = Path(source_path)
|
| 23 |
+
|
| 24 |
+
# Load training cameras
|
| 25 |
+
if (source_path / "transforms_train.json").exists():
|
| 26 |
+
train_cam_lst, point_cloud = read_cameras_from_json(
|
| 27 |
+
source_path=source_path,
|
| 28 |
+
meta_fname="transforms_train.json",
|
| 29 |
+
camera_creator=camera_creator)
|
| 30 |
+
else:
|
| 31 |
+
train_cam_lst, point_cloud = read_cameras_from_json(
|
| 32 |
+
source_path=source_path,
|
| 33 |
+
meta_fname="transforms.json",
|
| 34 |
+
camera_creator=camera_creator)
|
| 35 |
+
|
| 36 |
+
# Load testing cameras
|
| 37 |
+
if (source_path / "transforms_test.json").exists():
|
| 38 |
+
test_cam_lst, _ = read_cameras_from_json(
|
| 39 |
+
source_path=source_path,
|
| 40 |
+
meta_fname="transforms_test.json",
|
| 41 |
+
camera_creator=camera_creator)
|
| 42 |
+
elif use_test:
|
| 43 |
+
test_cam_lst = [
|
| 44 |
+
cam for i, cam in enumerate(train_cam_lst)
|
| 45 |
+
if i % test_every == 0]
|
| 46 |
+
train_cam_lst = [
|
| 47 |
+
cam for i, cam in enumerate(train_cam_lst)
|
| 48 |
+
if i % test_every != 0]
|
| 49 |
+
else:
|
| 50 |
+
test_cam_lst = []
|
| 51 |
+
|
| 52 |
+
# Parse main scene bound if there is
|
| 53 |
+
nerf_normalization_path = os.path.join(source_path, "nerf_normalization.json")
|
| 54 |
+
if os.path.isfile(nerf_normalization_path):
|
| 55 |
+
with open(nerf_normalization_path) as f:
|
| 56 |
+
nerf_normalization = json.load(f)
|
| 57 |
+
suggested_center = np.array(nerf_normalization["center"], dtype=np.float32)
|
| 58 |
+
suggested_radius = np.array(nerf_normalization["radius"], dtype=np.float32)
|
| 59 |
+
suggested_bounding = np.stack([
|
| 60 |
+
suggested_center - suggested_radius,
|
| 61 |
+
suggested_center + suggested_radius,
|
| 62 |
+
])
|
| 63 |
+
else:
|
| 64 |
+
# Assume synthetic blender scene bound
|
| 65 |
+
suggested_bounding = np.array([
|
| 66 |
+
[-1.5, -1.5, -1.5],
|
| 67 |
+
[1.5, 1.5, 1.5],
|
| 68 |
+
], dtype=np.float32)
|
| 69 |
+
|
| 70 |
+
# Pack dataset
|
| 71 |
+
dataset = {
|
| 72 |
+
'train_cam_lst': train_cam_lst,
|
| 73 |
+
'test_cam_lst': test_cam_lst,
|
| 74 |
+
'suggested_bounding': suggested_bounding,
|
| 75 |
+
'point_cloud': point_cloud,
|
| 76 |
+
}
|
| 77 |
+
return dataset
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def read_cameras_from_json(source_path, meta_fname, camera_creator):
|
| 81 |
+
|
| 82 |
+
with open(source_path / meta_fname) as f:
|
| 83 |
+
meta = json.load(f)
|
| 84 |
+
|
| 85 |
+
# Load COLMAP points if there is
|
| 86 |
+
if "colmap" in meta:
|
| 87 |
+
sfm = pycolmap.Reconstruction(source_path / meta["colmap"]["path"])
|
| 88 |
+
if "transform" in meta["colmap"]:
|
| 89 |
+
transform = np.array(meta["colmap"]["transform"])
|
| 90 |
+
else:
|
| 91 |
+
transform = None
|
| 92 |
+
point_cloud = parse_colmap_pts(sfm, transform)
|
| 93 |
+
correspondent = point_cloud.corr
|
| 94 |
+
else:
|
| 95 |
+
point_cloud = None
|
| 96 |
+
correspondent = None
|
| 97 |
+
|
| 98 |
+
# Load global setup
|
| 99 |
+
global_fovx = meta.get("camera_angle_x", 0)
|
| 100 |
+
global_fovy = meta.get("camera_angle_y", 0)
|
| 101 |
+
global_cx_p = parse_principle_point(meta, is_cx=True)
|
| 102 |
+
global_cy_p = parse_principle_point(meta, is_cx=False)
|
| 103 |
+
|
| 104 |
+
# Load all images and cameras
|
| 105 |
+
cam_lst = []
|
| 106 |
+
for frame in meta["frames"]:
|
| 107 |
+
|
| 108 |
+
# Guess the rgb image path and load image
|
| 109 |
+
path_candidates = [
|
| 110 |
+
source_path / frame["file_path"],
|
| 111 |
+
source_path / (frame["file_path"] + '.png'),
|
| 112 |
+
source_path / (frame["file_path"] + '.jpg'),
|
| 113 |
+
source_path / (frame["file_path"] + '.JPG'),
|
| 114 |
+
]
|
| 115 |
+
for image_path in path_candidates:
|
| 116 |
+
if image_path.exists():
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
+
if frame.get('heldout', False):
|
| 120 |
+
image = Image.new('RGB', (frame['w'], frame['h']))
|
| 121 |
+
elif image_path.exists():
|
| 122 |
+
image = Image.open(image_path)
|
| 123 |
+
else:
|
| 124 |
+
raise Exception(f"File not found: {str(image_path)}")
|
| 125 |
+
|
| 126 |
+
# Load camera intrinsic
|
| 127 |
+
fovx = frame.get('camera_angle_x', global_fovx)
|
| 128 |
+
cx_p = frame.get('cx_p', global_cx_p)
|
| 129 |
+
cy_p = frame.get('cy_p', global_cy_p)
|
| 130 |
+
|
| 131 |
+
if 'camera_angle_y' in frame:
|
| 132 |
+
fovy = frame['camera_angle_y']
|
| 133 |
+
elif global_fovy > 0:
|
| 134 |
+
fovy = global_fovy
|
| 135 |
+
else:
|
| 136 |
+
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
| 137 |
+
|
| 138 |
+
# Load camera pose
|
| 139 |
+
c2w = np.array(frame["transform_matrix"])
|
| 140 |
+
c2w[:3, 1:3] *= -1 # from opengl y-up-z-back to colmap y-down-z-forward
|
| 141 |
+
w2c = np.linalg.inv(c2w).astype(np.float32)
|
| 142 |
+
|
| 143 |
+
# Load sparse point
|
| 144 |
+
if point_cloud is not None:
|
| 145 |
+
sparse_pt = point_cloud.points[correspondent[image_path.name]]
|
| 146 |
+
else:
|
| 147 |
+
sparse_pt = None
|
| 148 |
+
|
| 149 |
+
cam_lst.append(camera_creator(
|
| 150 |
+
image=image,
|
| 151 |
+
w2c=w2c,
|
| 152 |
+
fovx=fovx,
|
| 153 |
+
fovy=fovy,
|
| 154 |
+
cx_p=cx_p,
|
| 155 |
+
cy_p=cy_p,
|
| 156 |
+
sparse_pt=sparse_pt,
|
| 157 |
+
image_name=image_path.name,
|
| 158 |
+
))
|
| 159 |
+
|
| 160 |
+
return cam_lst, point_cloud
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def parse_principle_point(info, is_cx):
|
| 164 |
+
key = "cx" if is_cx else "cy"
|
| 165 |
+
key_res = "w" if is_cx else "h"
|
| 166 |
+
if f"{key}_p" in info:
|
| 167 |
+
return info[f"{key}_p"]
|
| 168 |
+
if key in info and key_res in info:
|
| 169 |
+
return info[key] / info[key_res]
|
| 170 |
+
return None
|
src/sparse_voxel_gears/__pycache__/adaptive.cpython-39.pyc
ADDED
|
Binary file (6.77 kB). View file
|
|
|
src/sparse_voxel_gears/__pycache__/constructor.cpython-39.pyc
ADDED
|
Binary file (8 kB). View file
|
|
|
src/sparse_voxel_gears/__pycache__/io.cpython-39.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
src/sparse_voxel_gears/__pycache__/pooling.cpython-39.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
src/sparse_voxel_gears/__pycache__/properties.cpython-39.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
src/sparse_voxel_gears/__pycache__/renderer.cpython-39.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
src/sparse_voxel_gears/adaptive.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from src.utils import octree_utils
|
| 12 |
+
|
| 13 |
+
'''
|
| 14 |
+
Adaptive sparse voxel pruning and subdivision.
|
| 15 |
+
There are three types of data mode to tackle.
|
| 16 |
+
|
| 17 |
+
1. Per-voxel attribute:
|
| 18 |
+
Each voxel has it's own non-trainable data field.
|
| 19 |
+
|
| 20 |
+
2. Per-voxel parameters:
|
| 21 |
+
Similar to per-voxel attribute but these are trainable parameters.
|
| 22 |
+
|
| 23 |
+
3. Grid points parameters:
|
| 24 |
+
The trainable parameters are attached to the eight grid points of each voxel.
|
| 25 |
+
A grid point parameter can be shared by adjacent voxels.
|
| 26 |
+
'''
|
| 27 |
+
|
| 28 |
+
class SVAdaptive:
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def pruning(self, prune_mask):
|
| 32 |
+
'''
|
| 33 |
+
Prune sparse voxels. The grid points are updated accordingly.
|
| 34 |
+
|
| 35 |
+
Input:
|
| 36 |
+
@prune_mask [N] Mask indicating the voxels to prune.
|
| 37 |
+
'''
|
| 38 |
+
if len(prune_mask.shape) == 2:
|
| 39 |
+
assert prune_mask.shape[1] == 1
|
| 40 |
+
prune_mask = prune_mask.squeeze(1)
|
| 41 |
+
assert prune_mask.shape == (self.num_voxels, )
|
| 42 |
+
kept_idx = (~prune_mask).argwhere().squeeze(1)
|
| 43 |
+
if len(kept_idx) == 0:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
old_vox_key = self.vox_key.clone()
|
| 47 |
+
|
| 48 |
+
# Prune non-trainable per-voxel attributes.
|
| 49 |
+
for name in self.per_voxel_attr_lst:
|
| 50 |
+
ori_attr = getattr(self, name)
|
| 51 |
+
new_attr = mask_cat_perm(ori_attr, kept_idx=kept_idx)
|
| 52 |
+
setattr(self, name, new_attr)
|
| 53 |
+
if name == '_subdiv_p' and ori_attr.grad is not None:
|
| 54 |
+
self._subdiv_p.grad = mask_cat_perm(ori_attr.grad, kept_idx=kept_idx)
|
| 55 |
+
self._subdiv_p.requires_grad_()
|
| 56 |
+
del ori_attr
|
| 57 |
+
torch.cuda.empty_cache()
|
| 58 |
+
|
| 59 |
+
# Prune trainable per-voxel parameters.
|
| 60 |
+
for name in self.per_voxel_param_lst:
|
| 61 |
+
ori_param = getattr(self, name).detach()
|
| 62 |
+
new_param = mask_cat_perm(
|
| 63 |
+
ori_param,
|
| 64 |
+
kept_idx=kept_idx).requires_grad_()
|
| 65 |
+
setattr(self, name, new_param)
|
| 66 |
+
del ori_param, new_param
|
| 67 |
+
torch.cuda.empty_cache()
|
| 68 |
+
|
| 69 |
+
# Prune trainable grid points parameters (on voxel corners).
|
| 70 |
+
for name in self.grid_pts_param_lst:
|
| 71 |
+
ori_grid_pts = getattr(self, name).detach()
|
| 72 |
+
|
| 73 |
+
# Update parameter
|
| 74 |
+
ori_vox_grid_pts_val = ori_grid_pts[old_vox_key]
|
| 75 |
+
new_vox_val = mask_cat_perm(
|
| 76 |
+
ori_vox_grid_pts_val,
|
| 77 |
+
kept_idx=kept_idx)
|
| 78 |
+
new_param = agg_voxel_into_grid_pts(
|
| 79 |
+
self.num_grid_pts, # It's the updated one
|
| 80 |
+
self.vox_key,
|
| 81 |
+
new_vox_val).requires_grad_()
|
| 82 |
+
setattr(self, name, new_param)
|
| 83 |
+
del ori_grid_pts, ori_vox_grid_pts_val, new_vox_val, new_param
|
| 84 |
+
torch.cuda.empty_cache()
|
| 85 |
+
|
| 86 |
+
@torch.no_grad()
|
| 87 |
+
def subdividing(self, subdivide_mask):
|
| 88 |
+
'''
|
| 89 |
+
Prune sparse voxels. The grid points are updated accordingly.
|
| 90 |
+
|
| 91 |
+
Input:
|
| 92 |
+
@subdivide_mask [N] Mask indicating the voxels to subdivide.
|
| 93 |
+
'''
|
| 94 |
+
# Compute voxel index to keep and to subdivided
|
| 95 |
+
if len(subdivide_mask.shape) == 2:
|
| 96 |
+
assert subdivide_mask.shape[1] == 1
|
| 97 |
+
subdivide_mask = subdivide_mask.squeeze(1)
|
| 98 |
+
assert subdivide_mask.shape == (self.num_voxels, )
|
| 99 |
+
kept_idx = (~subdivide_mask).argwhere().squeeze(1)
|
| 100 |
+
subdivide_idx = subdivide_mask.argwhere().squeeze(1)
|
| 101 |
+
if len(subdivide_idx) == 0:
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
old_vox_key = self.vox_key.clone()
|
| 105 |
+
|
| 106 |
+
# Subdivide non-trainable per-voxel attributes.
|
| 107 |
+
octpath, octlevel = octree_utils.gen_children(
|
| 108 |
+
self.octpath[subdivide_idx],
|
| 109 |
+
self.octlevel[subdivide_idx])
|
| 110 |
+
|
| 111 |
+
special_subdiv = dict(
|
| 112 |
+
octpath=octpath,
|
| 113 |
+
octlevel=octlevel,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
for name in self.per_voxel_attr_lst:
|
| 117 |
+
ori_attr = getattr(self, name)
|
| 118 |
+
if name in special_subdiv:
|
| 119 |
+
subdiv_attr = special_subdiv.pop(name)
|
| 120 |
+
else:
|
| 121 |
+
subdiv_attr = ori_attr[subdivide_idx].repeat_interleave(8, dim=0)
|
| 122 |
+
new_attr = mask_cat_perm(
|
| 123 |
+
ori_attr,
|
| 124 |
+
kept_idx=kept_idx,
|
| 125 |
+
cat_tensor=subdiv_attr)
|
| 126 |
+
setattr(self, name, new_attr)
|
| 127 |
+
if name == '_subdiv_p' and ori_attr.grad is not None:
|
| 128 |
+
self._subdiv_p.grad = mask_cat_perm(
|
| 129 |
+
ori_attr.grad,
|
| 130 |
+
kept_idx=kept_idx,
|
| 131 |
+
cat_tensor=subdiv_attr)
|
| 132 |
+
self._subdiv_p.requires_grad_()
|
| 133 |
+
del ori_attr, subdiv_attr
|
| 134 |
+
|
| 135 |
+
assert len(special_subdiv) == 0
|
| 136 |
+
torch.cuda.empty_cache()
|
| 137 |
+
|
| 138 |
+
# Subdivide trainable per-voxel parameters.
|
| 139 |
+
for name in self.per_voxel_param_lst:
|
| 140 |
+
ori_param = getattr(self, name).detach()
|
| 141 |
+
|
| 142 |
+
# Update parameter
|
| 143 |
+
subdiv_param = ori_param[subdivide_idx].repeat_interleave(8, dim=0)
|
| 144 |
+
new_param = mask_cat_perm(
|
| 145 |
+
ori_param,
|
| 146 |
+
kept_idx=kept_idx,
|
| 147 |
+
cat_tensor=subdiv_param).requires_grad_()
|
| 148 |
+
setattr(self, name, new_param)
|
| 149 |
+
del ori_param, subdiv_param, new_param
|
| 150 |
+
torch.cuda.empty_cache()
|
| 151 |
+
|
| 152 |
+
# Subdivide grid points parameters (on voxel corners).
|
| 153 |
+
for name in self.grid_pts_param_lst:
|
| 154 |
+
ori_grid_pts = getattr(self, name).detach()
|
| 155 |
+
|
| 156 |
+
# Update parameter
|
| 157 |
+
# First we gather grid_pts values into each voxel first.
|
| 158 |
+
# The voxel is then subdivided by trilinear interpolation.
|
| 159 |
+
# Finally, we gather voxel values back to the grid_pts.
|
| 160 |
+
ori_vox_grid_pts_val = ori_grid_pts[old_vox_key]
|
| 161 |
+
subdiv_vox_grid_pts_val = subdivide_by_interp(
|
| 162 |
+
ori_vox_grid_pts_val[subdivide_idx])
|
| 163 |
+
new_vox_val = mask_cat_perm(
|
| 164 |
+
ori_vox_grid_pts_val,
|
| 165 |
+
kept_idx=kept_idx,
|
| 166 |
+
cat_tensor=subdiv_vox_grid_pts_val)
|
| 167 |
+
del ori_grid_pts, ori_vox_grid_pts_val, subdiv_vox_grid_pts_val
|
| 168 |
+
|
| 169 |
+
new_param = agg_voxel_into_grid_pts(
|
| 170 |
+
self.num_grid_pts, # It's the updated one
|
| 171 |
+
self.vox_key,
|
| 172 |
+
new_vox_val).cuda().requires_grad_()
|
| 173 |
+
setattr(self, name, new_param)
|
| 174 |
+
del new_vox_val, new_param
|
| 175 |
+
torch.cuda.empty_cache()
|
| 176 |
+
|
| 177 |
+
@torch.no_grad()
|
| 178 |
+
def sh_degree_add1(self):
|
| 179 |
+
if self.active_sh_degree < self.max_sh_degree:
|
| 180 |
+
self.active_sh_degree += 1
|
| 181 |
+
|
| 182 |
+
@torch.no_grad()
|
| 183 |
+
def compute_training_stat(self, camera_lst):
|
| 184 |
+
'''
|
| 185 |
+
Compute the following statistic of each voxel from the given cameras.
|
| 186 |
+
1. max_w: the maximum blending weight.
|
| 187 |
+
2. min_samp_interval: the minimum sampling interval (inverse of maximum sampling rate).
|
| 188 |
+
3. view_cnt: number of cameras with non-zero blending weight.
|
| 189 |
+
|
| 190 |
+
Input:
|
| 191 |
+
@camera_lst [Camera, ...] A list of cameras.
|
| 192 |
+
'''
|
| 193 |
+
self.freeze_vox_geo()
|
| 194 |
+
max_w = torch.zeros([self.num_voxels, 1], dtype=torch.float32, device="cuda")
|
| 195 |
+
min_samp_interval = torch.full([self.num_voxels, 1], 1e30, dtype=torch.float32, device="cuda")
|
| 196 |
+
view_cnt = torch.zeros([self.num_voxels, 1], dtype=torch.float32, device="cuda")
|
| 197 |
+
for camera in camera_lst:
|
| 198 |
+
max_w_i = self.render(camera, color_mode='dontcare', track_max_w=True)['max_w']
|
| 199 |
+
max_w = torch.maximum(max_w, max_w_i)
|
| 200 |
+
|
| 201 |
+
vis_idx = (max_w_i > 0).squeeze().argwhere().squeeze()
|
| 202 |
+
zdist = ((self.vox_center[vis_idx] - camera.position) * camera.lookat).sum(-1, keepdims=True)
|
| 203 |
+
samp_interval = zdist * camera.pix_size
|
| 204 |
+
min_samp_interval[vis_idx] = torch.minimum(min_samp_interval[vis_idx], samp_interval)
|
| 205 |
+
|
| 206 |
+
view_cnt[vis_idx] += 1
|
| 207 |
+
|
| 208 |
+
stat_pkg = {
|
| 209 |
+
'max_w': max_w,
|
| 210 |
+
'min_samp_interval': min_samp_interval,
|
| 211 |
+
'view_cnt': view_cnt,
|
| 212 |
+
}
|
| 213 |
+
self.unfreeze_vox_geo()
|
| 214 |
+
return stat_pkg
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Some helpful functions
|
| 218 |
+
def mask_cat_perm(tensor, kept_idx=None, cat_tensor=None, perm=None):
|
| 219 |
+
'''
|
| 220 |
+
Perform tensor masking, concatenation, and permutation.
|
| 221 |
+
'''
|
| 222 |
+
if kept_idx is None and cat_tensor is None and perm is None:
|
| 223 |
+
raise Exception("No op for mask_cat_perm??")
|
| 224 |
+
device = tensor.device
|
| 225 |
+
if kept_idx is not None:
|
| 226 |
+
tensor = tensor[kept_idx.to(device)]
|
| 227 |
+
if cat_tensor is not None:
|
| 228 |
+
tensor = torch.cat([tensor, cat_tensor.to(device)])
|
| 229 |
+
if perm is not None:
|
| 230 |
+
assert len(perm) == len(tensor)
|
| 231 |
+
tensor = tensor[perm.to(device)]
|
| 232 |
+
return tensor.contiguous()
|
| 233 |
+
|
| 234 |
+
def agg_voxel_into_grid_pts(num_grid_pts, vox_key, vox_val, reduce='mean'):
|
| 235 |
+
'''
|
| 236 |
+
Aggregate per-voxel data into their eight grid points.
|
| 237 |
+
Input:
|
| 238 |
+
@num_grid_pts Number of final grid points.
|
| 239 |
+
@vox_key [N, 8] Index to the eight grid points of each voxel.
|
| 240 |
+
@vox_val [N, 8, *] Data of the eight grid points of each voxel.
|
| 241 |
+
Output:
|
| 242 |
+
@new_param [num_grid_pts, *] Grid points data aggregated from vox_val.
|
| 243 |
+
'''
|
| 244 |
+
ch = vox_val.shape[2:]
|
| 245 |
+
device = vox_val.device
|
| 246 |
+
vox_key = vox_key.to(device)
|
| 247 |
+
new_param = torch.zeros([num_grid_pts, *ch], dtype=torch.float32, device=device)
|
| 248 |
+
new_param.index_reduce_(
|
| 249 |
+
dim=0,
|
| 250 |
+
index=vox_key.flatten(),
|
| 251 |
+
source=vox_val.flatten(0,1),
|
| 252 |
+
reduce=reduce,
|
| 253 |
+
include_self=False)
|
| 254 |
+
# Equivalent implementation by old API
|
| 255 |
+
# new_param /= vox_key.flatten().bincount(minlength=num_grid_pts).unsqueeze(-1)
|
| 256 |
+
# new_param.nan_to_num_()
|
| 257 |
+
return new_param.contiguous()
|
| 258 |
+
|
| 259 |
+
def subdivide_by_interp(vox_val):
|
| 260 |
+
'''
|
| 261 |
+
Subdivide grid point data by trilinear interpolation.
|
| 262 |
+
The subdivided children order is the same as those from `_subdivide_attr` and `gen_children`.
|
| 263 |
+
Input:
|
| 264 |
+
@vox_val [N, 8, *] Data of the eight grid points of each voxel.
|
| 265 |
+
Output:
|
| 266 |
+
@new_vox_val [8N, 8, *] Data of the eight grid points of the subdivided voxel.
|
| 267 |
+
'''
|
| 268 |
+
vox_val = vox_val.contiguous()
|
| 269 |
+
main_idx = torch.arange(8, dtype=torch.int64, device=vox_val.device)
|
| 270 |
+
new_vox_val = torch.zeros([len(vox_val), 8, *vox_val.shape[1:]], device=vox_val.device)
|
| 271 |
+
new_vox_val[:, main_idx, main_idx] = vox_val
|
| 272 |
+
new_vox_val[:, main_idx, main_idx^0b001] = 0.5 * (vox_val + vox_val[:, main_idx^0b001])
|
| 273 |
+
new_vox_val[:, main_idx, main_idx^0b010] = 0.5 * (vox_val + vox_val[:, main_idx^0b010])
|
| 274 |
+
new_vox_val[:, main_idx, main_idx^0b100] = 0.5 * (vox_val + vox_val[:, main_idx^0b100])
|
| 275 |
+
new_vox_val[:, main_idx, main_idx^0b011] = 0.25 * (
|
| 276 |
+
vox_val + \
|
| 277 |
+
vox_val[:, main_idx^0b001] + \
|
| 278 |
+
vox_val[:, main_idx^0b010] + \
|
| 279 |
+
vox_val[:, main_idx^0b011]
|
| 280 |
+
)
|
| 281 |
+
new_vox_val[:, main_idx, main_idx^0b101] = 0.25 * (
|
| 282 |
+
vox_val + \
|
| 283 |
+
vox_val[:, main_idx^0b001] + \
|
| 284 |
+
vox_val[:, main_idx^0b100] + \
|
| 285 |
+
vox_val[:, main_idx^0b101]
|
| 286 |
+
)
|
| 287 |
+
new_vox_val[:, main_idx, main_idx^0b110] = 0.25 * (
|
| 288 |
+
vox_val + \
|
| 289 |
+
vox_val[:, main_idx^0b010] + \
|
| 290 |
+
vox_val[:, main_idx^0b100] + \
|
| 291 |
+
vox_val[:, main_idx^0b110]
|
| 292 |
+
)
|
| 293 |
+
new_vox_val[:, main_idx, main_idx^0b111] = vox_val.mean(1, keepdim=True)
|
| 294 |
+
|
| 295 |
+
new_vox_val = new_vox_val.reshape(len(vox_val)*8, *vox_val.shape[1:])
|
| 296 |
+
return new_vox_val.contiguous()
|
src/sparse_voxel_gears/constructor.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import svraster_cuda
|
| 12 |
+
|
| 13 |
+
from src.utils.activation_utils import rgb2shzero
|
| 14 |
+
from src.utils import octree_utils
|
| 15 |
+
|
| 16 |
+
class SVConstructor:
|
| 17 |
+
|
| 18 |
+
def model_init(self,
|
| 19 |
+
bounding, # Scene bound [min_xyz, max_xyz]
|
| 20 |
+
outside_level, # Number of Octree levels for background
|
| 21 |
+
init_n_level=6, # Starting from (2^init_n_level)^3 voxels
|
| 22 |
+
init_out_ratio=2.0, # Number of voxel ratio for outside (background region)
|
| 23 |
+
sh_degree_init=3, # Initial activated sh degree
|
| 24 |
+
geo_init=-10.0, # Init pre-activation density
|
| 25 |
+
sh0_init=0.5, # Init voxel colors in range [0,1]
|
| 26 |
+
shs_init=0.0, # Init coefficients of higher-degree sh
|
| 27 |
+
cameras=None, # Cameras that helps voxel allocation
|
| 28 |
+
):
|
| 29 |
+
|
| 30 |
+
assert outside_level <= svraster_cuda.meta.MAX_NUM_LEVELS
|
| 31 |
+
|
| 32 |
+
# Define scene bound
|
| 33 |
+
center = (bounding[0] + bounding[1]) * 0.5
|
| 34 |
+
extent = max(bounding[1] - bounding[0])
|
| 35 |
+
self.scene_center, self.scene_extent, self.inside_extent = get_scene_bound_tensor(
|
| 36 |
+
center=center, extent=extent, outside_level=outside_level)
|
| 37 |
+
|
| 38 |
+
# Init voxel layout.
|
| 39 |
+
# The world is seperated into inside (main foreground) and outside (background) regions.
|
| 40 |
+
in_path, in_level = octlayout_inside_uniform(
|
| 41 |
+
scene_center=self.scene_center,
|
| 42 |
+
scene_extent=self.scene_extent,
|
| 43 |
+
outside_level=outside_level,
|
| 44 |
+
n_level=init_n_level,
|
| 45 |
+
cameras=cameras,
|
| 46 |
+
filter_zero_visiblity=(cameras is not None),
|
| 47 |
+
filter_near=-1)
|
| 48 |
+
|
| 49 |
+
if outside_level == 0:
|
| 50 |
+
# Object centric bounded scenes
|
| 51 |
+
ou_path = torch.empty([0, 1], dtype=in_path.dtype, device="cuda")
|
| 52 |
+
ou_level = torch.empty([0, 1], dtype=in_level.dtype, device="cuda")
|
| 53 |
+
else:
|
| 54 |
+
min_num = len(in_path) * init_out_ratio
|
| 55 |
+
max_level = outside_level + init_n_level
|
| 56 |
+
ou_path, ou_level = octlayout_outside_heuristic(
|
| 57 |
+
scene_center=self.scene_center,
|
| 58 |
+
scene_extent=self.scene_extent,
|
| 59 |
+
outside_level=outside_level,
|
| 60 |
+
cameras=cameras,
|
| 61 |
+
min_num=min_num,
|
| 62 |
+
max_level=max_level,
|
| 63 |
+
filter_near=-1)
|
| 64 |
+
|
| 65 |
+
self.octpath = torch.cat([ou_path, in_path])
|
| 66 |
+
self.octlevel = torch.cat([ou_level, in_level])
|
| 67 |
+
|
| 68 |
+
self.active_sh_degree = min(sh_degree_init, self.max_sh_degree)
|
| 69 |
+
|
| 70 |
+
# Init trainable parameters
|
| 71 |
+
self._geo_grid_pts = torch.full(
|
| 72 |
+
[self.num_grid_pts, 1], geo_init,
|
| 73 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 74 |
+
|
| 75 |
+
self._sh0 = torch.full(
|
| 76 |
+
[self.num_voxels, 3], rgb2shzero(sh0_init),
|
| 77 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 78 |
+
|
| 79 |
+
self._shs = torch.full(
|
| 80 |
+
[self.num_voxels, (self.max_sh_degree+1)**2 - 1, 3], shs_init,
|
| 81 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 82 |
+
|
| 83 |
+
# Subdivision priority trackor
|
| 84 |
+
self._subdiv_p = torch.ones(
|
| 85 |
+
[self.num_voxels, 1],
|
| 86 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 87 |
+
|
| 88 |
+
def octpath_init(self,
|
| 89 |
+
scene_center,
|
| 90 |
+
scene_extent,
|
| 91 |
+
octpath, # Nx1 octpath.
|
| 92 |
+
octlevel, # Nx1 or scalar for the Octree level of each voxel.
|
| 93 |
+
|
| 94 |
+
# The following are model parameters.
|
| 95 |
+
# If the input are tensors, the gradient of rendering can be backprop to them.
|
| 96 |
+
# Otherwise, it creates new trainable tensors.
|
| 97 |
+
rgb=0.5, # Nx3 or scalar for voxel color in range of 0~1.
|
| 98 |
+
shs=0.0, # NxDx3 or scalar for voxel higher-deg sh coefficient.
|
| 99 |
+
density=-10., # Nx8 or Ngridx1 or scalar for voxel density field.
|
| 100 |
+
# The order is [0,0,0] => [0,0,1] => [0,1,0] => [0,1,1] ...
|
| 101 |
+
reduce_density=False, # Whether to merge grid points if density is Nx8.
|
| 102 |
+
):
|
| 103 |
+
|
| 104 |
+
self.scene_center, self.scene_extent, self.inside_extent = get_scene_bound_tensor(
|
| 105 |
+
center=scene_center, extent=scene_extent)
|
| 106 |
+
|
| 107 |
+
assert torch.is_tensor(octpath)
|
| 108 |
+
octlevel = get_octlevel_tensor(octlevel, num_voxels=len(octpath))
|
| 109 |
+
|
| 110 |
+
self.octpath = octpath.view(-1, 1).contiguous()
|
| 111 |
+
self.octlevel = octlevel.view(-1, 1).contiguous()
|
| 112 |
+
assert len(self.octpath) == len(self.octlevel)
|
| 113 |
+
|
| 114 |
+
# Subdivision priority trackor
|
| 115 |
+
self._subdiv_p = torch.ones(
|
| 116 |
+
[self.num_voxels, 1],
|
| 117 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 118 |
+
|
| 119 |
+
# Setup appearence parameters
|
| 120 |
+
if torch.is_tensor(rgb):
|
| 121 |
+
assert rgb.shape == (self.num_voxels, 3)
|
| 122 |
+
self._sh0 = rgb2shzero(rgb.contiguous().cuda())
|
| 123 |
+
else:
|
| 124 |
+
self._sh0 = torch.full(
|
| 125 |
+
[self.num_voxels, 3], rgb2shzero(rgb),
|
| 126 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 127 |
+
|
| 128 |
+
if torch.is_tensor(shs):
|
| 129 |
+
assert shs.shape == (self.num_voxels, (self.max_sh_degree+1)**2 - 1, 3)
|
| 130 |
+
self.shs = shs.contiguous().cuda()
|
| 131 |
+
else:
|
| 132 |
+
self._shs = torch.full(
|
| 133 |
+
[self.num_voxels, (self.max_sh_degree+1)**2 - 1, 3], shs,
|
| 134 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 135 |
+
|
| 136 |
+
# Setup geometry parameters
|
| 137 |
+
if torch.is_tensor(density):
|
| 138 |
+
if density.shape == (self.num_grid_pts, 1):
|
| 139 |
+
self._geo_grid_pts = density.contiguous().cuda()
|
| 140 |
+
elif density.shape == (self.num_voxels, 8):
|
| 141 |
+
if reduce_density:
|
| 142 |
+
self._geo_grid_pts = torch.zeros(
|
| 143 |
+
[self.num_grid_pts, 1], dtype=torch.float32, device="cuda")
|
| 144 |
+
self._geo_grid_pts.index_reduce_(
|
| 145 |
+
dim=0,
|
| 146 |
+
index=self.vox_key.flatten(),
|
| 147 |
+
source=density.flatten(),
|
| 148 |
+
reduce="mean",
|
| 149 |
+
include_self=False)
|
| 150 |
+
else:
|
| 151 |
+
self.frozen_vox_geo = density.contiguous().cuda()
|
| 152 |
+
else:
|
| 153 |
+
raise Exception(f"Unexpected density shape. "
|
| 154 |
+
f"It should be either {(self.num_grid_pts,1)} or {(self.num_voxels,8)}")
|
| 155 |
+
else:
|
| 156 |
+
self._geo_grid_pts = torch.full(
|
| 157 |
+
[self.num_grid_pts, 1], density,
|
| 158 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 159 |
+
|
| 160 |
+
def ijkl_init(self,
|
| 161 |
+
scene_center,
|
| 162 |
+
scene_extent,
|
| 163 |
+
ijk, # Nx3 integer coordinates of each voxel.
|
| 164 |
+
octlevel, # Nx1 or scalar for the Octree level of each voxel.
|
| 165 |
+
|
| 166 |
+
# The following are model parameters.
|
| 167 |
+
# If the input are tensors, the gradient of rendering can be backprop to them.
|
| 168 |
+
# Otherwise, it creates new trainable tensors.
|
| 169 |
+
rgb=0.5, # Nx3 or scalar for voxel color in range of 0~1.
|
| 170 |
+
shs=0.0, # NxDx3 or scalar for voxel higher-deg sh coefficient.
|
| 171 |
+
density=-10., # Nx8 or Ngridx1 or scalar for voxel density field.
|
| 172 |
+
# The order is [0,0,0] => [0,0,1] => [0,1,0] => [0,1,1] ...
|
| 173 |
+
reduce_density=False, # Whether to merge grid points if density is Nx8.
|
| 174 |
+
):
|
| 175 |
+
|
| 176 |
+
scene_center, scene_extent, _ = get_scene_bound_tensor(
|
| 177 |
+
center=scene_center, extent=scene_extent)
|
| 178 |
+
|
| 179 |
+
# Convert to ijkl to octpath
|
| 180 |
+
octlevel = get_octlevel_tensor(octlevel, num_voxels=len(ijk))
|
| 181 |
+
|
| 182 |
+
assert torch.is_tensor(ijk)
|
| 183 |
+
assert len(ijk.shape) == 2 and ijk.shape[1] == 3
|
| 184 |
+
assert len(ijk) == len(octlevel)
|
| 185 |
+
ijk = ijk.long()
|
| 186 |
+
if (ijk < 0).any():
|
| 187 |
+
raise Exception("xyz out of scene bound")
|
| 188 |
+
if (ijk >= (1 << octlevel.long())).any():
|
| 189 |
+
raise Exception("xyz out of scene bound")
|
| 190 |
+
octpath = svraster_cuda.utils.ijk_2_octpath(ijk, octlevel)
|
| 191 |
+
|
| 192 |
+
self.octpath_init(
|
| 193 |
+
scene_center=scene_center,
|
| 194 |
+
scene_extent=scene_extent,
|
| 195 |
+
octpath=octpath,
|
| 196 |
+
octlevel=octlevel,
|
| 197 |
+
rgb=rgb,
|
| 198 |
+
shs=shs,
|
| 199 |
+
density=density,
|
| 200 |
+
reduce_density=reduce_density)
|
| 201 |
+
|
| 202 |
+
def points_init(self,
|
| 203 |
+
scene_center,
|
| 204 |
+
scene_extent,
|
| 205 |
+
xyz, # Nx3 point coordinates in world space.
|
| 206 |
+
octlevel=None, # Nx1 or scalar for the Octree level of each voxel.
|
| 207 |
+
expected_vox_size=None,
|
| 208 |
+
level_round_mode='nearest',
|
| 209 |
+
|
| 210 |
+
# The following are model parameters.
|
| 211 |
+
# If the input are tensors, the gradient of rendering can be backprop to them.
|
| 212 |
+
# Otherwise, it creates new trainable tensors.
|
| 213 |
+
rgb=0.5, # Nx3 or scalar for voxel color in range of 0~1.
|
| 214 |
+
shs=0.0, # NxDx3 or scalar for voxel higher-deg sh coefficient.
|
| 215 |
+
density=-10., # Nx8 or scalar for voxel density field.
|
| 216 |
+
# The order is [0,0,0] => [0,0,1] => [0,1,0] => [0,1,1] ...
|
| 217 |
+
reduce_density=False, # Whether to merge grid points if density is Nx8.
|
| 218 |
+
):
|
| 219 |
+
|
| 220 |
+
scene_center, scene_extent, _ = get_scene_bound_tensor(center=scene_center, extent=scene_extent)
|
| 221 |
+
|
| 222 |
+
# Compute voxel level
|
| 223 |
+
if octlevel is not None:
|
| 224 |
+
assert expected_vox_size is None
|
| 225 |
+
octlevel = get_octlevel_tensor(octlevel, num_voxels=len(xyz))
|
| 226 |
+
elif expected_vox_size is not None:
|
| 227 |
+
octlevel_fp32 = octree_utils.vox_size_2_level(scene_extent, expected_vox_size)
|
| 228 |
+
if level_round_mode == "nearest":
|
| 229 |
+
octlevel_fp32 = octlevel_fp32.round()
|
| 230 |
+
elif level_round_mode == "down":
|
| 231 |
+
octlevel_fp32 = octlevel_fp32.floor()
|
| 232 |
+
elif level_round_mode == "up":
|
| 233 |
+
octlevel_fp32 = octlevel_fp32.ceil()
|
| 234 |
+
else:
|
| 235 |
+
raise Exception("Unknonw level_round_mode")
|
| 236 |
+
octlevel_fp32 = octlevel_fp32.clamp(1, svraster_cuda.meta.MAX_NUM_LEVELS)
|
| 237 |
+
octlevel = get_octlevel_tensor(octlevel_fp32.to(torch.int8), num_voxels=len(xyz))
|
| 238 |
+
else:
|
| 239 |
+
raise Exception("Either octlevel or expected_vox_size should be given.")
|
| 240 |
+
|
| 241 |
+
# Transform point to ijk integer coordinate
|
| 242 |
+
scene_min_xyz = scene_center - 0.5 * scene_extent
|
| 243 |
+
vox_size = octree_utils.level_2_vox_size(scene_extent, octlevel)
|
| 244 |
+
ijk = ((xyz - scene_min_xyz) / vox_size).long()
|
| 245 |
+
|
| 246 |
+
# Reduce duplicated tensor
|
| 247 |
+
ijkl = torch.cat([ijk, octlevel], dim=1)
|
| 248 |
+
ijkl_unq, invmap = ijkl.unique(dim=0, return_inverse=True)
|
| 249 |
+
ijk, octlevel = ijkl_unq.split([3, 1], dim=1)
|
| 250 |
+
octlevel = octlevel.to(torch.int8)
|
| 251 |
+
|
| 252 |
+
if torch.is_tensor(rgb):
|
| 253 |
+
assert rgb.shape == (len(invmap), 3)
|
| 254 |
+
new_shape = (len(ijk), 3)
|
| 255 |
+
rgb = torch.zeros(new_shape, dtype=torch.float32, device="cuda").index_reduce_(
|
| 256 |
+
dim=0,
|
| 257 |
+
index=invmap,
|
| 258 |
+
source=rgb,
|
| 259 |
+
reduce="mean",
|
| 260 |
+
include_self=False)
|
| 261 |
+
|
| 262 |
+
if torch.is_tensor(shs):
|
| 263 |
+
assert shs.shape == (len(invmap), (self.max_sh_degree+1)**2 - 1, 3)
|
| 264 |
+
new_shape = (len(ijk), (self.max_sh_degree+1)**2 - 1, 3)
|
| 265 |
+
shs = torch.zeros(new_shape, dtype=torch.float32, device="cuda").index_reduce_(
|
| 266 |
+
dim=0,
|
| 267 |
+
index=invmap,
|
| 268 |
+
source=shs,
|
| 269 |
+
reduce="mean",
|
| 270 |
+
include_self=False)
|
| 271 |
+
|
| 272 |
+
if torch.is_tensor(density):
|
| 273 |
+
assert density.shape == (len(invmap), 8)
|
| 274 |
+
new_shape = (len(ijk), 8)
|
| 275 |
+
density = torch.zeros(new_shape, dtype=torch.float32, device="cuda").index_reduce_(
|
| 276 |
+
dim=0,
|
| 277 |
+
index=invmap,
|
| 278 |
+
source=density,
|
| 279 |
+
reduce="mean",
|
| 280 |
+
include_self=False)
|
| 281 |
+
|
| 282 |
+
# Allocate voxel using ijkl coordinate
|
| 283 |
+
self.ijkl_init(
|
| 284 |
+
scene_center=scene_center,
|
| 285 |
+
scene_extent=scene_extent,
|
| 286 |
+
ijk=ijk,
|
| 287 |
+
octlevel=octlevel,
|
| 288 |
+
rgb=rgb,
|
| 289 |
+
shs=shs,
|
| 290 |
+
density=density,
|
| 291 |
+
reduce_density=reduce_density)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
#################################################
|
| 295 |
+
# Helper function
|
| 296 |
+
#################################################
|
| 297 |
+
def get_scene_bound_tensor(center, extent, outside_level=0):
|
| 298 |
+
if torch.is_tensor(center):
|
| 299 |
+
scene_center = center.float().clone().cuda()
|
| 300 |
+
else:
|
| 301 |
+
scene_center = torch.tensor(center, dtype=torch.float32, device="cuda")
|
| 302 |
+
|
| 303 |
+
if torch.is_tensor(extent):
|
| 304 |
+
inside_extent = extent.float().clone().cuda()
|
| 305 |
+
else:
|
| 306 |
+
inside_extent = torch.tensor(extent, dtype=torch.float32, device="cuda")
|
| 307 |
+
|
| 308 |
+
scene_extent = inside_extent * (2 ** outside_level)
|
| 309 |
+
|
| 310 |
+
assert scene_center.shape == (3,)
|
| 311 |
+
assert scene_extent.numel() == 1
|
| 312 |
+
|
| 313 |
+
return scene_center, scene_extent, inside_extent
|
| 314 |
+
|
| 315 |
+
def get_octlevel_tensor(octlevel, num_voxels=None):
|
| 316 |
+
if not torch.is_tensor(octlevel):
|
| 317 |
+
assert np.all(octlevel > 0)
|
| 318 |
+
assert np.all(octlevel <= svraster_cuda.meta.MAX_NUM_LEVELS)
|
| 319 |
+
octlevel = torch.tensor(octlevel, dtype=torch.int8, device="cuda")
|
| 320 |
+
if octlevel.numel() == 1:
|
| 321 |
+
octlevel = octlevel.view(1, 1).repeat(num_voxels, 1).contiguous()
|
| 322 |
+
octlevel = octlevel.reshape(-1, 1)
|
| 323 |
+
assert octlevel.dtype == torch.int8
|
| 324 |
+
assert num_voxels is None or octlevel.numel() == num_voxels
|
| 325 |
+
|
| 326 |
+
return octlevel
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
#################################################
|
| 330 |
+
# Octree layout construction heuristic
|
| 331 |
+
#################################################
|
| 332 |
+
def octlayout_filtering(octpath, octlevel, scene_center, scene_extent, cameras=None, filter_zero_visiblity=True, filter_near=-1):
|
| 333 |
+
|
| 334 |
+
vox_center, vox_size = octree_utils.octpath_decoding(
|
| 335 |
+
octpath, octlevel,
|
| 336 |
+
scene_center, scene_extent)
|
| 337 |
+
|
| 338 |
+
# Filtering
|
| 339 |
+
kept_mask = torch.ones([len(octpath)], dtype=torch.bool, device="cuda")
|
| 340 |
+
if filter_zero_visiblity:
|
| 341 |
+
assert cameras is not None, "Cameras should be given to filter invisible voxels"
|
| 342 |
+
rate = svraster_cuda.renderer.mark_max_samp_rate(
|
| 343 |
+
cameras, octpath, vox_center, vox_size)
|
| 344 |
+
kept_mask &= (rate > 0)
|
| 345 |
+
if filter_near > 0:
|
| 346 |
+
is_near = svraster_cuda.renderer.mark_near(
|
| 347 |
+
cameras, octpath, vox_center, vox_size, near=filter_near)
|
| 348 |
+
kept_mask &= (~is_near)
|
| 349 |
+
kept_idx = torch.where(kept_mask)[0]
|
| 350 |
+
octpath = octpath[kept_idx]
|
| 351 |
+
octlevel = octlevel[kept_idx]
|
| 352 |
+
return octpath, octlevel
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def octlayout_inside_uniform(scene_center, scene_extent, outside_level, n_level, cameras=None, filter_zero_visiblity=True, filter_near=-1):
|
| 356 |
+
octpath, octlevel = octree_utils.gen_octpath_dense(
|
| 357 |
+
outside_level=outside_level,
|
| 358 |
+
n_level_inside=n_level)
|
| 359 |
+
|
| 360 |
+
octpath, octlevel = octlayout_filtering(
|
| 361 |
+
octpath=octpath,
|
| 362 |
+
octlevel=octlevel,
|
| 363 |
+
scene_center=scene_center,
|
| 364 |
+
scene_extent=scene_extent,
|
| 365 |
+
cameras=cameras,
|
| 366 |
+
filter_zero_visiblity=filter_zero_visiblity,
|
| 367 |
+
filter_near=filter_near)
|
| 368 |
+
return octpath, octlevel
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def octlayout_outside_heuristic(scene_center, scene_extent, outside_level, cameras, min_num, max_level, filter_near=-1):
|
| 372 |
+
|
| 373 |
+
assert cameras is not None, "Cameras should provided in this mode."
|
| 374 |
+
|
| 375 |
+
# Init by adding one sub-level in each shell level
|
| 376 |
+
octpath = []
|
| 377 |
+
octlevel = []
|
| 378 |
+
for lv in range(1, 1+outside_level):
|
| 379 |
+
path, lv = octree_utils.gen_octpath_shell(
|
| 380 |
+
shell_level=lv,
|
| 381 |
+
n_level_inside=1)
|
| 382 |
+
octpath.append(path)
|
| 383 |
+
octlevel.append(lv)
|
| 384 |
+
octpath = torch.cat(octpath)
|
| 385 |
+
octlevel = torch.cat(octlevel)
|
| 386 |
+
|
| 387 |
+
# Iteratively subdivide voxels with maximum sampling rate
|
| 388 |
+
while True:
|
| 389 |
+
vox_center, vox_size = octree_utils.octpath_decoding(
|
| 390 |
+
octpath, octlevel, scene_center, scene_extent)
|
| 391 |
+
samp_rate = svraster_cuda.renderer.mark_max_samp_rate(
|
| 392 |
+
cameras, octpath, vox_center, vox_size)
|
| 393 |
+
|
| 394 |
+
kept_idx = torch.where((samp_rate > 0))[0]
|
| 395 |
+
octpath = octpath[kept_idx]
|
| 396 |
+
octlevel = octlevel[kept_idx]
|
| 397 |
+
octlevel_mask = (octlevel.squeeze(1) < max_level)
|
| 398 |
+
samp_rate = samp_rate[kept_idx] * octlevel_mask
|
| 399 |
+
vox_size = vox_size[kept_idx]
|
| 400 |
+
still_need_n = (min_num - len(octpath)) // 7
|
| 401 |
+
still_need_n = min(len(octpath), round(still_need_n))
|
| 402 |
+
if still_need_n <= 0:
|
| 403 |
+
break
|
| 404 |
+
rank = samp_rate * (octlevel.squeeze(1) < svraster_cuda.meta.MAX_NUM_LEVELS)
|
| 405 |
+
subdiv_mask = (rank >= rank.sort().values[-still_need_n])
|
| 406 |
+
subdiv_mask &= (octlevel.squeeze(1) < svraster_cuda.meta.MAX_NUM_LEVELS)
|
| 407 |
+
subdiv_mask &= octlevel_mask
|
| 408 |
+
samp_rate *= subdiv_mask
|
| 409 |
+
subdiv_mask &= (samp_rate >= samp_rate.quantile(0.9)) # Subdivide only 10% each iteration
|
| 410 |
+
if subdiv_mask.sum() == 0:
|
| 411 |
+
break
|
| 412 |
+
octpath_children, octlevel_children = octree_utils.gen_children(
|
| 413 |
+
octpath[subdiv_mask], octlevel[subdiv_mask])
|
| 414 |
+
octpath = torch.cat([octpath[~subdiv_mask], octpath_children])
|
| 415 |
+
octlevel = torch.cat([octlevel[~subdiv_mask], octlevel_children])
|
| 416 |
+
|
| 417 |
+
octpath, octlevel = octlayout_filtering(
|
| 418 |
+
octpath=octpath,
|
| 419 |
+
octlevel=octlevel,
|
| 420 |
+
scene_center=scene_center,
|
| 421 |
+
scene_extent=scene_extent,
|
| 422 |
+
cameras=cameras,
|
| 423 |
+
filter_zero_visiblity=True,
|
| 424 |
+
filter_near=filter_near)
|
| 425 |
+
return octpath, octlevel
|
src/sparse_voxel_gears/io.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from src.utils import octree_utils
|
| 14 |
+
|
| 15 |
+
class SVInOut:
|
| 16 |
+
|
| 17 |
+
def save(self, path, quantize=False):
|
| 18 |
+
'''
|
| 19 |
+
Save the necessary attributes and parameters for reproducing rendering.
|
| 20 |
+
'''
|
| 21 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 22 |
+
state_dict = {
|
| 23 |
+
'active_sh_degree': self.active_sh_degree,
|
| 24 |
+
'scene_center': self.scene_center.data.contiguous(),
|
| 25 |
+
'inside_extent': self.inside_extent.data.contiguous(),
|
| 26 |
+
'scene_extent': self.scene_extent.data.contiguous(),
|
| 27 |
+
'octpath': self.octpath.data.contiguous(),
|
| 28 |
+
'octlevel': self.octlevel.data.contiguous(),
|
| 29 |
+
'_geo_grid_pts': self._geo_grid_pts.data.contiguous(),
|
| 30 |
+
'_sh0': self._sh0.data.contiguous(),
|
| 31 |
+
'_shs': self._shs.data.contiguous(),
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
if quantize:
|
| 35 |
+
quantize_state_dict(state_dict)
|
| 36 |
+
state_dict['quantized'] = True
|
| 37 |
+
else:
|
| 38 |
+
state_dict['quantized'] = False
|
| 39 |
+
|
| 40 |
+
for k, v in state_dict.items():
|
| 41 |
+
if torch.is_tensor(v):
|
| 42 |
+
state_dict[k] = v.cpu()
|
| 43 |
+
torch.save(state_dict, path)
|
| 44 |
+
self.latest_save_path = path
|
| 45 |
+
|
| 46 |
+
def load(self, path):
|
| 47 |
+
'''
|
| 48 |
+
Load the saved models.
|
| 49 |
+
'''
|
| 50 |
+
self.loaded_path = path
|
| 51 |
+
state_dict = torch.load(path, map_location="cpu", weights_only=False)
|
| 52 |
+
|
| 53 |
+
if state_dict.get('quantized', False):
|
| 54 |
+
dequantize_state_dict(state_dict)
|
| 55 |
+
|
| 56 |
+
self.active_sh_degree = state_dict['active_sh_degree']
|
| 57 |
+
|
| 58 |
+
self.scene_center = state_dict['scene_center'].cuda()
|
| 59 |
+
self.inside_extent = state_dict['inside_extent'].cuda()
|
| 60 |
+
self.scene_extent = state_dict['scene_extent'].cuda()
|
| 61 |
+
|
| 62 |
+
self.octpath = state_dict['octpath'].cuda()
|
| 63 |
+
self.octlevel = state_dict['octlevel'].cuda().to(torch.int8)
|
| 64 |
+
|
| 65 |
+
self._geo_grid_pts = state_dict['_geo_grid_pts'].cuda().requires_grad_()
|
| 66 |
+
self._sh0 = state_dict['_sh0'].cuda().requires_grad_()
|
| 67 |
+
self._shs = state_dict['_shs'].cuda().requires_grad_()
|
| 68 |
+
|
| 69 |
+
# Subdivision priority trackor
|
| 70 |
+
self._subdiv_p = torch.ones(
|
| 71 |
+
[self.num_voxels, 1],
|
| 72 |
+
dtype=torch.float32, device="cuda").requires_grad_()
|
| 73 |
+
|
| 74 |
+
def save_iteration(self, model_path, iteration, quantize=False):
|
| 75 |
+
path = os.path.join(model_path, "checkpoints", f"iter{iteration:06d}_model.pt")
|
| 76 |
+
self.save(path, quantize=quantize)
|
| 77 |
+
self.latest_save_iter = iteration
|
| 78 |
+
|
| 79 |
+
def load_iteration(self, model_path, iteration=-1):
|
| 80 |
+
if iteration == -1:
|
| 81 |
+
# Find the maximum iteration if it is -1.
|
| 82 |
+
fnames = os.listdir(os.path.join(model_path, "checkpoints"))
|
| 83 |
+
loaded_iter = max(int(re.sub("[^0-9]", "", fname)) for fname in fnames)
|
| 84 |
+
else:
|
| 85 |
+
loaded_iter = iteration
|
| 86 |
+
|
| 87 |
+
path = os.path.join(model_path, "checkpoints", f"iter{loaded_iter:06d}_model.pt")
|
| 88 |
+
self.load(path)
|
| 89 |
+
|
| 90 |
+
self.loaded_iter = iteration
|
| 91 |
+
|
| 92 |
+
return loaded_iter
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Quantization utilities to reduce size when saving model.
|
| 96 |
+
# It can reduce ~70% model size with minor PSNR drop.
|
| 97 |
+
def quantize_state_dict(state_dict):
|
| 98 |
+
state_dict['_geo_grid_pts'] = quantization(state_dict['_geo_grid_pts'])
|
| 99 |
+
state_dict['_sh0'] = [quantization(v) for v in state_dict['_sh0'].split(1, dim=1)]
|
| 100 |
+
state_dict['_shs'] = [quantization(v) for v in state_dict['_shs'].split(1, dim=1)]
|
| 101 |
+
|
| 102 |
+
def dequantize_state_dict(state_dict):
|
| 103 |
+
state_dict['_geo_grid_pts'] = dequantization(state_dict['_geo_grid_pts'])
|
| 104 |
+
state_dict['_sh0'] = torch.cat(
|
| 105 |
+
[dequantization(v) for v in state_dict['_sh0']], dim=1)
|
| 106 |
+
state_dict['_shs'] = torch.cat(
|
| 107 |
+
[dequantization(v) for v in state_dict['_shs']], dim=1)
|
| 108 |
+
|
| 109 |
+
def quantization(src_tensor, max_iter=10):
|
| 110 |
+
src_shape = src_tensor.shape
|
| 111 |
+
src_vals = src_tensor.flatten().contiguous()
|
| 112 |
+
order = src_vals.argsort()
|
| 113 |
+
quantile_ind = (torch.linspace(0,1,257) * (len(order) - 1)).long().clamp_(0, len(order)-1)
|
| 114 |
+
codebook = src_vals[order[quantile_ind]].contiguous()
|
| 115 |
+
codebook[0] = -torch.inf
|
| 116 |
+
ind = torch.searchsorted(codebook, src_vals)
|
| 117 |
+
|
| 118 |
+
codebook = codebook[1:]
|
| 119 |
+
ind = (ind - 1).clamp_(0, 255)
|
| 120 |
+
|
| 121 |
+
diff_l = (src_vals - codebook[ind-1]).abs()
|
| 122 |
+
diff_m = (src_vals - codebook[ind]).abs()
|
| 123 |
+
ind = ind - 1 + (diff_m < diff_l)
|
| 124 |
+
ind.clamp_(0, 255)
|
| 125 |
+
|
| 126 |
+
for _ in range(max_iter):
|
| 127 |
+
codebook = torch.zeros_like(codebook).index_reduce_(
|
| 128 |
+
dim=0,
|
| 129 |
+
index=ind,
|
| 130 |
+
source=src_vals,
|
| 131 |
+
reduce='mean',
|
| 132 |
+
include_self=False)
|
| 133 |
+
diff_l = (src_vals - codebook[ind-1]).abs()
|
| 134 |
+
diff_r = (src_vals - codebook[(ind+1).clamp_max_(255)]).abs()
|
| 135 |
+
diff_m = (src_vals - codebook[ind]).abs()
|
| 136 |
+
upd_mask = torch.minimum(diff_l, diff_r) < diff_m
|
| 137 |
+
if upd_mask.sum() == 0:
|
| 138 |
+
break
|
| 139 |
+
shift = (diff_r < diff_l) * 2 - 1
|
| 140 |
+
ind[upd_mask] += shift[upd_mask]
|
| 141 |
+
ind.clamp_(0, 255)
|
| 142 |
+
|
| 143 |
+
codebook = torch.zeros_like(codebook).index_reduce_(
|
| 144 |
+
dim=0,
|
| 145 |
+
index=ind,
|
| 146 |
+
source=src_vals,
|
| 147 |
+
reduce='mean',
|
| 148 |
+
include_self=False)
|
| 149 |
+
|
| 150 |
+
return dict(
|
| 151 |
+
index=ind.reshape(src_shape).to(torch.uint8),
|
| 152 |
+
codebook=codebook,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def dequantization(quant_dict):
|
| 156 |
+
return quant_dict['codebook'][quant_dict['index'].long()]
|
src/sparse_voxel_gears/pooling.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import svraster_cuda
|
| 11 |
+
|
| 12 |
+
from src.utils import octree_utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SVPooling:
|
| 16 |
+
|
| 17 |
+
def pooling_to_level(self, max_level, octpath=None, octlevel=None):
|
| 18 |
+
octpath = self.octpath if octpath is None else octpath
|
| 19 |
+
octlevel = self.octlevel if octlevel is None else octlevel
|
| 20 |
+
|
| 21 |
+
num_bit_to_mask = 3 * max(0, svraster_cuda.meta.MAX_NUM_LEVELS - max_level)
|
| 22 |
+
octpath = (octpath >> num_bit_to_mask) << num_bit_to_mask
|
| 23 |
+
octlevel = octlevel.clamp_max(max_level)
|
| 24 |
+
octpack, invmap = torch.stack([octpath, octlevel]).unique(sorted=True, dim=1, return_inverse=True)
|
| 25 |
+
octpath, octlevel = octpack
|
| 26 |
+
octlevel = octlevel.to(torch.int8)
|
| 27 |
+
|
| 28 |
+
vox_center, vox_size = octree_utils.octpath_decoding(
|
| 29 |
+
octpath, octlevel, self.scene_center, self.scene_extent)
|
| 30 |
+
|
| 31 |
+
return dict(
|
| 32 |
+
invmap=invmap,
|
| 33 |
+
octpath=octpath,
|
| 34 |
+
octlevel=octlevel,
|
| 35 |
+
vox_center=vox_center,
|
| 36 |
+
vox_size=vox_size,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def pooling_to_rate(self, cameras, max_rate, octpath=None, octlevel=None):
|
| 40 |
+
octpath = self.octpath.clone() if octpath is None else octpath
|
| 41 |
+
octlevel = self.octlevel.clone() if octlevel is None else octlevel
|
| 42 |
+
invmap = torch.arange(len(octpath), device="cuda")
|
| 43 |
+
|
| 44 |
+
for _ in range(svraster_cuda.meta.MAX_NUM_LEVELS):
|
| 45 |
+
vox_center, vox_size = octree_utils.octpath_decoding(octpath, octlevel, self.scene_center, self.scene_extent)
|
| 46 |
+
samp_rate = svraster_cuda.renderer.mark_max_samp_rate(cameras, octpath, vox_center, vox_size)
|
| 47 |
+
pool_mask = (samp_rate < max_rate) & (octlevel.squeeze(1) > 1)
|
| 48 |
+
if pool_mask.sum() == 0:
|
| 49 |
+
break
|
| 50 |
+
octlevel[pool_mask] = octlevel[pool_mask] - 1
|
| 51 |
+
num_bit_to_mask = 3 * (svraster_cuda.meta.MAX_NUM_LEVELS - octlevel[pool_mask])
|
| 52 |
+
octpath[pool_mask] = octpath[pool_mask] >> num_bit_to_mask << num_bit_to_mask
|
| 53 |
+
|
| 54 |
+
octpack, cur_invmap = torch.stack([octpath, octlevel]).unique(sorted=True, dim=1, return_inverse=True)
|
| 55 |
+
octpath, octlevel = octpack
|
| 56 |
+
octlevel = octlevel.to(torch.int8)
|
| 57 |
+
invmap = cur_invmap[invmap]
|
| 58 |
+
|
| 59 |
+
vox_center, vox_size = octree_utils.octpath_decoding(
|
| 60 |
+
octpath, octlevel, self.scene_center, self.scene_extent)
|
| 61 |
+
|
| 62 |
+
return dict(
|
| 63 |
+
invmap=invmap,
|
| 64 |
+
octpath=octpath,
|
| 65 |
+
octlevel=octlevel,
|
| 66 |
+
vox_center=vox_center,
|
| 67 |
+
vox_size=vox_size,
|
| 68 |
+
)
|
src/sparse_voxel_gears/properties.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from src.utils import octree_utils
|
| 12 |
+
from src.utils.fuser_utils import rgb_fusion
|
| 13 |
+
from src.utils.activation_utils import rgb2shzero
|
| 14 |
+
|
| 15 |
+
import svraster_cuda
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SVProperties:
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def num_voxels(self):
|
| 22 |
+
return len(self.octpath)
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def num_grid_pts(self):
|
| 26 |
+
return len(self.grid_pts_key)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def scene_min(self):
|
| 30 |
+
return self.scene_center - 0.5 * self.scene_extent
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def scene_max(self):
|
| 34 |
+
return self.scene_center + 0.5 * self.scene_extent
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def inside_min(self):
|
| 38 |
+
return self.scene_center - 0.5 * self.inside_extent
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def inside_max(self):
|
| 42 |
+
return self.scene_center + 0.5 * self.inside_extent
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def outside_level(self):
|
| 46 |
+
return (self.scene_extent / self.inside_extent).log2().round().long().item()
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def bounding(self):
|
| 50 |
+
return torch.stack([self.scene_min, self.scene_max])
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def inside_mask(self):
|
| 54 |
+
isin = ((self.inside_min < self.vox_center) & (self.vox_center < self.inside_max)).all(1)
|
| 55 |
+
return isin
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def sh0(self):
|
| 59 |
+
return self._sh0
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def shs(self):
|
| 63 |
+
return self._shs
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def subdivision_priority(self):
|
| 67 |
+
return self._subdiv_p.grad
|
| 68 |
+
|
| 69 |
+
def reset_subdivision_priority(self):
|
| 70 |
+
self._subdiv_p.grad = None
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def signature(self):
|
| 74 |
+
# Signature to check if the voxel grid layout is updated
|
| 75 |
+
return (self.num_voxels, id(self.octpath), id(self.octlevel))
|
| 76 |
+
|
| 77 |
+
def _check_derived_voxel_attr(self):
|
| 78 |
+
# Lazy computation of inverse voxel sizes
|
| 79 |
+
signature = self.signature
|
| 80 |
+
need_recompute = not hasattr(self, '_check_derived_voxel_attr_signature') or \
|
| 81 |
+
self._check_derived_voxel_attr_signature != signature
|
| 82 |
+
if need_recompute:
|
| 83 |
+
self._vox_center, self._vox_size = octree_utils.octpath_decoding(
|
| 84 |
+
self.octpath, self.octlevel, self.scene_center, self.scene_extent)
|
| 85 |
+
self._grid_pts_key, self._vox_key = octree_utils.build_grid_pts_link(self.octpath, self.octlevel)
|
| 86 |
+
self._check_derived_voxel_attr_signature = signature
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def vox_center(self):
|
| 90 |
+
self._check_derived_voxel_attr()
|
| 91 |
+
return self._vox_center
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def vox_size(self):
|
| 95 |
+
self._check_derived_voxel_attr()
|
| 96 |
+
return self._vox_size
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def grid_pts_key(self):
|
| 100 |
+
self._check_derived_voxel_attr()
|
| 101 |
+
return self._grid_pts_key
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def vox_key(self):
|
| 105 |
+
self._check_derived_voxel_attr()
|
| 106 |
+
return self._vox_key
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def vox_size_inv(self):
|
| 110 |
+
# Lazy computation of inverse voxel sizes
|
| 111 |
+
signature = self.signature
|
| 112 |
+
need_recompute = not hasattr(self, '_vox_size_inv_signature') or \
|
| 113 |
+
self._vox_size_inv_signature != signature
|
| 114 |
+
if need_recompute:
|
| 115 |
+
self._vox_size_inv = 1 / self.vox_size
|
| 116 |
+
self._vox_size_inv_signature = signature
|
| 117 |
+
return self._vox_size_inv
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def grid_pts_xyz(self):
|
| 121 |
+
# Lazy computation of grid points xyz
|
| 122 |
+
signature = self.signature
|
| 123 |
+
need_recompute = not hasattr(self, '_grid_pts_xyz_signature') or \
|
| 124 |
+
self._grid_pts_xyz_signature != signature
|
| 125 |
+
if need_recompute:
|
| 126 |
+
self._grid_pts_xyz = octree_utils.compute_gridpoints_xyz(
|
| 127 |
+
self.grid_pts_key, self.scene_center, self.scene_extent)
|
| 128 |
+
self._grid_pts_xyz_signature = signature
|
| 129 |
+
return self._grid_pts_xyz
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def reset_sh_from_cameras(self, cameras):
|
| 133 |
+
self._sh0.data.copy_(rgb2shzero(rgb_fusion(self, cameras)))
|
| 134 |
+
self._shs.data.zero_()
|
| 135 |
+
|
| 136 |
+
def apply_tv_on_density_field(self, lambda_tv_density):
|
| 137 |
+
if self._geo_grid_pts.grad is None:
|
| 138 |
+
self._geo_grid_pts.grad = torch.zeros_like(self._geo_grid_pts.data)
|
| 139 |
+
svraster_cuda.grid_loss_bw.total_variation(
|
| 140 |
+
grid_pts=self._geo_grid_pts,
|
| 141 |
+
vox_key=self.vox_key,
|
| 142 |
+
weight=lambda_tv_density,
|
| 143 |
+
vox_size_inv=self.vox_size_inv,
|
| 144 |
+
no_tv_s=True,
|
| 145 |
+
tv_sparse=False,
|
| 146 |
+
grid_pts_grad=self._geo_grid_pts.grad)
|
src/sparse_voxel_gears/renderer.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import svraster_cuda
|
| 11 |
+
|
| 12 |
+
from src.utils.image_utils import resize_rendering
|
| 13 |
+
|
| 14 |
+
class SVRenderer:
|
| 15 |
+
|
| 16 |
+
def freeze_vox_geo(self):
|
| 17 |
+
'''
|
| 18 |
+
Freeze grid points parameter and pre-gather them to each voxel.
|
| 19 |
+
'''
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
self.frozen_vox_geo = svraster_cuda.renderer.GatherGeoParams.apply(
|
| 22 |
+
self.vox_key,
|
| 23 |
+
torch.arange(self.num_voxels, device="cuda"),
|
| 24 |
+
self._geo_grid_pts
|
| 25 |
+
)
|
| 26 |
+
self._geo_grid_pts.requires_grad = False
|
| 27 |
+
|
| 28 |
+
def unfreeze_vox_geo(self):
|
| 29 |
+
'''
|
| 30 |
+
Unfreeze grid points parameter.
|
| 31 |
+
'''
|
| 32 |
+
del self.frozen_vox_geo
|
| 33 |
+
self._geo_grid_pts.requires_grad = True
|
| 34 |
+
|
| 35 |
+
def vox_fn(self, idx, cam_pos, color_mode=None, viewdir=None):
|
| 36 |
+
'''
|
| 37 |
+
Per-frame voxel property processing. Two important operations:
|
| 38 |
+
1. Gather grid points parameter into each voxel.
|
| 39 |
+
2. Compute view-dependent color of each voxel.
|
| 40 |
+
|
| 41 |
+
Input:
|
| 42 |
+
@idx Indices for active voxel for current frame.
|
| 43 |
+
@cam_pos Camera position.
|
| 44 |
+
Output:
|
| 45 |
+
@vox_params A dictionary of the pre-process voxel properties.
|
| 46 |
+
'''
|
| 47 |
+
|
| 48 |
+
# Gather the density values at the eight corners of each voxel.
|
| 49 |
+
# It defined a trilinear density field.
|
| 50 |
+
# The final tensor are in shape [#vox, 8]
|
| 51 |
+
if hasattr(self, 'frozen_vox_geo'):
|
| 52 |
+
geos = self.frozen_vox_geo
|
| 53 |
+
else:
|
| 54 |
+
geos = svraster_cuda.renderer.GatherGeoParams.apply(
|
| 55 |
+
self.vox_key,
|
| 56 |
+
idx,
|
| 57 |
+
self._geo_grid_pts
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Compute voxel colors
|
| 61 |
+
if color_mode is None or color_mode == "sh":
|
| 62 |
+
active_sh_degree = self.active_sh_degree
|
| 63 |
+
color_mode = "sh"
|
| 64 |
+
elif color_mode.startswith("sh"):
|
| 65 |
+
active_sh_degree = int(color_mode[2])
|
| 66 |
+
color_mode = "sh"
|
| 67 |
+
|
| 68 |
+
if color_mode == "sh":
|
| 69 |
+
rgbs = svraster_cuda.renderer.SH_eval.apply(
|
| 70 |
+
active_sh_degree,
|
| 71 |
+
idx,
|
| 72 |
+
self.vox_center,
|
| 73 |
+
cam_pos,
|
| 74 |
+
viewdir, # Ignore above two when viewdir is not None
|
| 75 |
+
self.sh0,
|
| 76 |
+
self.shs,
|
| 77 |
+
)
|
| 78 |
+
elif color_mode == "rand":
|
| 79 |
+
rgbs = torch.rand([self.num_voxels, 3], dtype=torch.float32, device="cuda")
|
| 80 |
+
elif color_mode == "dontcare":
|
| 81 |
+
rgbs = torch.empty([self.num_voxels, 3], dtype=torch.float32, device="cuda")
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
# Pack everything
|
| 86 |
+
vox_params = {
|
| 87 |
+
'geos': geos,
|
| 88 |
+
'rgbs': rgbs,
|
| 89 |
+
'subdiv_p': self._subdiv_p, # Dummy param to record subdivision priority
|
| 90 |
+
}
|
| 91 |
+
if vox_params['subdiv_p'] is None:
|
| 92 |
+
vox_params['subdiv_p'] = torch.ones([self.num_voxels, 1], device="cuda")
|
| 93 |
+
|
| 94 |
+
return vox_params
|
| 95 |
+
|
| 96 |
+
def render(
|
| 97 |
+
self,
|
| 98 |
+
camera,
|
| 99 |
+
color_mode=None,
|
| 100 |
+
track_max_w=False,
|
| 101 |
+
ss=None,
|
| 102 |
+
output_depth=False,
|
| 103 |
+
output_normal=False,
|
| 104 |
+
output_T=False,
|
| 105 |
+
rand_bg=False,
|
| 106 |
+
use_auto_exposure=False,
|
| 107 |
+
**other_opt):
|
| 108 |
+
|
| 109 |
+
###################################
|
| 110 |
+
# Pre-processing
|
| 111 |
+
###################################
|
| 112 |
+
if ss is None:
|
| 113 |
+
ss = self.ss
|
| 114 |
+
w_src, h_src = camera.image_width, camera.image_height
|
| 115 |
+
w, h = round(w_src * ss), round(h_src * ss)
|
| 116 |
+
w_ss, h_ss = w / w_src, h / h_src
|
| 117 |
+
if ss != 1.0 and 'gt_color' in other_opt:
|
| 118 |
+
other_opt['gt_color'] = resize_rendering(other_opt['gt_color'], size=(h, w))
|
| 119 |
+
|
| 120 |
+
n_samp_per_vox = other_opt.pop('n_samp_per_vox', self.n_samp_per_vox)
|
| 121 |
+
|
| 122 |
+
###################################
|
| 123 |
+
# Call low-level rasterization API
|
| 124 |
+
###################################
|
| 125 |
+
raster_settings = svraster_cuda.renderer.RasterSettings(
|
| 126 |
+
color_mode=color_mode,
|
| 127 |
+
n_samp_per_vox=n_samp_per_vox,
|
| 128 |
+
image_width=w,
|
| 129 |
+
image_height=h,
|
| 130 |
+
tanfovx=camera.tanfovx,
|
| 131 |
+
tanfovy=camera.tanfovy,
|
| 132 |
+
cx=camera.cx * w_ss,
|
| 133 |
+
cy=camera.cy * h_ss,
|
| 134 |
+
w2c_matrix=camera.w2c,
|
| 135 |
+
c2w_matrix=camera.c2w,
|
| 136 |
+
bg_color=float(self.white_background),
|
| 137 |
+
near=camera.near,
|
| 138 |
+
need_depth=output_depth,
|
| 139 |
+
need_normal=output_normal,
|
| 140 |
+
track_max_w=track_max_w,
|
| 141 |
+
**other_opt)
|
| 142 |
+
color, depth, normal, T, max_w = svraster_cuda.renderer.rasterize_voxels(
|
| 143 |
+
raster_settings,
|
| 144 |
+
self.octpath,
|
| 145 |
+
self.vox_center,
|
| 146 |
+
self.vox_size,
|
| 147 |
+
self.vox_fn)
|
| 148 |
+
|
| 149 |
+
###################################
|
| 150 |
+
# Post-processing and pack output
|
| 151 |
+
###################################
|
| 152 |
+
if rand_bg:
|
| 153 |
+
color = color + T * torch.rand_like(color, requires_grad=False)
|
| 154 |
+
elif not self.white_background and not self.black_background:
|
| 155 |
+
color = color + T * color.mean((1,2), keepdim=True)
|
| 156 |
+
|
| 157 |
+
if use_auto_exposure:
|
| 158 |
+
color = camera.auto_exposure_apply(color)
|
| 159 |
+
|
| 160 |
+
render_pkg = {
|
| 161 |
+
'color': color,
|
| 162 |
+
'depth': depth if output_depth else None,
|
| 163 |
+
'normal': normal if output_normal else None,
|
| 164 |
+
'T': T if output_T else None,
|
| 165 |
+
'max_w': max_w,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
for k in ['color', 'depth', 'normal', 'T']:
|
| 169 |
+
render_pkg[f'raw_{k}'] = render_pkg[k]
|
| 170 |
+
|
| 171 |
+
# Post process super-sampling
|
| 172 |
+
if render_pkg[k] is not None and render_pkg[k].shape[-2:] != (h_src, w_src):
|
| 173 |
+
render_pkg[k] = resize_rendering(render_pkg[k], size=(h_src, w_src))
|
| 174 |
+
|
| 175 |
+
# Clip intensity
|
| 176 |
+
render_pkg['color'] = render_pkg['color'].clamp(0, 1)
|
| 177 |
+
|
| 178 |
+
return render_pkg
|
src/sparse_voxel_gears/renderer_copy.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import svraster_cuda
|
| 11 |
+
|
| 12 |
+
from src.utils.image_utils import resize_rendering
|
| 13 |
+
|
| 14 |
+
class SVRenderer:
|
| 15 |
+
|
| 16 |
+
def freeze_vox_geo(self):
|
| 17 |
+
'''
|
| 18 |
+
Freeze grid points parameter and pre-gather them to each voxel.
|
| 19 |
+
'''
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
self.frozen_vox_geo = svraster_cuda.renderer.GatherGeoParams.apply(
|
| 22 |
+
self.vox_key,
|
| 23 |
+
torch.arange(self.num_voxels, device="cuda"),
|
| 24 |
+
self._geo_grid_pts
|
| 25 |
+
)
|
| 26 |
+
self._geo_grid_pts.requires_grad = False
|
| 27 |
+
|
| 28 |
+
def unfreeze_vox_geo(self):
|
| 29 |
+
'''
|
| 30 |
+
Unfreeze grid points parameter.
|
| 31 |
+
'''
|
| 32 |
+
del self.frozen_vox_geo
|
| 33 |
+
self._geo_grid_pts.requires_grad = True
|
| 34 |
+
|
| 35 |
+
def vox_fn(self, idx, cam_pos, color_mode=None, viewdir=None):
|
| 36 |
+
'''
|
| 37 |
+
Per-frame voxel property processing. Two important operations:
|
| 38 |
+
1. Gather grid points parameter into each voxel.
|
| 39 |
+
2. Compute view-dependent color of each voxel.
|
| 40 |
+
|
| 41 |
+
Input:
|
| 42 |
+
@idx Indices for active voxel for current frame.
|
| 43 |
+
@cam_pos Camera position.
|
| 44 |
+
Output:
|
| 45 |
+
@vox_params A dictionary of the pre-process voxel properties.
|
| 46 |
+
'''
|
| 47 |
+
|
| 48 |
+
# Gather the density values at the eight corners of each voxel.
|
| 49 |
+
# It defined a trilinear density field.
|
| 50 |
+
# The final tensor are in shape [#vox, 8]
|
| 51 |
+
if hasattr(self, 'frozen_vox_geo'):
|
| 52 |
+
geos = self.frozen_vox_geo
|
| 53 |
+
else:
|
| 54 |
+
geos = svraster_cuda.renderer.GatherGeoParams.apply(
|
| 55 |
+
self.vox_key,
|
| 56 |
+
idx,
|
| 57 |
+
self._geo_grid_pts
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Compute voxel colors
|
| 61 |
+
if color_mode is None or color_mode == "sh":
|
| 62 |
+
active_sh_degree = self.active_sh_degree
|
| 63 |
+
color_mode = "sh"
|
| 64 |
+
elif color_mode.startswith("sh"):
|
| 65 |
+
active_sh_degree = int(color_mode[2])
|
| 66 |
+
color_mode = "sh"
|
| 67 |
+
|
| 68 |
+
if color_mode == "sh":
|
| 69 |
+
rgbs = svraster_cuda.renderer.SH_eval.apply(
|
| 70 |
+
active_sh_degree,
|
| 71 |
+
idx,
|
| 72 |
+
self.vox_center,
|
| 73 |
+
cam_pos,
|
| 74 |
+
viewdir, # Ignore above two when viewdir is not None
|
| 75 |
+
self.sh0,
|
| 76 |
+
self.shs,
|
| 77 |
+
)
|
| 78 |
+
elif color_mode == "rand":
|
| 79 |
+
rgbs = torch.rand([self.num_voxels, 3], dtype=torch.float32, device="cuda")
|
| 80 |
+
elif color_mode == "dontcare":
|
| 81 |
+
rgbs = torch.empty([self.num_voxels, 3], dtype=torch.float32, device="cuda")
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
# Pack everything
|
| 86 |
+
vox_params = {
|
| 87 |
+
'geos': geos,
|
| 88 |
+
'rgbs': rgbs,
|
| 89 |
+
'subdiv_p': self._subdiv_p, # Dummy param to record subdivision priority
|
| 90 |
+
}
|
| 91 |
+
if vox_params['subdiv_p'] is None:
|
| 92 |
+
vox_params['subdiv_p'] = torch.ones([self.num_voxels, 1], device="cuda")
|
| 93 |
+
|
| 94 |
+
return vox_params
|
| 95 |
+
|
| 96 |
+
def render(
|
| 97 |
+
self,
|
| 98 |
+
camera,
|
| 99 |
+
color_mode=None,
|
| 100 |
+
track_max_w=False,
|
| 101 |
+
ss=None,
|
| 102 |
+
output_depth=False,
|
| 103 |
+
output_normal=False,
|
| 104 |
+
output_T=False,
|
| 105 |
+
rand_bg=False,
|
| 106 |
+
use_auto_exposure=False,
|
| 107 |
+
**other_opt):
|
| 108 |
+
|
| 109 |
+
###################################
|
| 110 |
+
# Pre-processing
|
| 111 |
+
###################################
|
| 112 |
+
if ss is None:
|
| 113 |
+
ss = self.ss
|
| 114 |
+
w_src, h_src = camera.image_width, camera.image_height
|
| 115 |
+
w, h = round(w_src * ss), round(h_src * ss)
|
| 116 |
+
w_ss, h_ss = w / w_src, h / h_src
|
| 117 |
+
if ss != 1.0 and 'gt_color' in other_opt:
|
| 118 |
+
other_opt['gt_color'] = resize_rendering(other_opt['gt_color'], size=(h, w))
|
| 119 |
+
|
| 120 |
+
n_samp_per_vox = other_opt.pop('n_samp_per_vox', self.n_samp_per_vox)
|
| 121 |
+
|
| 122 |
+
###################################
|
| 123 |
+
# Call low-level rasterization API
|
| 124 |
+
###################################
|
| 125 |
+
raster_settings = svraster_cuda.renderer.RasterSettings(
|
| 126 |
+
color_mode=color_mode,
|
| 127 |
+
n_samp_per_vox=n_samp_per_vox,
|
| 128 |
+
image_width=w,
|
| 129 |
+
image_height=h,
|
| 130 |
+
tanfovx=camera.tanfovx,
|
| 131 |
+
tanfovy=camera.tanfovy,
|
| 132 |
+
cx=camera.cx * w_ss,
|
| 133 |
+
cy=camera.cy * h_ss,
|
| 134 |
+
w2c_matrix=camera.w2c,
|
| 135 |
+
c2w_matrix=camera.c2w,
|
| 136 |
+
bg_color=float(self.white_background),
|
| 137 |
+
near=camera.near,
|
| 138 |
+
need_depth=output_depth,
|
| 139 |
+
need_normal=output_normal,
|
| 140 |
+
track_max_w=track_max_w,
|
| 141 |
+
**other_opt)
|
| 142 |
+
color, depth, normal, T, max_w = svraster_cuda.renderer.rasterize_voxels(
|
| 143 |
+
raster_settings,
|
| 144 |
+
self.octpath,
|
| 145 |
+
self.vox_center,
|
| 146 |
+
self.vox_size,
|
| 147 |
+
self.vox_fn)
|
| 148 |
+
|
| 149 |
+
###################################
|
| 150 |
+
# Post-processing and pack output
|
| 151 |
+
###################################
|
| 152 |
+
if rand_bg:
|
| 153 |
+
color = color + T * torch.rand_like(color, requires_grad=False)
|
| 154 |
+
elif not self.white_background and not self.black_background:
|
| 155 |
+
color = color + T * color.mean((1,2), keepdim=True)
|
| 156 |
+
|
| 157 |
+
if use_auto_exposure:
|
| 158 |
+
color = camera.auto_exposure_apply(color)
|
| 159 |
+
|
| 160 |
+
render_pkg = {
|
| 161 |
+
'color': color,
|
| 162 |
+
'depth': depth if output_depth else None,
|
| 163 |
+
'normal': normal if output_normal else None,
|
| 164 |
+
'T': T if output_T else None,
|
| 165 |
+
'max_w': max_w,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
for k in ['color', 'depth', 'normal', 'T']:
|
| 169 |
+
render_pkg[f'raw_{k}'] = render_pkg[k]
|
| 170 |
+
|
| 171 |
+
# Post process super-sampling
|
| 172 |
+
if render_pkg[k] is not None and render_pkg[k].shape[-2:] != (h_src, w_src):
|
| 173 |
+
render_pkg[k] = resize_rendering(render_pkg[k], size=(h_src, w_src))
|
| 174 |
+
|
| 175 |
+
# Clip intensity
|
| 176 |
+
render_pkg['color'] = render_pkg['color'].clamp(0, 1)
|
| 177 |
+
|
| 178 |
+
return render_pkg
|
src/sparse_voxel_model.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
from src.sparse_voxel_gears.constructor import SVConstructor
|
| 10 |
+
from src.sparse_voxel_gears.properties import SVProperties
|
| 11 |
+
from src.sparse_voxel_gears.renderer import SVRenderer
|
| 12 |
+
from src.sparse_voxel_gears.adaptive import SVAdaptive
|
| 13 |
+
from src.sparse_voxel_gears.io import SVInOut
|
| 14 |
+
from src.sparse_voxel_gears.pooling import SVPooling
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SparseVoxelModel(SVConstructor, SVProperties, SVRenderer, SVAdaptive, SVInOut, SVPooling):
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
n_samp_per_vox=1, # Number of sampled points per visited voxel
|
| 21 |
+
sh_degree=3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
|
| 22 |
+
ss=1.5, # Super-sampling rates for anti-aliasing
|
| 23 |
+
white_background=False, # Assum white background
|
| 24 |
+
black_background=False, # Assum black background
|
| 25 |
+
):
|
| 26 |
+
'''
|
| 27 |
+
Setup of the model meta. At this point, no voxel is allocated.
|
| 28 |
+
Use the following methods to allocate voxels and parameters.
|
| 29 |
+
|
| 30 |
+
1. `model_load` defined in `src/sparse_voxel_gears/io.py`.
|
| 31 |
+
Load the saved models from a given path.
|
| 32 |
+
|
| 33 |
+
2. `model_init` defined in `src/sparse_voxel_gears/constructor.py`.
|
| 34 |
+
Heuristically initial the sparse grid layout and parameters from the training datas.
|
| 35 |
+
'''
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.n_samp_per_vox = n_samp_per_vox
|
| 39 |
+
self.max_sh_degree = sh_degree
|
| 40 |
+
self.ss = ss
|
| 41 |
+
self.white_background = white_background
|
| 42 |
+
self.black_background = black_background
|
| 43 |
+
|
| 44 |
+
# List the variable names
|
| 45 |
+
self.per_voxel_attr_lst = [
|
| 46 |
+
'octpath', 'octlevel',
|
| 47 |
+
'_subdiv_p',
|
| 48 |
+
]
|
| 49 |
+
self.per_voxel_param_lst = [
|
| 50 |
+
'_sh0', '_shs',
|
| 51 |
+
]
|
| 52 |
+
self.grid_pts_param_lst = [
|
| 53 |
+
'_geo_grid_pts',
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# To be init from model_init
|
| 57 |
+
self.scene_center = None
|
| 58 |
+
self.scene_extent = None
|
| 59 |
+
self.inside_extent = None
|
| 60 |
+
self.octpath = None
|
| 61 |
+
self.octlevel = None
|
| 62 |
+
self.active_sh_degree = sh_degree
|
| 63 |
+
|
| 64 |
+
self._geo_grid_pts = None
|
| 65 |
+
self._sh0 = None
|
| 66 |
+
self._shs = None
|
| 67 |
+
self._subdiv_p = None
|
src/sparse_voxel_model_copy.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
from src.sparse_voxel_gears.constructor import SVConstructor
|
| 10 |
+
from src.sparse_voxel_gears.properties import SVProperties
|
| 11 |
+
from src.sparse_voxel_gears.renderer import SVRenderer
|
| 12 |
+
from src.sparse_voxel_gears.adaptive import SVAdaptive
|
| 13 |
+
from src.sparse_voxel_gears.io import SVInOut
|
| 14 |
+
from src.sparse_voxel_gears.pooling import SVPooling
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SparseVoxelModel(SVConstructor, SVProperties, SVRenderer, SVAdaptive, SVInOut, SVPooling):
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
n_samp_per_vox=1, # Number of sampled points per visited voxel
|
| 21 |
+
sh_degree=3, # Use 3 * (k+1)^2 params per voxels for view-dependent colors
|
| 22 |
+
ss=1.5, # Super-sampling rates for anti-aliasing
|
| 23 |
+
white_background=False, # Assum white background
|
| 24 |
+
black_background=False, # Assum black background
|
| 25 |
+
):
|
| 26 |
+
'''
|
| 27 |
+
Setup of the model meta. At this point, no voxel is allocated.
|
| 28 |
+
Use the following methods to allocate voxels and parameters.
|
| 29 |
+
|
| 30 |
+
1. `model_load` defined in `src/sparse_voxel_gears/io.py`.
|
| 31 |
+
Load the saved models from a given path.
|
| 32 |
+
|
| 33 |
+
2. `model_init` defined in `src/sparse_voxel_gears/constructor.py`.
|
| 34 |
+
Heuristically initial the sparse grid layout and parameters from the training datas.
|
| 35 |
+
'''
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.n_samp_per_vox = n_samp_per_vox
|
| 39 |
+
self.max_sh_degree = sh_degree
|
| 40 |
+
self.ss = ss
|
| 41 |
+
self.white_background = white_background
|
| 42 |
+
self.black_background = black_background
|
| 43 |
+
|
| 44 |
+
# List the variable names
|
| 45 |
+
self.per_voxel_attr_lst = [
|
| 46 |
+
'octpath', 'octlevel',
|
| 47 |
+
'_subdiv_p',
|
| 48 |
+
]
|
| 49 |
+
self.per_voxel_param_lst = [
|
| 50 |
+
'_sh0', '_shs',
|
| 51 |
+
]
|
| 52 |
+
self.grid_pts_param_lst = [
|
| 53 |
+
'_geo_grid_pts',
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# To be init from model_init
|
| 57 |
+
self.scene_center = None
|
| 58 |
+
self.scene_extent = None
|
| 59 |
+
self.inside_extent = None
|
| 60 |
+
self.octpath = None
|
| 61 |
+
self.octlevel = None
|
| 62 |
+
self.active_sh_degree = sh_degree
|
| 63 |
+
|
| 64 |
+
self._geo_grid_pts = None
|
| 65 |
+
self._sh0 = None
|
| 66 |
+
self._shs = None
|
| 67 |
+
self._subdiv_p = None
|
src/utils/__pycache__/activation_utils.cpython-39.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
src/utils/__pycache__/bounding_utils.cpython-39.pyc
ADDED
|
Binary file (3.05 kB). View file
|
|
|
src/utils/__pycache__/camera_utils.cpython-39.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
src/utils/__pycache__/colmap_utils.cpython-39.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
src/utils/__pycache__/fuser_utils.cpython-39.pyc
ADDED
|
Binary file (3.87 kB). View file
|
|
|
src/utils/__pycache__/image_utils.cpython-39.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
src/utils/__pycache__/loss_utils.cpython-39.pyc
ADDED
|
Binary file (8.78 kB). View file
|
|
|
src/utils/__pycache__/marching_cubes_utils.cpython-39.pyc
ADDED
|
Binary file (25.1 kB). View file
|
|
|
src/utils/__pycache__/mono_utils.cpython-39.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
src/utils/__pycache__/octree_utils.cpython-39.pyc
ADDED
|
Binary file (7.49 kB). View file
|
|
|
src/utils/__pycache__/system_utils.cpython-39.pyc
ADDED
|
Binary file (372 Bytes). View file
|
|
|
src/utils/activation_utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from svraster_cuda.meta import STEP_SZ_SCALE
|
| 11 |
+
|
| 12 |
+
def softplus(x):
|
| 13 |
+
return torch.nn.functional.softplus(x)
|
| 14 |
+
|
| 15 |
+
def exp_linear_10(x):
|
| 16 |
+
return torch.where(x > 1, x, torch.exp(x - 1))
|
| 17 |
+
|
| 18 |
+
def exp_linear_11(x):
|
| 19 |
+
return torch.where(x > 1.1, x, torch.exp(0.909090909091 * x - 0.904689820196))
|
| 20 |
+
|
| 21 |
+
def exp_linear_20(x):
|
| 22 |
+
return torch.where(x > 2.0, x, torch.exp(0.5 * x - 0.30685281944))
|
| 23 |
+
|
| 24 |
+
def softplus_inverse(y):
|
| 25 |
+
return y + torch.log(-torch.expm1(-y))
|
| 26 |
+
|
| 27 |
+
def exp_linear_10_inverse(y):
|
| 28 |
+
return torch.where(y > 1, y, torch.log(y) + 1)
|
| 29 |
+
|
| 30 |
+
def exp_linear_11_inverse(y):
|
| 31 |
+
return torch.where(y > 1.1, y, (torch.log(y) + 0.904689820196) / 0.909090909091)
|
| 32 |
+
|
| 33 |
+
def exp_linear_20_inverse(x):
|
| 34 |
+
return torch.where(y > 2.0, y, (torch.log(y) + 0.30685281944) / 0.5)
|
| 35 |
+
|
| 36 |
+
def smooth_clamp_max(x, max_val):
|
| 37 |
+
return max_val - torch.nn.functional.softplus(max_val - x)
|
| 38 |
+
|
| 39 |
+
def density2alpha(density, interval):
|
| 40 |
+
return 1 - torch.exp(-STEP_SZ_SCALE * interval * density)
|
| 41 |
+
|
| 42 |
+
def alpha2density(alpha, interval):
|
| 43 |
+
return torch.log(1 - alpha) / (-STEP_SZ_SCALE * interval)
|
| 44 |
+
|
| 45 |
+
def rgb2shzero(x):
|
| 46 |
+
return (x - 0.5) / 0.28209479177387814
|
| 47 |
+
|
| 48 |
+
def shzero2rgb(x):
|
| 49 |
+
return x * 0.28209479177387814 + 0.5
|
src/utils/bounding_utils.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def decide_main_bounding(bound_mode="default",
|
| 13 |
+
forward_dist_scale=1.0, # For "forward" mode
|
| 14 |
+
pcd_density_rate=0.1, # For "pcd" mode
|
| 15 |
+
bound_scale=1.0, # Scaling of the bounding
|
| 16 |
+
tr_cams=None, # Cameras
|
| 17 |
+
pcd=None, # Point cloud
|
| 18 |
+
suggested_bounding=None):
|
| 19 |
+
if bound_mode == "default" and suggested_bounding is not None:
|
| 20 |
+
print("Use suggested bounding")
|
| 21 |
+
center = suggested_bounding.mean(0)
|
| 22 |
+
radius = (suggested_bounding[1] - suggested_bounding[0]) * 0.5
|
| 23 |
+
elif bound_mode in ["camera_max", "camera_median"]:
|
| 24 |
+
center, radius = main_scene_bound_camera_heuristic(
|
| 25 |
+
cams=tr_cams, bound_mode=bound_mode)
|
| 26 |
+
elif bound_mode == "forward":
|
| 27 |
+
center, radius = main_scene_bound_forward_heuristic(
|
| 28 |
+
cams=tr_cams, forward_dist_scale=forward_dist_scale)
|
| 29 |
+
elif bound_mode == "pcd":
|
| 30 |
+
center, radius = main_scene_bound_pcd_heuristic(
|
| 31 |
+
pcd=pcd, pcd_density_rate=pcd_density_rate)
|
| 32 |
+
elif bound_mode == "default":
|
| 33 |
+
cam_lookats = np.stack([cam.lookat.tolist() for cam in tr_cams])
|
| 34 |
+
lookat_dots = (cam_lookats[:,None] * cam_lookats).sum(-1)
|
| 35 |
+
is_forward_facing = lookat_dots.min() > 0
|
| 36 |
+
|
| 37 |
+
if is_forward_facing:
|
| 38 |
+
center, radius = main_scene_bound_forward_heuristic(
|
| 39 |
+
cams=tr_cams, forward_dist_scale=forward_dist_scale)
|
| 40 |
+
else:
|
| 41 |
+
center, radius = main_scene_bound_camera_heuristic(
|
| 42 |
+
cams=tr_cams, bound_mode="camera_median")
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
radius = radius * bound_scale
|
| 47 |
+
|
| 48 |
+
bounding = np.array([
|
| 49 |
+
center - radius,
|
| 50 |
+
center + radius,
|
| 51 |
+
], dtype=np.float32)
|
| 52 |
+
return bounding
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main_scene_bound_camera_heuristic(cams, bound_mode):
|
| 56 |
+
print("Heuristic bounding:", bound_mode)
|
| 57 |
+
cam_positions = np.stack([cam.position.tolist() for cam in cams])
|
| 58 |
+
center = cam_positions.mean(0)
|
| 59 |
+
dists = np.linalg.norm(cam_positions - center, axis=1)
|
| 60 |
+
if bound_mode == "camera_max":
|
| 61 |
+
radius = np.max(dists)
|
| 62 |
+
elif bound_mode == "camera_median":
|
| 63 |
+
radius = np.median(dists)
|
| 64 |
+
else:
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
return center, radius
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main_scene_bound_forward_heuristic(cams, forward_dist_scale):
|
| 70 |
+
print("Heuristic bounding: forward")
|
| 71 |
+
positions = np.stack([cam.position.tolist() for cam in cams])
|
| 72 |
+
cam_center = positions.mean(0)
|
| 73 |
+
cam_lookat = np.stack([cam.lookat.tolist() for cam in cams]).mean(0)
|
| 74 |
+
cam_lookat /= np.linalg.norm(cam_lookat)
|
| 75 |
+
cam_extent = 2 * np.linalg.norm(positions - cam_center, axis=1).max()
|
| 76 |
+
|
| 77 |
+
center = cam_center + forward_dist_scale * cam_extent * cam_lookat
|
| 78 |
+
radius = 0.8 * forward_dist_scale * cam_extent
|
| 79 |
+
|
| 80 |
+
return center, radius
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def main_scene_bound_pcd_heuristic(pcd, pcd_density_rate):
|
| 84 |
+
print("Heuristic bounding: pcd")
|
| 85 |
+
center = np.median(pcd.points, axis=0)
|
| 86 |
+
dist = np.abs(pcd.points - center).max(axis=1)
|
| 87 |
+
dist = np.sort(dist)
|
| 88 |
+
density = (1 + np.arange(len(dist))) * (dist > 0) / ((2 * dist) ** 3 + 1e-6)
|
| 89 |
+
|
| 90 |
+
# Should cover at least 5% of the point
|
| 91 |
+
begin_idx = round(len(density) * 0.05)
|
| 92 |
+
|
| 93 |
+
# Find the radius with maximum point density
|
| 94 |
+
max_idx = begin_idx + density[begin_idx:].argmax()
|
| 95 |
+
|
| 96 |
+
# Find the smallest radius with point density equal to pcd_density_rate of maximum
|
| 97 |
+
target_density = pcd_density_rate * density[max_idx]
|
| 98 |
+
target_idx = max_idx + np.where(density[max_idx:] < target_density)[0][0]
|
| 99 |
+
|
| 100 |
+
radius = dist[target_idx]
|
| 101 |
+
|
| 102 |
+
return center, radius
|
src/utils/camera_utils.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.interpolate import make_interp_spline
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def fov2focal(fov, pixels):
|
| 14 |
+
return pixels / (2 * np.tan(0.5 * fov))
|
| 15 |
+
|
| 16 |
+
def focal2fov(focal, pixels):
|
| 17 |
+
return 2 * np.arctan(pixels / (2 * focal))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def interpolate_poses(poses, n_frame, periodic=True):
|
| 21 |
+
|
| 22 |
+
assert len(poses) > 1
|
| 23 |
+
|
| 24 |
+
poses = list(poses)
|
| 25 |
+
bc_type = None
|
| 26 |
+
|
| 27 |
+
if periodic:
|
| 28 |
+
poses.append(poses[0])
|
| 29 |
+
bc_type = "periodic"
|
| 30 |
+
|
| 31 |
+
pos_lst = np.stack([pose[:3, 3] for pose in poses])
|
| 32 |
+
lookat_lst = np.stack([pose[:3, 2] for pose in poses])
|
| 33 |
+
right_lst = np.stack([pose[:3, 0] for pose in poses])
|
| 34 |
+
|
| 35 |
+
ts = np.linspace(0, 1, len(poses))
|
| 36 |
+
pos_interp_f = make_interp_spline(ts, pos_lst, bc_type=bc_type)
|
| 37 |
+
lookat_interp_f = make_interp_spline(ts, lookat_lst, bc_type=bc_type)
|
| 38 |
+
right_interp_f = make_interp_spline(ts, right_lst, bc_type=bc_type)
|
| 39 |
+
|
| 40 |
+
samps = np.linspace(0, 1, n_frame+1)[:n_frame]
|
| 41 |
+
pos_video = pos_interp_f(samps)
|
| 42 |
+
lookat_video = lookat_interp_f(samps)
|
| 43 |
+
right_video = right_interp_f(samps)
|
| 44 |
+
interp_poses = []
|
| 45 |
+
for i in range(n_frame):
|
| 46 |
+
pos = pos_video[i]
|
| 47 |
+
lookat = lookat_video[i] / np.linalg.norm(lookat_video[i])
|
| 48 |
+
right_ = right_video[i] / np.linalg.norm(right_video[i])
|
| 49 |
+
down = np.cross(lookat, right_)
|
| 50 |
+
right = np.cross(down, lookat)
|
| 51 |
+
c2w = np.eye(4, dtype=np.float32)
|
| 52 |
+
c2w[:3, 0] = right
|
| 53 |
+
c2w[:3, 1] = down
|
| 54 |
+
c2w[:3, 2] = lookat
|
| 55 |
+
c2w[:3, 3] = pos
|
| 56 |
+
interp_poses.append(c2w)
|
| 57 |
+
|
| 58 |
+
return interp_poses
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def gen_circular_poses(radius,
|
| 62 |
+
n_frame,
|
| 63 |
+
starting=1.5 * np.pi, # Starting from -z
|
| 64 |
+
):
|
| 65 |
+
poses = []
|
| 66 |
+
for rad in np.linspace(starting, starting + 2 * np.pi, n_frame):
|
| 67 |
+
pos = radius * np.array([np.cos(rad), 0, np.sin(rad)])
|
| 68 |
+
lookat = -pos / np.linalg.norm(pos)
|
| 69 |
+
down = np.array([0, 1, 0])
|
| 70 |
+
right = np.cross(down, lookat)
|
| 71 |
+
right = right / np.linalg.norm(right)
|
| 72 |
+
down = np.cross(lookat, right)
|
| 73 |
+
c2w = np.eye(4, dtype=np.float32)
|
| 74 |
+
c2w[:3, 0] = right
|
| 75 |
+
c2w[:3, 1] = down
|
| 76 |
+
c2w[:3, 2] = lookat
|
| 77 |
+
c2w[:3, 3] = pos
|
| 78 |
+
poses.append(c2w)
|
| 79 |
+
return poses
|
src/utils/colmap_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import pycolmap
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from typing import NamedTuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PointCloud(NamedTuple):
|
| 16 |
+
points: np.array
|
| 17 |
+
colors: np.array
|
| 18 |
+
errors: np.array
|
| 19 |
+
corr: dict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_colmap_pts(sfm: pycolmap.Reconstruction, transform: np.array =None):
|
| 23 |
+
"""
|
| 24 |
+
Parse COLMAP points and correspondents.
|
| 25 |
+
|
| 26 |
+
Input:
|
| 27 |
+
@sfm Reconstruction from COLMAP.
|
| 28 |
+
@transform 3x3 matrix to transform xyz.
|
| 29 |
+
Output:
|
| 30 |
+
@xyz Nx3 point positions.
|
| 31 |
+
@rgb Nx3 point colors.
|
| 32 |
+
@err N errors.
|
| 33 |
+
@corr Dictionary from file name to point indices.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
xyz = []
|
| 37 |
+
rgb = []
|
| 38 |
+
err = []
|
| 39 |
+
points_id = []
|
| 40 |
+
for k, v in sfm.points3D.items():
|
| 41 |
+
points_id.append(k)
|
| 42 |
+
xyz.append(v.xyz)
|
| 43 |
+
rgb.append(v.color)
|
| 44 |
+
err.append(v.error)
|
| 45 |
+
if transform is not None:
|
| 46 |
+
xyz[-1] = transform @ xyz[-1]
|
| 47 |
+
|
| 48 |
+
xyz = np.array(xyz)
|
| 49 |
+
rgb = np.array(rgb)
|
| 50 |
+
err = np.array(err)
|
| 51 |
+
points_id = np.array(points_id)
|
| 52 |
+
|
| 53 |
+
points_idmap = np.full([points_id.max()+2], -1, dtype=np.int64)
|
| 54 |
+
points_idmap[points_id] = np.arange(len(xyz))
|
| 55 |
+
|
| 56 |
+
corr = {}
|
| 57 |
+
for image in sfm.images.values():
|
| 58 |
+
idx = np.array([p.point3D_id for p in image.points2D if p.has_point3D()])
|
| 59 |
+
corr[image.name] = points_idmap[idx]
|
| 60 |
+
assert corr[image.name].min() >= 0 and corr[image.name].max() < len(xyz)
|
| 61 |
+
|
| 62 |
+
return PointCloud(points=xyz, colors=rgb, errors=err, corr=corr)
|
src/utils/fuser_utils.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
Reference: KinectFusion algorithm.
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Fuser:
|
| 19 |
+
def __init__(self,
|
| 20 |
+
xyz,
|
| 21 |
+
bandwidth,
|
| 22 |
+
use_trunc=True,
|
| 23 |
+
fuse_tsdf=True,
|
| 24 |
+
feat_dim=0,
|
| 25 |
+
alpha_thres=0.5,
|
| 26 |
+
crop_border=0.0,
|
| 27 |
+
normal_weight=False,
|
| 28 |
+
depth_weight=False,
|
| 29 |
+
border_weight=False,
|
| 30 |
+
max_norm_dist=10.,
|
| 31 |
+
use_half=False):
|
| 32 |
+
assert len(xyz.shape) == 2
|
| 33 |
+
assert xyz.shape[1] == 3
|
| 34 |
+
self.xyz = xyz
|
| 35 |
+
self.bandwidth = bandwidth
|
| 36 |
+
self.use_trunc = use_trunc
|
| 37 |
+
self.fuse_tsdf = fuse_tsdf
|
| 38 |
+
self.feat_dim = feat_dim
|
| 39 |
+
self.alpha_thres = alpha_thres
|
| 40 |
+
self.crop_border = crop_border
|
| 41 |
+
self.normal_weight = normal_weight
|
| 42 |
+
self.depth_weight = depth_weight
|
| 43 |
+
self.border_weight = border_weight
|
| 44 |
+
self.max_norm_dist = max_norm_dist
|
| 45 |
+
|
| 46 |
+
self.dtype = torch.float16 if use_half else torch.float32
|
| 47 |
+
self.weight = torch.zeros([len(xyz), 1], dtype=self.dtype, device="cuda")
|
| 48 |
+
self.feat = torch.zeros([len(xyz), feat_dim], dtype=self.dtype, device="cuda")
|
| 49 |
+
if self.fuse_tsdf:
|
| 50 |
+
self.sd_val = torch.zeros([len(xyz), 1], dtype=self.dtype, device="cuda")
|
| 51 |
+
else:
|
| 52 |
+
self.sd_val = None
|
| 53 |
+
|
| 54 |
+
def integrate(self, cam, depth, feat=None, alpha=None):
|
| 55 |
+
# Project grid points to image
|
| 56 |
+
xyz_uv = cam.project(self.xyz)
|
| 57 |
+
xyz_front = ((self.xyz - cam.position) @ cam.lookat) > cam.near
|
| 58 |
+
|
| 59 |
+
# Filter points projected outside
|
| 60 |
+
filter_idx = torch.where((xyz_uv.abs() <= 1-self.crop_border).all(-1) & xyz_front)[0]
|
| 61 |
+
valid_idx = filter_idx
|
| 62 |
+
valid_xyz = self.xyz[filter_idx]
|
| 63 |
+
valid_uv = xyz_uv[filter_idx]
|
| 64 |
+
|
| 65 |
+
# Compute projective sdf
|
| 66 |
+
valid_frame_depth = torch.nn.functional.grid_sample(
|
| 67 |
+
depth.view(1,1,*depth.shape[-2:]),
|
| 68 |
+
valid_uv.view(1,1,-1,2),
|
| 69 |
+
mode='bilinear',
|
| 70 |
+
align_corners=False).flatten()
|
| 71 |
+
valid_xyz_depth = (valid_xyz - cam.position) @ cam.lookat
|
| 72 |
+
valid_sdf = valid_frame_depth - valid_xyz_depth
|
| 73 |
+
|
| 74 |
+
if torch.is_tensor(self.bandwidth):
|
| 75 |
+
bandwidth = self.bandwidth[valid_idx]
|
| 76 |
+
else:
|
| 77 |
+
bandwidth = self.bandwidth
|
| 78 |
+
|
| 79 |
+
valid_sdf *= (1 / bandwidth)
|
| 80 |
+
|
| 81 |
+
if self.use_trunc:
|
| 82 |
+
# Filter occluded
|
| 83 |
+
filter_idx = torch.where(valid_sdf >= -1)[0]
|
| 84 |
+
valid_idx = valid_idx[filter_idx]
|
| 85 |
+
valid_uv = valid_uv[filter_idx]
|
| 86 |
+
valid_frame_depth = valid_frame_depth[filter_idx]
|
| 87 |
+
valid_sdf = valid_sdf[filter_idx]
|
| 88 |
+
valid_sdf = valid_sdf.clamp_(-1, 1)
|
| 89 |
+
|
| 90 |
+
# Init weighting
|
| 91 |
+
w = torch.ones_like(valid_frame_depth)
|
| 92 |
+
else:
|
| 93 |
+
norm_dist = valid_sdf.abs()
|
| 94 |
+
w = torch.exp(-norm_dist.clamp_max(self.max_norm_dist))
|
| 95 |
+
|
| 96 |
+
# Alpha filtering
|
| 97 |
+
if alpha is not None:
|
| 98 |
+
valid_alpha = torch.nn.functional.grid_sample(
|
| 99 |
+
alpha.view(1,1,*alpha.shape[-2:]),
|
| 100 |
+
valid_uv.view(1,1,-1,2),
|
| 101 |
+
mode='bilinear',
|
| 102 |
+
align_corners=False).flatten()
|
| 103 |
+
w *= valid_alpha
|
| 104 |
+
|
| 105 |
+
filter_idx = torch.where(valid_alpha >= self.alpha_thres)[0]
|
| 106 |
+
valid_idx = valid_idx[filter_idx]
|
| 107 |
+
valid_uv = valid_uv[filter_idx]
|
| 108 |
+
valid_frame_depth = valid_frame_depth[filter_idx]
|
| 109 |
+
valid_sdf = valid_sdf[filter_idx]
|
| 110 |
+
w = w[filter_idx]
|
| 111 |
+
|
| 112 |
+
# Compute geometric weighting
|
| 113 |
+
if self.depth_weight:
|
| 114 |
+
w *= 1 / valid_frame_depth.clamp_min(0.1)
|
| 115 |
+
|
| 116 |
+
if self.normal_weight:
|
| 117 |
+
normal = cam.depth2normal(depth)
|
| 118 |
+
rd = torch.nn.functional.normalize(cam.depth2pts(depth) - cam.position.view(3,1,1), dim=0)
|
| 119 |
+
cos_theta = (normal * rd).sum(0).clamp_min(0)
|
| 120 |
+
valid_cos_theta = torch.nn.functional.grid_sample(
|
| 121 |
+
cos_theta.view(1,1,*cos_theta.shape[-2:]),
|
| 122 |
+
valid_uv.view(1,1,-1,2),
|
| 123 |
+
mode='bilinear',
|
| 124 |
+
align_corners=False).flatten()
|
| 125 |
+
w *= valid_cos_theta
|
| 126 |
+
|
| 127 |
+
if self.border_weight:
|
| 128 |
+
# The image center get 1.0; corners get 0.1
|
| 129 |
+
w *= 1 / (1 + 9/np.sqrt(2) * valid_uv.square().sum(1).sqrt())
|
| 130 |
+
|
| 131 |
+
# Reshape integration weight
|
| 132 |
+
w = w.unsqueeze(-1).to(self.dtype)
|
| 133 |
+
|
| 134 |
+
# Integrate weight
|
| 135 |
+
self.weight[valid_idx] += w
|
| 136 |
+
|
| 137 |
+
# Integrate tsdf
|
| 138 |
+
if self.fuse_tsdf:
|
| 139 |
+
valid_sdf = valid_sdf.unsqueeze(-1).to(self.dtype)
|
| 140 |
+
self.sd_val[valid_idx] += w * valid_sdf
|
| 141 |
+
|
| 142 |
+
# Sample feature
|
| 143 |
+
if self.feat_dim > 0:
|
| 144 |
+
valid_feat = torch.nn.functional.grid_sample(
|
| 145 |
+
feat.view(1,self.feat_dim,*feat.shape[-2:]).to(self.dtype),
|
| 146 |
+
valid_uv.view(1,1,-1,2).to(self.dtype),
|
| 147 |
+
mode='bilinear',
|
| 148 |
+
align_corners=False)[0,:,0].T
|
| 149 |
+
self.feat[valid_idx] += w * valid_feat
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def feature(self):
|
| 153 |
+
return self.feat / self.weight
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def tsdf(self):
|
| 157 |
+
return self.sd_val / self.weight
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def rgb_fusion(voxel_model, cameras):
|
| 162 |
+
|
| 163 |
+
from .octree_utils import level_2_vox_size
|
| 164 |
+
|
| 165 |
+
# Define volume integrator
|
| 166 |
+
finest_vox_size = level_2_vox_size(voxel_model.scene_extent, voxel_model.octlevel.max()).item()
|
| 167 |
+
feat_volume = Fuser(
|
| 168 |
+
xyz=voxel_model.vox_center,
|
| 169 |
+
bandwidth=10 * finest_vox_size,
|
| 170 |
+
use_trunc=False,
|
| 171 |
+
fuse_tsdf=False,
|
| 172 |
+
feat_dim=3,
|
| 173 |
+
crop_border=0.,
|
| 174 |
+
normal_weight=False,
|
| 175 |
+
depth_weight=False,
|
| 176 |
+
border_weight=False,
|
| 177 |
+
use_half=True)
|
| 178 |
+
|
| 179 |
+
# Run semantic maps fusion
|
| 180 |
+
for cam in cameras:
|
| 181 |
+
render_pkg = voxel_model.render(cam, color_mode="dontcare", output_depth=True)
|
| 182 |
+
depth = render_pkg['depth'][2]
|
| 183 |
+
feat_volume.integrate(cam=cam, feat=cam.image.cuda(), depth=depth)
|
| 184 |
+
|
| 185 |
+
return feat_volume.feature.nan_to_num_(0.5).float()
|