Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2021 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| """Different datasets implementation plus a general port for all the datasets.""" | |
| INTERNAL = False # pylint: disable=g-statement-before-imports | |
| import json | |
| import os | |
| from os import path | |
| import queue | |
| import threading | |
| # if not INTERNAL: | |
| # import cv2 # pylint: disable=g-import-not-at-top | |
| import jax | |
| import numpy as np | |
| from PIL import Image | |
| from jaxnerf.nerf import utils | |
| from jaxnerf.nerf import clip_utils | |
| def get_dataset(split, args, clip_model = None): | |
| return dataset_dict[args.dataset](split, args, clip_model) | |
| def convert_to_ndc(origins, directions, focal, w, h, near=1.): | |
| """Convert a set of rays to NDC coordinates.""" | |
| # Shift ray origins to near plane | |
| t = -(near + origins[..., 2]) / directions[..., 2] | |
| origins = origins + t[..., None] * directions | |
| dx, dy, dz = tuple(np.moveaxis(directions, -1, 0)) | |
| ox, oy, oz = tuple(np.moveaxis(origins, -1, 0)) | |
| # Projection | |
| o0 = -((2 * focal) / w) * (ox / oz) | |
| o1 = -((2 * focal) / h) * (oy / oz) | |
| o2 = 1 + 2 * near / oz | |
| d0 = -((2 * focal) / w) * (dx / dz - ox / oz) | |
| d1 = -((2 * focal) / h) * (dy / dz - oy / oz) | |
| d2 = -2 * near / oz | |
| origins = np.stack([o0, o1, o2], -1) | |
| directions = np.stack([d0, d1, d2], -1) | |
| return origins, directions | |
| class Dataset(threading.Thread): | |
| """Dataset Base Class.""" | |
| def __init__(self, split, flags, clip_model): | |
| super(Dataset, self).__init__() | |
| self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches. | |
| self.daemon = True | |
| self.use_pixel_centers = flags.use_pixel_centers | |
| self.split = split | |
| if split == "train": | |
| self._train_init(flags, clip_model) | |
| elif split == "test": | |
| self._test_init(flags) | |
| else: | |
| raise ValueError( | |
| "the split argument should be either \"train\" or \"test\", set" | |
| "to {} here.".format(split)) | |
| self.batch_size = flags.batch_size // jax.process_count() | |
| self.batching = flags.batching | |
| self.render_path = flags.render_path | |
| self.far = flags.far | |
| self.near = flags.near | |
| self.max_steps = flags.max_steps | |
| self.sc_loss_factor = flags.sc_loss_factor | |
| self.start() | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| """Get the next training batch or test example. | |
| Returns: | |
| batch: dict, has "pixels" and "rays". | |
| """ | |
| x = self.queue.get() | |
| if self.split == "train": | |
| return utils.shard(x) | |
| else: | |
| return utils.to_device(x) | |
| def peek(self): | |
| """Peek at the next training batch or test example without dequeuing it. | |
| Returns: | |
| batch: dict, has "pixels" and "rays". | |
| """ | |
| x = self.queue.queue[0].copy() # Make a copy of the front of the queue. | |
| if self.split == "train": | |
| return utils.shard(x) | |
| else: | |
| return utils.to_device(x) | |
| def run(self): | |
| if self.split == "train": | |
| next_func = self._next_train | |
| else: | |
| next_func = self._next_test | |
| while True: | |
| self.queue.put(next_func()) | |
| def size(self): | |
| return self.n_examples | |
| def _train_init(self, flags, clip_model): | |
| """Initialize training.""" | |
| self._load_renderings(flags, clip_model) | |
| self._generate_rays() | |
| if flags.batching == "all_images": | |
| # flatten the ray and image dimension together. | |
| self.images = self.images.reshape([-1, 3]) | |
| self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]), | |
| self.rays) | |
| elif flags.batching == "single_image": | |
| self.images = self.images.reshape([-1, self.resolution, 3]) | |
| self.rays = utils.namedtuple_map( | |
| lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays) | |
| else: | |
| raise NotImplementedError( | |
| f"{flags.batching} batching strategy is not implemented.") | |
| def _test_init(self, flags): | |
| self._load_renderings(flags, clip_model = None) | |
| self._generate_rays() | |
| self.it = 0 | |
| def _next_train(self): | |
| """Sample next training batch.""" | |
| if self.batching == "all_images": | |
| ray_indices = np.random.randint(0, self.rays[0].shape[0], | |
| (self.batch_size,)) | |
| batch_pixels = self.images[ray_indices] | |
| batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays) | |
| raise NotImplementedError("image_index not implemented for batching=all_images") | |
| elif self.batching == "single_image": | |
| image_index = np.random.randint(0, self.n_examples, ()) | |
| ray_indices = np.random.randint(0, self.rays[0][0].shape[0], | |
| (self.batch_size,)) | |
| batch_pixels = self.images[image_index][ray_indices] | |
| batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices], | |
| self.rays) | |
| else: | |
| raise NotImplementedError( | |
| f"{self.batching} batching strategy is not implemented.") | |
| return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index} | |
| def _next_test(self): | |
| """Sample next test example.""" | |
| idx = self.it | |
| self.it = (self.it + 1) % self.n_examples | |
| if self.render_path: | |
| return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)} | |
| else: | |
| return {"pixels": self.images[idx], | |
| "rays": utils.namedtuple_map(lambda r: r[idx], self.rays), | |
| "image_index": idx} | |
| # TODO(bydeng): Swap this function with a more flexible camera model. | |
| def _generate_rays(self): | |
| """Generating rays for all images.""" | |
| pixel_center = 0.5 if self.use_pixel_centers else 0.0 | |
| x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking | |
| np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns) | |
| np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows) | |
| indexing="xy") | |
| camera_dirs = np.stack([(x - self.w * 0.5) / self.focal, | |
| -(y - self.h * 0.5) / self.focal, -np.ones_like(x)], | |
| axis=-1) | |
| directions = ((camera_dirs[None, ..., None, :] * | |
| self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1)) | |
| origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1], | |
| directions.shape) | |
| viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) | |
| self.rays = utils.Rays( | |
| origins=origins, directions=directions, viewdirs=viewdirs) | |
| def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1): | |
| """ render one instance of rays given a camera to world matrix (4, 4) """ | |
| pixel_center = 0.5 if self.use_pixel_centers else 0.0 | |
| # TODO @Alex: apply mesh downsampling here | |
| x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking | |
| np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns) | |
| np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows) | |
| indexing="xy") | |
| camera_dirs = np.stack([(x - self.w * 0.5) / self.focal, | |
| -(y - self.h * 0.5) / self.focal, -np.ones_like(x)], | |
| axis=-1) | |
| directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1) | |
| origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape) | |
| viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) | |
| return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs) | |
| class Blender(Dataset): | |
| """Blender Dataset.""" | |
| def _load_renderings(self, flags, clip_model = None): | |
| """Load images from disk.""" | |
| if flags.render_path: | |
| raise ValueError("render_path cannot be used for the blender dataset.") | |
| cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot) | |
| # load in CLIP precomputed image features | |
| self.images = np.stack(images, axis=0) | |
| if flags.white_bkgd: | |
| self.images = (self.images[..., :3] * self.images[..., -1:] + | |
| (1. - self.images[..., -1:])) | |
| else: | |
| self.images = self.images[..., :3] | |
| self.h, self.w = self.images.shape[1:3] | |
| self.resolution = self.h * self.w | |
| self.camtoworlds = np.stack(cams, axis=0) | |
| camera_angle_x = float(meta["camera_angle_x"]) | |
| self.focal = .5 * self.w / np.tan(.5 * camera_angle_x) | |
| self.n_examples = self.images.shape[0] | |
| if flags.use_semantic_loss and clip_model is not None: | |
| embs = [] | |
| for img in self.images: | |
| img = np.expand_dims(np.transpose(img,[2,0,1]), 0) | |
| embs.append(clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img))) | |
| self.embeddings = np.concatenate(embs, 0) | |
| self.image_idx = np.arange(self.images.shape[0]) | |
| np.random.shuffle(self.image_idx) | |
| self.image_idx = self.image_idx.tolist() | |
| # self.embeddings = utils.read_pickle(flags.precompute_pkl_path) | |
| # self.precompute_pkl_path = flags.precompute_pkl_path | |
| def load_files(data_dir, split, factor, few_shot): | |
| with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: | |
| meta = json.load(fp) | |
| images = [] | |
| cams = [] | |
| frames = np.arange(len(meta["frames"])) | |
| if few_shot > 0 and split == 'train': | |
| np.random.shuffle(frames) | |
| frames = frames[:few_shot] | |
| for i in frames: | |
| frame = meta["frames"][i] | |
| fname = os.path.join(data_dir, frame["file_path"] + ".png") | |
| with utils.open_file(fname, "rb") as imgin: | |
| image = np.array(Image.open(imgin)).astype(np.float32) / 255. | |
| if factor == 2: | |
| [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] | |
| image = cv2.resize(image, (halfres_w, halfres_h), | |
| interpolation=cv2.INTER_AREA) | |
| elif factor == 4: | |
| [halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]] | |
| image = cv2.resize(image, (halfres_w, halfres_h), | |
| interpolation=cv2.INTER_AREA) | |
| elif factor > 0: | |
| raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} " | |
| "set.".format(factor)) | |
| cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) | |
| images.append(image) | |
| return cams, images, meta | |
| def _next_train(self): | |
| batch_dict = super(Blender, self)._next_train() | |
| if self.batching == "single_image": | |
| image_index = batch_dict.pop("image_index") | |
| # target image for CLIP | |
| ''' | |
| batch_dict["embedding"] = self.embeddings[image_index] | |
| # source rays for CLIP (for constructing source image later) | |
| src_seed = int(np.random.randint(0, self.max_steps, ())) | |
| src_rng = jax.random.PRNGKey(src_seed) | |
| src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far))) | |
| random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16) | |
| random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3])) | |
| batch_dict["random_rays"] = random_rays | |
| ''' | |
| else: | |
| raise NotImplementedError | |
| return batch_dict | |
| def get_clip_data(self): | |
| if len(self.image_idx) == 0: | |
| self.image_idx = np.arange(self.images.shape[0]) | |
| np.random.shuffle(self.image_idx) | |
| self.image_idx = self.image_idx.tolist() | |
| image_index = self.image_idx.pop() | |
| batch_dict = {} | |
| batch_dict["embedding"] = self.embeddings[image_index] | |
| # source rays for CLIP (for constructing source image later) | |
| src_seed = int(np.random.randint(0, self.max_steps, ())) | |
| src_rng = jax.random.PRNGKey(src_seed) | |
| src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far))) | |
| random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16) | |
| random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3])) | |
| batch_dict["random_rays"] = random_rays | |
| return batch_dict | |
| class LLFF(Dataset): | |
| """LLFF Dataset.""" | |
| def _load_renderings(self, flags): | |
| """Load images from disk.""" | |
| # Load images. | |
| imgdir_suffix = "" | |
| if flags.factor > 0: | |
| imgdir_suffix = "_{}".format(flags.factor) | |
| factor = flags.factor | |
| else: | |
| factor = 1 | |
| imgdir = path.join(flags.data_dir, "images" + imgdir_suffix) | |
| if not utils.file_exists(imgdir): | |
| raise ValueError("Image folder {} doesn't exist.".format(imgdir)) | |
| imgfiles = [ | |
| path.join(imgdir, f) | |
| for f in sorted(utils.listdir(imgdir)) | |
| if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") | |
| ] | |
| images = [] | |
| for imgfile in imgfiles: | |
| with utils.open_file(imgfile, "rb") as imgin: | |
| image = np.array(Image.open(imgin), dtype=np.float32) / 255. | |
| images.append(image) | |
| images = np.stack(images, axis=-1) | |
| # Load poses and bds. | |
| with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"), | |
| "rb") as fp: | |
| poses_arr = np.load(fp) | |
| poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) | |
| bds = poses_arr[:, -2:].transpose([1, 0]) | |
| if poses.shape[-1] != images.shape[-1]: | |
| raise RuntimeError("Mismatch between imgs {} and poses {}".format( | |
| images.shape[-1], poses.shape[-1])) | |
| # Update poses according to downsampling. | |
| poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) | |
| poses[2, 4, :] = poses[2, 4, :] * 1. / factor | |
| # Correct rotation matrix ordering and move variable dim to axis 0. | |
| poses = np.concatenate( | |
| [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) | |
| poses = np.moveaxis(poses, -1, 0).astype(np.float32) | |
| images = np.moveaxis(images, -1, 0) | |
| bds = np.moveaxis(bds, -1, 0).astype(np.float32) | |
| # Rescale according to a default bd factor. | |
| scale = 1. / (bds.min() * .75) | |
| poses[:, :3, 3] *= scale | |
| bds *= scale | |
| # Recenter poses. | |
| poses = self._recenter_poses(poses) | |
| # Generate a spiral/spherical ray path for rendering videos. | |
| if flags.spherify: | |
| poses = self._generate_spherical_poses(poses, bds) | |
| self.spherify = True | |
| else: | |
| self.spherify = False | |
| if not flags.spherify and self.split == "test": | |
| self._generate_spiral_poses(poses, bds) | |
| # Select the split. | |
| i_test = np.arange(images.shape[0])[::flags.llffhold] | |
| i_train = np.array( | |
| [i for i in np.arange(int(images.shape[0])) if i not in i_test]) | |
| if self.split == "train": | |
| indices = i_train | |
| else: | |
| indices = i_test | |
| images = images[indices] | |
| poses = poses[indices] | |
| self.images = images | |
| self.camtoworlds = poses[:, :3, :4] | |
| self.focal = poses[0, -1, -1] | |
| self.h, self.w = images.shape[1:3] | |
| self.resolution = self.h * self.w | |
| if flags.render_path: | |
| self.n_examples = self.render_poses.shape[0] | |
| else: | |
| self.n_examples = images.shape[0] | |
| def _generate_rays(self): | |
| """Generate normalized device coordinate rays for llff.""" | |
| if self.split == "test": | |
| n_render_poses = self.render_poses.shape[0] | |
| self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds], | |
| axis=0) | |
| super()._generate_rays() | |
| if not self.spherify: | |
| ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins, | |
| self.rays.directions, | |
| self.focal, self.w, self.h) | |
| self.rays = utils.Rays( | |
| origins=ndc_origins, | |
| directions=ndc_directions, | |
| viewdirs=self.rays.viewdirs) | |
| # Split poses from the dataset and generated poses | |
| if self.split == "test": | |
| self.camtoworlds = self.camtoworlds[n_render_poses:] | |
| split = [np.split(r, [n_render_poses], 0) for r in self.rays] | |
| split0, split1 = zip(*split) | |
| self.render_rays = utils.Rays(*split0) | |
| self.rays = utils.Rays(*split1) | |
| def _recenter_poses(self, poses): | |
| """Recenter poses according to the original NeRF code.""" | |
| poses_ = poses.copy() | |
| bottom = np.reshape([0, 0, 0, 1.], [1, 4]) | |
| c2w = self._poses_avg(poses) | |
| c2w = np.concatenate([c2w[:3, :4], bottom], -2) | |
| bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) | |
| poses = np.concatenate([poses[:, :3, :4], bottom], -2) | |
| poses = np.linalg.inv(c2w) @ poses | |
| poses_[:, :3, :4] = poses[:, :3, :4] | |
| poses = poses_ | |
| return poses | |
| def _poses_avg(self, poses): | |
| """Average poses according to the original NeRF code.""" | |
| hwf = poses[0, :3, -1:] | |
| center = poses[:, :3, 3].mean(0) | |
| vec2 = self._normalize(poses[:, :3, 2].sum(0)) | |
| up = poses[:, :3, 1].sum(0) | |
| c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1) | |
| return c2w | |
| def _viewmatrix(self, z, up, pos): | |
| """Construct lookat view matrix.""" | |
| vec2 = self._normalize(z) | |
| vec1_avg = up | |
| vec0 = self._normalize(np.cross(vec1_avg, vec2)) | |
| vec1 = self._normalize(np.cross(vec2, vec0)) | |
| m = np.stack([vec0, vec1, vec2, pos], 1) | |
| return m | |
| def _normalize(self, x): | |
| """Normalization helper function.""" | |
| return x / np.linalg.norm(x) | |
| def _generate_spiral_poses(self, poses, bds): | |
| """Generate a spiral path for rendering.""" | |
| c2w = self._poses_avg(poses) | |
| # Get average pose. | |
| up = self._normalize(poses[:, :3, 1].sum(0)) | |
| # Find a reasonable "focus depth" for this dataset. | |
| close_depth, inf_depth = bds.min() * .9, bds.max() * 5. | |
| dt = .75 | |
| mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) | |
| focal = mean_dz | |
| # Get radii for spiral path. | |
| tt = poses[:, :3, 3] | |
| rads = np.percentile(np.abs(tt), 90, 0) | |
| c2w_path = c2w | |
| n_views = 120 | |
| n_rots = 2 | |
| # Generate poses for spiral path. | |
| render_poses = [] | |
| rads = np.array(list(rads) + [1.]) | |
| hwf = c2w_path[:, 4:5] | |
| zrate = .5 | |
| for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]: | |
| c = np.dot(c2w[:3, :4], (np.array( | |
| [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)) | |
| z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) | |
| render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1)) | |
| self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4] | |
| def _generate_spherical_poses(self, poses, bds): | |
| """Generate a 360 degree spherical path for rendering.""" | |
| # pylint: disable=g-long-lambda | |
| p34_to_44 = lambda p: np.concatenate([ | |
| p, | |
| np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1]) | |
| ], 1) | |
| rays_d = poses[:, :3, 2:3] | |
| rays_o = poses[:, :3, 3:4] | |
| def min_line_dist(rays_o, rays_d): | |
| a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) | |
| b_i = -a_i @ rays_o | |
| pt_mindist = np.squeeze(-np.linalg.inv( | |
| (np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0)) | |
| return pt_mindist | |
| pt_mindist = min_line_dist(rays_o, rays_d) | |
| center = pt_mindist | |
| up = (poses[:, :3, 3] - center).mean(0) | |
| vec0 = self._normalize(up) | |
| vec1 = self._normalize(np.cross([.1, .2, .3], vec0)) | |
| vec2 = self._normalize(np.cross(vec0, vec1)) | |
| pos = center | |
| c2w = np.stack([vec1, vec2, vec0, pos], 1) | |
| poses_reset = ( | |
| np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])) | |
| rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) | |
| sc = 1. / rad | |
| poses_reset[:, :3, 3] *= sc | |
| bds *= sc | |
| rad *= sc | |
| centroid = np.mean(poses_reset[:, :3, 3], 0) | |
| zh = centroid[2] | |
| radcircle = np.sqrt(rad ** 2 - zh ** 2) | |
| new_poses = [] | |
| for th in np.linspace(0., 2. * np.pi, 120): | |
| camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) | |
| up = np.array([0, 0, -1.]) | |
| vec2 = self._normalize(camorigin) | |
| vec0 = self._normalize(np.cross(vec2, up)) | |
| vec1 = self._normalize(np.cross(vec2, vec0)) | |
| pos = camorigin | |
| p = np.stack([vec0, vec1, vec2, pos], 1) | |
| new_poses.append(p) | |
| new_poses = np.stack(new_poses, 0) | |
| new_poses = np.concatenate([ | |
| new_poses, | |
| np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape) | |
| ], -1) | |
| poses_reset = np.concatenate([ | |
| poses_reset[:, :3, :4], | |
| np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape) | |
| ], -1) | |
| if self.split == "test": | |
| self.render_poses = new_poses[:, :3, :4] | |
| return poses_reset | |
| dataset_dict = {"blender": Blender, | |
| "llff": LLFF} | |