Spaces:
Configuration error
Configuration error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import models | |
| from models.base import BaseModel | |
| from models.utils import chunk_batch | |
| from systems.utils import update_module_step | |
| from nerfacc import ( | |
| ContractionType, | |
| OccupancyGrid, | |
| ray_marching, | |
| render_weight_from_density, | |
| render_weight_from_alpha, | |
| accumulate_along_rays, | |
| ) | |
| from nerfacc.intersection import ray_aabb_intersect | |
| import pdb | |
| class VarianceNetwork(nn.Module): | |
| def __init__(self, config): | |
| super(VarianceNetwork, self).__init__() | |
| self.config = config | |
| self.init_val = self.config.init_val | |
| self.register_parameter( | |
| "variance", nn.Parameter(torch.tensor(self.config.init_val)) | |
| ) | |
| self.modulate = self.config.get("modulate", False) | |
| if self.modulate: | |
| self.mod_start_steps = self.config.mod_start_steps | |
| self.reach_max_steps = self.config.reach_max_steps | |
| self.max_inv_s = self.config.max_inv_s | |
| def inv_s(self): | |
| val = torch.exp(self.variance * 10.0) | |
| if self.modulate and self.do_mod: | |
| val = val.clamp_max(self.mod_val) | |
| return val | |
| def forward(self, x): | |
| return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s | |
| def update_step(self, epoch, global_step): | |
| if self.modulate: | |
| self.do_mod = global_step > self.mod_start_steps | |
| if not self.do_mod: | |
| self.prev_inv_s = self.inv_s.item() | |
| else: | |
| self.mod_val = min( | |
| (global_step / self.reach_max_steps) | |
| * (self.max_inv_s - self.prev_inv_s) | |
| + self.prev_inv_s, | |
| self.max_inv_s, | |
| ) | |
| class NeuSModel(BaseModel): | |
| def setup(self): | |
| self.geometry = models.make(self.config.geometry.name, self.config.geometry) | |
| self.texture = models.make(self.config.texture.name, self.config.texture) | |
| self.geometry.contraction_type = ContractionType.AABB | |
| if self.config.learned_background: | |
| self.geometry_bg = models.make( | |
| self.config.geometry_bg.name, self.config.geometry_bg | |
| ) | |
| self.texture_bg = models.make( | |
| self.config.texture_bg.name, self.config.texture_bg | |
| ) | |
| self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE | |
| self.near_plane_bg, self.far_plane_bg = 0.1, 1e3 | |
| self.cone_angle_bg = ( | |
| 10 | |
| ** (math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg) | |
| - 1.0 | |
| ) | |
| self.render_step_size_bg = 0.01 | |
| self.variance = VarianceNetwork(self.config.variance) | |
| self.register_buffer( | |
| "scene_aabb", | |
| torch.as_tensor( | |
| [ | |
| -self.config.radius, | |
| -self.config.radius, | |
| -self.config.radius, | |
| self.config.radius, | |
| self.config.radius, | |
| self.config.radius, | |
| ], | |
| dtype=torch.float32, | |
| ), | |
| ) | |
| if self.config.grid_prune: | |
| self.occupancy_grid = OccupancyGrid( | |
| roi_aabb=self.scene_aabb, | |
| resolution=128, | |
| contraction_type=ContractionType.AABB, | |
| ) | |
| if self.config.learned_background: | |
| self.occupancy_grid_bg = OccupancyGrid( | |
| roi_aabb=self.scene_aabb, | |
| resolution=256, | |
| contraction_type=ContractionType.UN_BOUNDED_SPHERE, | |
| ) | |
| self.randomized = self.config.randomized | |
| self.background_color = None | |
| self.render_step_size = ( | |
| 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray | |
| ) | |
| def update_step(self, epoch, global_step): | |
| update_module_step(self.geometry, epoch, global_step) | |
| update_module_step(self.texture, epoch, global_step) | |
| if self.config.learned_background: | |
| update_module_step(self.geometry_bg, epoch, global_step) | |
| update_module_step(self.texture_bg, epoch, global_step) | |
| update_module_step(self.variance, epoch, global_step) | |
| cos_anneal_end = self.config.get("cos_anneal_end", 0) | |
| self.cos_anneal_ratio = ( | |
| 1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end) | |
| ) | |
| def occ_eval_fn(x): | |
| sdf = self.geometry(x, with_grad=False, with_feature=False) | |
| inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) | |
| inv_s = inv_s.expand(sdf.shape[0], 1) | |
| estimated_next_sdf = sdf[..., None] - self.render_step_size * 0.5 | |
| estimated_prev_sdf = sdf[..., None] + self.render_step_size * 0.5 | |
| prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) | |
| next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) | |
| p = prev_cdf - next_cdf | |
| c = prev_cdf | |
| alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0) | |
| return alpha | |
| def occ_eval_fn_bg(x): | |
| density, _ = self.geometry_bg(x) | |
| # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size_bg) based on taylor series | |
| return density[..., None] * self.render_step_size_bg | |
| if self.training and self.config.grid_prune: | |
| self.occupancy_grid.every_n_step( | |
| step=global_step, | |
| occ_eval_fn=occ_eval_fn, | |
| occ_thre=self.config.get("grid_prune_occ_thre", 0.01), | |
| ) | |
| if self.config.learned_background: | |
| self.occupancy_grid_bg.every_n_step( | |
| step=global_step, | |
| occ_eval_fn=occ_eval_fn_bg, | |
| occ_thre=self.config.get("grid_prune_occ_thre_bg", 0.01), | |
| ) | |
| def isosurface(self): | |
| mesh = self.geometry.isosurface() | |
| return mesh | |
| def get_alpha(self, sdf, normal, dirs, dists): | |
| inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip( | |
| 1e-6, 1e6 | |
| ) # Single parameter | |
| inv_s = inv_s.expand(sdf.shape[0], 1) | |
| true_cos = (dirs * normal).sum(-1, keepdim=True) | |
| # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes | |
| # the cos value "not dead" at the beginning training iterations, for better convergence. | |
| iter_cos = -( | |
| F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) | |
| + F.relu(-true_cos) * self.cos_anneal_ratio | |
| ) # always non-positive | |
| # Estimate signed distances at section points | |
| estimated_next_sdf = sdf[..., None] + iter_cos * dists.reshape(-1, 1) * 0.5 | |
| estimated_prev_sdf = sdf[..., None] - iter_cos * dists.reshape(-1, 1) * 0.5 | |
| prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) | |
| next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) | |
| p = prev_cdf - next_cdf | |
| c = prev_cdf | |
| alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0) | |
| return alpha | |
| def forward_bg_(self, rays): | |
| n_rays = rays.shape[0] | |
| rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) | |
| def sigma_fn(t_starts, t_ends, ray_indices): | |
| ray_indices = ray_indices.long() | |
| t_origins = rays_o[ray_indices] | |
| t_dirs = rays_d[ray_indices] | |
| positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 | |
| density, _ = self.geometry_bg(positions) | |
| return density[..., None] | |
| _, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb) | |
| # if the ray intersects with the bounding box, start from the farther intersection point | |
| # otherwise start from self.far_plane_bg | |
| # note that in nerfacc t_max is set to 1e10 if there is no intersection | |
| near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max) | |
| with torch.no_grad(): | |
| ray_indices, t_starts, t_ends = ray_marching( | |
| rays_o, | |
| rays_d, | |
| scene_aabb=None, | |
| grid=self.occupancy_grid_bg if self.config.grid_prune else None, | |
| sigma_fn=sigma_fn, | |
| near_plane=near_plane, | |
| far_plane=self.far_plane_bg, | |
| render_step_size=self.render_step_size_bg, | |
| stratified=self.randomized, | |
| cone_angle=self.cone_angle_bg, | |
| alpha_thre=0.0, | |
| ) | |
| ray_indices = ray_indices.long() | |
| t_origins = rays_o[ray_indices] | |
| t_dirs = rays_d[ray_indices] | |
| midpoints = (t_starts + t_ends) / 2.0 | |
| positions = t_origins + t_dirs * midpoints | |
| intervals = t_ends - t_starts | |
| density, feature = self.geometry_bg(positions) | |
| rgb = self.texture_bg(feature, t_dirs) | |
| weights = render_weight_from_density( | |
| t_starts, t_ends, density[..., None], ray_indices=ray_indices, n_rays=n_rays | |
| ) | |
| opacity = accumulate_along_rays( | |
| weights, ray_indices, values=None, n_rays=n_rays | |
| ) | |
| depth = accumulate_along_rays( | |
| weights, ray_indices, values=midpoints, n_rays=n_rays | |
| ) | |
| comp_rgb = accumulate_along_rays( | |
| weights, ray_indices, values=rgb, n_rays=n_rays | |
| ) | |
| comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) | |
| out = { | |
| "comp_rgb": comp_rgb, | |
| "opacity": opacity, | |
| "depth": depth, | |
| "rays_valid": opacity > 0, | |
| "num_samples": torch.as_tensor( | |
| [len(t_starts)], dtype=torch.int32, device=rays.device | |
| ), | |
| } | |
| if self.training: | |
| out.update( | |
| { | |
| "weights": weights.view(-1), | |
| "points": midpoints.view(-1), | |
| "intervals": intervals.view(-1), | |
| "ray_indices": ray_indices.view(-1), | |
| } | |
| ) | |
| return out | |
| def forward_(self, rays): | |
| n_rays = rays.shape[0] | |
| rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) | |
| with torch.no_grad(): | |
| ray_indices, t_starts, t_ends = ray_marching( | |
| rays_o, | |
| rays_d, | |
| scene_aabb=self.scene_aabb, | |
| grid=self.occupancy_grid if self.config.grid_prune else None, | |
| alpha_fn=None, | |
| near_plane=None, | |
| far_plane=None, | |
| render_step_size=self.render_step_size, | |
| stratified=self.randomized, | |
| cone_angle=0.0, | |
| alpha_thre=0.0, | |
| ) | |
| ray_indices = ray_indices.long() | |
| t_origins = rays_o[ray_indices] | |
| t_dirs = rays_d[ray_indices] | |
| midpoints = (t_starts + t_ends) / 2.0 | |
| positions = t_origins + t_dirs * midpoints | |
| dists = t_ends - t_starts | |
| if self.config.geometry.grad_type == "finite_difference": | |
| sdf, sdf_grad, feature, sdf_laplace = self.geometry( | |
| positions, with_grad=True, with_feature=True, with_laplace=True | |
| ) | |
| else: | |
| sdf, sdf_grad, feature = self.geometry( | |
| positions, with_grad=True, with_feature=True | |
| ) | |
| normal = F.normalize(sdf_grad, p=2, dim=-1) | |
| alpha = self.get_alpha(sdf, normal, t_dirs, dists)[..., None] | |
| rgb = self.texture(feature, t_dirs, normal) | |
| weights = render_weight_from_alpha( | |
| alpha, ray_indices=ray_indices, n_rays=n_rays | |
| ) | |
| opacity = accumulate_along_rays( | |
| weights, ray_indices, values=None, n_rays=n_rays | |
| ) | |
| depth = accumulate_along_rays( | |
| weights, ray_indices, values=midpoints, n_rays=n_rays | |
| ) | |
| comp_rgb = accumulate_along_rays( | |
| weights, ray_indices, values=rgb, n_rays=n_rays | |
| ) | |
| comp_normal = accumulate_along_rays( | |
| weights, ray_indices, values=normal, n_rays=n_rays | |
| ) | |
| comp_normal = F.normalize(comp_normal, p=2, dim=-1) | |
| pts_random = ( | |
| torch.rand([1024 * 2, 3]).to(sdf.dtype).to(sdf.device) * 2 - 1 | |
| ) # normalized to (-1, 1) | |
| if self.config.geometry.grad_type == "finite_difference": | |
| random_sdf, random_sdf_grad, _ = self.geometry( | |
| pts_random, with_grad=True, with_feature=False, with_laplace=True | |
| ) | |
| _, normal_perturb, _ = self.geometry( | |
| pts_random + torch.randn_like(pts_random) * 1e-2, | |
| with_grad=True, | |
| with_feature=False, | |
| with_laplace=True, | |
| ) | |
| else: | |
| random_sdf, random_sdf_grad = self.geometry( | |
| pts_random, with_grad=True, with_feature=False | |
| ) | |
| _, normal_perturb = self.geometry( | |
| positions + torch.randn_like(positions) * 1e-2, | |
| with_grad=True, | |
| with_feature=False, | |
| ) | |
| # pdb.set_trace() | |
| out = { | |
| "comp_rgb": comp_rgb, | |
| "comp_normal": comp_normal, | |
| "opacity": opacity, | |
| "depth": depth, | |
| "rays_valid": opacity > 0, | |
| "num_samples": torch.as_tensor( | |
| [len(t_starts)], dtype=torch.int32, device=rays.device | |
| ), | |
| } | |
| if self.training: | |
| out.update( | |
| { | |
| "sdf_samples": sdf, | |
| "sdf_grad_samples": sdf_grad, | |
| "random_sdf": random_sdf, | |
| "random_sdf_grad": random_sdf_grad, | |
| "normal_perturb": normal_perturb, | |
| "weights": weights.view(-1), | |
| "points": midpoints.view(-1), | |
| "intervals": dists.view(-1), | |
| "ray_indices": ray_indices.view(-1), | |
| } | |
| ) | |
| if self.config.geometry.grad_type == "finite_difference": | |
| out.update({"sdf_laplace_samples": sdf_laplace}) | |
| if self.config.learned_background: | |
| out_bg = self.forward_bg_(rays) | |
| else: | |
| out_bg = { | |
| "comp_rgb": self.background_color[None, :].expand(*comp_rgb.shape), | |
| "num_samples": torch.zeros_like(out["num_samples"]), | |
| "rays_valid": torch.zeros_like(out["rays_valid"]), | |
| } | |
| out_full = { | |
| "comp_rgb": out["comp_rgb"] + out_bg["comp_rgb"] * (1.0 - out["opacity"]), | |
| "num_samples": out["num_samples"] + out_bg["num_samples"], | |
| "rays_valid": out["rays_valid"] | out_bg["rays_valid"], | |
| } | |
| return { | |
| **out, | |
| **{k + "_bg": v for k, v in out_bg.items()}, | |
| **{k + "_full": v for k, v in out_full.items()}, | |
| } | |
| def forward(self, rays): | |
| if self.training: | |
| out = self.forward_(rays) | |
| else: | |
| out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) | |
| return {**out, "inv_s": self.variance.inv_s} | |
| def train(self, mode=True): | |
| self.randomized = mode and self.config.randomized | |
| return super().train(mode=mode) | |
| def eval(self): | |
| self.randomized = False | |
| return super().eval() | |
| def regularizations(self, out): | |
| losses = {} | |
| losses.update(self.geometry.regularizations(out)) | |
| losses.update(self.texture.regularizations(out)) | |
| return losses | |
| def export(self, export_config): | |
| mesh = self.isosurface() | |
| if export_config.export_vertex_color: | |
| _, sdf_grad, feature = chunk_batch( | |
| self.geometry, | |
| export_config.chunk_size, | |
| False, | |
| mesh["v_pos"].to(self.rank), | |
| with_grad=True, | |
| with_feature=True, | |
| ) | |
| normal = F.normalize(sdf_grad, p=2, dim=-1) | |
| rgb = self.texture( | |
| feature, -normal, normal | |
| ) # set the viewing directions to the normal to get "albedo" | |
| mesh["v_rgb"] = rgb.cpu() | |
| return mesh | |