Spaces:
Configuration error
Configuration error
| import os | |
| import math | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader, IterableDataset | |
| import torchvision.transforms.functional as TF | |
| import pytorch_lightning as pl | |
| import datasets | |
| from datasets.colmap_utils import \ | |
| read_cameras_binary, read_images_binary, read_points3d_binary | |
| from models.ray_utils import get_ray_directions | |
| from utils.misc import get_rank | |
| def get_center(pts): | |
| center = pts.mean(0) | |
| dis = (pts - center[None,:]).norm(p=2, dim=-1) | |
| mean, std = dis.mean(), dis.std() | |
| q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75) | |
| valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5) | |
| center = pts[valid].mean(0) | |
| return center | |
| def normalize_poses(poses, pts, up_est_method, center_est_method): | |
| if center_est_method == 'camera': | |
| # estimation scene center as the average of all camera positions | |
| center = poses[...,3].mean(0) | |
| elif center_est_method == 'lookat': | |
| # estimation scene center as the average of the intersection of selected pairs of camera rays | |
| cams_ori = poses[...,3] | |
| cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.]) | |
| cams_dir = F.normalize(cams_dir, dim=-1) | |
| A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1) | |
| b = -cams_ori + cams_ori.roll(1,0) | |
| t = torch.linalg.lstsq(A, b).solution | |
| center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2)) | |
| elif center_est_method == 'point': | |
| # first estimation scene center as the average of all camera positions | |
| # later we'll use the center of all points bounded by the cameras as the final scene center | |
| center = poses[...,3].mean(0) | |
| else: | |
| raise NotImplementedError(f'Unknown center estimation method: {center_est_method}') | |
| if up_est_method == 'ground': | |
| # estimate up direction as the normal of the estimated ground plane | |
| # use RANSAC to estimate the ground plane in the point cloud | |
| import pyransac3d as pyrsc | |
| ground = pyrsc.Plane() | |
| plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale | |
| plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0 | |
| z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction | |
| signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1) | |
| if signed_distance.mean() < 0: | |
| z = -z # flip the direction if points lie under the plane | |
| elif up_est_method == 'camera': | |
| # estimate up direction as the average of all camera up directions | |
| z = F.normalize((poses[...,3] - center).mean(0), dim=0) | |
| else: | |
| raise NotImplementedError(f'Unknown up estimation method: {up_est_method}') | |
| # new axis | |
| y_ = torch.as_tensor([z[1], -z[0], 0.]) | |
| x = F.normalize(y_.cross(z), dim=0) | |
| y = z.cross(x) | |
| if center_est_method == 'point': | |
| # rotation | |
| Rc = torch.stack([x, y, z], dim=1) | |
| R = Rc.T | |
| poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) | |
| inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) | |
| poses_norm = (inv_trans @ poses_homo)[:,:3] | |
| pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] | |
| # translation and scaling | |
| poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0] | |
| pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])] | |
| center = get_center(pts_fg) | |
| tc = center.reshape(3, 1) | |
| t = -tc | |
| poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1) | |
| inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) | |
| poses_norm = (inv_trans @ poses_homo)[:,:3] | |
| scale = poses_norm[...,3].norm(p=2, dim=-1).min() | |
| poses_norm[...,3] /= scale | |
| pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] | |
| pts = pts / scale | |
| else: | |
| # rotation and translation | |
| Rc = torch.stack([x, y, z], dim=1) | |
| tc = center.reshape(3, 1) | |
| R, t = Rc.T, -Rc.T @ tc | |
| poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) | |
| inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) | |
| poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4) | |
| # scaling | |
| scale = poses_norm[...,3].norm(p=2, dim=-1).min() | |
| poses_norm[...,3] /= scale | |
| # apply the transformation to the point cloud | |
| pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] | |
| pts = pts / scale | |
| return poses_norm, pts | |
| def create_spheric_poses(cameras, n_steps=120): | |
| center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) | |
| mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean() | |
| mean_h = cameras[:,2].mean() | |
| r = (mean_d**2 - mean_h**2).sqrt() | |
| up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device) | |
| all_c2w = [] | |
| for theta in torch.linspace(0, 2 * math.pi, n_steps): | |
| cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h]) | |
| l = F.normalize(center - cam_pos, p=2, dim=0) | |
| s = F.normalize(l.cross(up), p=2, dim=0) | |
| u = F.normalize(s.cross(l), p=2, dim=0) | |
| c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) | |
| all_c2w.append(c2w) | |
| all_c2w = torch.stack(all_c2w, dim=0) | |
| return all_c2w | |
| class ColmapDatasetBase(): | |
| # the data only has to be processed once | |
| initialized = False | |
| properties = {} | |
| def setup(self, config, split): | |
| self.config = config | |
| self.split = split | |
| self.rank = get_rank() | |
| if not ColmapDatasetBase.initialized: | |
| camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin')) | |
| H = int(camdata[1].height) | |
| W = int(camdata[1].width) | |
| if 'img_wh' in self.config: | |
| w, h = self.config.img_wh | |
| assert round(W / w * h) == H | |
| elif 'img_downscale' in self.config: | |
| w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) | |
| else: | |
| raise KeyError("Either img_wh or img_downscale should be specified.") | |
| img_wh = (w, h) | |
| factor = w / W | |
| if camdata[1].model == 'SIMPLE_RADIAL': | |
| fx = fy = camdata[1].params[0] * factor | |
| cx = camdata[1].params[1] * factor | |
| cy = camdata[1].params[2] * factor | |
| elif camdata[1].model in ['PINHOLE', 'OPENCV']: | |
| fx = camdata[1].params[0] * factor | |
| fy = camdata[1].params[1] * factor | |
| cx = camdata[1].params[2] * factor | |
| cy = camdata[1].params[3] * factor | |
| else: | |
| raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") | |
| directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank) | |
| imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin')) | |
| mask_dir = os.path.join(self.config.root_dir, 'masks') | |
| has_mask = os.path.exists(mask_dir) # TODO: support partial masks | |
| apply_mask = has_mask and self.config.apply_mask | |
| all_c2w, all_images, all_fg_masks = [], [], [] | |
| for i, d in enumerate(imdata.values()): | |
| R = d.qvec2rotmat() | |
| t = d.tvec.reshape(3, 1) | |
| c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float() | |
| c2w[:,1:3] *= -1. # COLMAP => OpenGL | |
| all_c2w.append(c2w) | |
| if self.split in ['train', 'val']: | |
| img_path = os.path.join(self.config.root_dir, 'images', d.name) | |
| img = Image.open(img_path) | |
| img = img.resize(img_wh, Image.BICUBIC) | |
| img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] | |
| img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu() | |
| if has_mask: | |
| mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])] | |
| mask_paths = list(filter(os.path.exists, mask_paths)) | |
| assert len(mask_paths) == 1 | |
| mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1) | |
| mask = mask.resize(img_wh, Image.BICUBIC) | |
| mask = TF.to_tensor(mask)[0] | |
| else: | |
| mask = torch.ones_like(img[...,0], device=img.device) | |
| all_fg_masks.append(mask) # (h, w) | |
| all_images.append(img) | |
| all_c2w = torch.stack(all_c2w, dim=0) | |
| pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin')) | |
| pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float() | |
| all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method) | |
| ColmapDatasetBase.properties = { | |
| 'w': w, | |
| 'h': h, | |
| 'img_wh': img_wh, | |
| 'factor': factor, | |
| 'has_mask': has_mask, | |
| 'apply_mask': apply_mask, | |
| 'directions': directions, | |
| 'pts3d': pts3d, | |
| 'all_c2w': all_c2w, | |
| 'all_images': all_images, | |
| 'all_fg_masks': all_fg_masks | |
| } | |
| ColmapDatasetBase.initialized = True | |
| for k, v in ColmapDatasetBase.properties.items(): | |
| setattr(self, k, v) | |
| if self.split == 'test': | |
| self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) | |
| self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) | |
| self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) | |
| else: | |
| self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float() | |
| """ | |
| # for debug use | |
| from models.ray_utils import get_rays | |
| rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True) | |
| pts_out = [] | |
| pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()])) | |
| t_vals = torch.linspace(0, 1, 8) | |
| z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals | |
| ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :]) | |
| pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()])) | |
| ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :]) | |
| pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) | |
| ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :]) | |
| pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) | |
| ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :]) | |
| pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) | |
| open('cameras.txt', 'w').write('\n'.join(pts_out)) | |
| open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()])) | |
| exit(1) | |
| """ | |
| self.all_c2w = self.all_c2w.float().to(self.rank) | |
| if self.config.load_data_on_gpu: | |
| self.all_images = self.all_images.to(self.rank) | |
| self.all_fg_masks = self.all_fg_masks.to(self.rank) | |
| class ColmapDataset(Dataset, ColmapDatasetBase): | |
| def __init__(self, config, split): | |
| self.setup(config, split) | |
| def __len__(self): | |
| return len(self.all_images) | |
| def __getitem__(self, index): | |
| return { | |
| 'index': index | |
| } | |
| class ColmapIterableDataset(IterableDataset, ColmapDatasetBase): | |
| def __init__(self, config, split): | |
| self.setup(config, split) | |
| def __iter__(self): | |
| while True: | |
| yield {} | |
| class ColmapDataModule(pl.LightningDataModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| def setup(self, stage=None): | |
| if stage in [None, 'fit']: | |
| self.train_dataset = ColmapIterableDataset(self.config, 'train') | |
| if stage in [None, 'fit', 'validate']: | |
| self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train')) | |
| if stage in [None, 'test']: | |
| self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test')) | |
| if stage in [None, 'predict']: | |
| self.predict_dataset = ColmapDataset(self.config, 'train') | |
| def prepare_data(self): | |
| pass | |
| def general_loader(self, dataset, batch_size): | |
| sampler = None | |
| return DataLoader( | |
| dataset, | |
| num_workers=os.cpu_count(), | |
| batch_size=batch_size, | |
| pin_memory=True, | |
| sampler=sampler | |
| ) | |
| def train_dataloader(self): | |
| return self.general_loader(self.train_dataset, batch_size=1) | |
| def val_dataloader(self): | |
| return self.general_loader(self.val_dataset, batch_size=1) | |
| def test_dataloader(self): | |
| return self.general_loader(self.test_dataset, batch_size=1) | |
| def predict_dataloader(self): | |
| return self.general_loader(self.predict_dataset, batch_size=1) | |