Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import models | |
| from models.utils import get_activation | |
| from models.network_utils import get_encoding, get_mlp | |
| from systems.utils import update_module_step | |
| class VolumeRadiance(nn.Module): | |
| def __init__(self, config): | |
| super(VolumeRadiance, self).__init__() | |
| self.config = config | |
| self.with_viewdir = False #self.config.get('wo_viewdir', False) | |
| self.n_dir_dims = self.config.get('n_dir_dims', 3) | |
| self.n_output_dims = 3 | |
| if self.with_viewdir: | |
| encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config) | |
| self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims | |
| # self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config) | |
| else: | |
| encoding = None | |
| self.n_input_dims = self.config.input_feature_dim | |
| network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) | |
| self.encoding = encoding | |
| self.network = network | |
| def forward(self, features, dirs, *args): | |
| # features = features.detach() | |
| if self.with_viewdir: | |
| dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1) | |
| dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims)) | |
| network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) | |
| # network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) | |
| color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() | |
| # color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float() | |
| # color = color + color_base | |
| else: | |
| network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) | |
| color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() | |
| if 'color_activation' in self.config: | |
| color = get_activation(self.config.color_activation)(color) | |
| return color | |
| def update_step(self, epoch, global_step): | |
| update_module_step(self.encoding, epoch, global_step) | |
| def regularizations(self, out): | |
| return {} | |
| class VolumeColor(nn.Module): | |
| def __init__(self, config): | |
| super(VolumeColor, self).__init__() | |
| self.config = config | |
| self.n_output_dims = 3 | |
| self.n_input_dims = self.config.input_feature_dim | |
| network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) | |
| self.network = network | |
| def forward(self, features, *args): | |
| network_inp = features.view(-1, features.shape[-1]) | |
| color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() | |
| if 'color_activation' in self.config: | |
| color = get_activation(self.config.color_activation)(color) | |
| return color | |
| def regularizations(self, out): | |
| return {} | |