Spaces:
Configuration error
Configuration error
| import os | |
| import json | |
| import math | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, IterableDataset | |
| import torchvision.transforms.functional as TF | |
| from torchvision.utils import make_grid, save_image | |
| from einops import rearrange | |
| from mediapy import read_video | |
| from pathlib import Path | |
| from rembg import remove, new_session | |
| import pytorch_lightning as pl | |
| import datasets | |
| from models.ray_utils import get_ray_directions | |
| from utils.misc import get_rank | |
| from datasets.ortho import ( | |
| inv_RT, | |
| camNormal2worldNormal, | |
| RT_opengl2opencv, | |
| normal_opengl2opencv, | |
| ) | |
| from utils.dpt import DPT | |
| def get_c2w_from_up_and_look_at( | |
| up, | |
| look_at, | |
| pos, | |
| opengl=False, | |
| ): | |
| up = up / np.linalg.norm(up) | |
| z = look_at - pos | |
| z = z / np.linalg.norm(z) | |
| y = -up | |
| x = np.cross(y, z) | |
| x /= np.linalg.norm(x) | |
| y = np.cross(z, x) | |
| c2w = np.zeros([4, 4], dtype=np.float32) | |
| c2w[:3, 0] = x | |
| c2w[:3, 1] = y | |
| c2w[:3, 2] = z | |
| c2w[:3, 3] = pos | |
| c2w[3, 3] = 1.0 | |
| # opencv to opengl | |
| if opengl: | |
| c2w[..., 1:3] *= -1 | |
| return c2w | |
| def get_uniform_poses(num_frames, radius, elevation, opengl=False): | |
| T = num_frames | |
| azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T]) | |
| elevations = np.full_like(azimuths, np.deg2rad(elevation)) | |
| cam_dists = np.full_like(azimuths, radius) | |
| campos = np.stack( | |
| [ | |
| cam_dists * np.cos(elevations) * np.cos(azimuths), | |
| cam_dists * np.cos(elevations) * np.sin(azimuths), | |
| cam_dists * np.sin(elevations), | |
| ], | |
| axis=-1, | |
| ) | |
| center = np.array([0, 0, 0], dtype=np.float32) | |
| up = np.array([0, 0, 1], dtype=np.float32) | |
| poses = [] | |
| for t in range(T): | |
| poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl)) | |
| return np.stack(poses, axis=0) | |
| def blender2midas(img): | |
| """Blender: rub | |
| midas: lub | |
| """ | |
| img[..., 0] = -img[..., 0] | |
| img[..., 1] = -img[..., 1] | |
| img[..., -1] = -img[..., -1] | |
| return img | |
| def midas2blender(img): | |
| """Blender: rub | |
| midas: lub | |
| """ | |
| img[..., 0] = -img[..., 0] | |
| img[..., 1] = -img[..., 1] | |
| img[..., -1] = -img[..., -1] | |
| return img | |
| class BlenderDatasetBase: | |
| def setup(self, config, split): | |
| self.config = config | |
| self.rank = get_rank() | |
| self.has_mask = True | |
| self.apply_mask = True | |
| dpt = DPT(device=self.rank, mode="normal") | |
| # with open( | |
| # os.path.join( | |
| # self.config.root_dir, self.config.scene, f"transforms_train.json" | |
| # ), | |
| # "r", | |
| # ) as f: | |
| # meta = json.load(f) | |
| # if "w" in meta and "h" in meta: | |
| # W, H = int(meta["w"]), int(meta["h"]) | |
| # else: | |
| # W, H = 800, 800 | |
| frames = read_video(Path(self.config.root_dir) / f"{self.config.scene}") | |
| rembg_session = new_session() | |
| num_frames, H, W = frames.shape[:3] | |
| if "img_wh" in self.config: | |
| w, h = self.config.img_wh | |
| assert round(W / w * h) == H | |
| elif "img_downscale" in self.config: | |
| w, h = W // self.config.img_downscale, H // self.config.img_downscale | |
| else: | |
| raise KeyError("Either img_wh or img_downscale should be specified.") | |
| self.w, self.h = w, h | |
| self.img_wh = (self.w, self.h) | |
| # self.near, self.far = self.config.near_plane, self.config.far_plane | |
| self.focal = 0.5 * w / math.tan(0.5 * np.deg2rad(60)) # scaled focal length | |
| # ray directions for all pixels, same for all images (same H, W, focal) | |
| self.directions = get_ray_directions( | |
| self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2 | |
| ).to( | |
| self.rank | |
| ) # (h, w, 3) | |
| self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] | |
| radius = 2.0 | |
| elevation = 0.0 | |
| poses = get_uniform_poses(num_frames, radius, elevation, opengl=True) | |
| for i, (c2w, frame) in enumerate(zip(poses, frames)): | |
| c2w = torch.from_numpy(np.array(c2w)[:3, :4]) | |
| self.all_c2w.append(c2w) | |
| img = Image.fromarray(frame) | |
| img = remove(img, session=rembg_session) | |
| img = img.resize(self.img_wh, Image.BICUBIC) | |
| img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) | |
| self.all_fg_masks.append(img[..., -1]) # (h, w) | |
| self.all_images.append(img[..., :3]) | |
| self.all_c2w, self.all_images, self.all_fg_masks = ( | |
| torch.stack(self.all_c2w, dim=0).float().to(self.rank), | |
| torch.stack(self.all_images, dim=0).float().to(self.rank), | |
| torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), | |
| ) | |
| self.normals = dpt(self.all_images) | |
| self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1 | |
| self.normals = self.normals * 2.0 - 1.0 | |
| self.normals = midas2blender(self.normals).cpu().numpy() | |
| # self.normals = self.normals.cpu().numpy() | |
| self.normals[..., 0] *= -1 | |
| self.normals[~self.all_masks] = [0, 0, 0] | |
| normals = rearrange(self.normals, "b h w c -> b c h w") | |
| normals = normals * 0.5 + 0.5 | |
| normals = torch.from_numpy(normals) | |
| # save_image(make_grid(normals, nrow=4), "tmp/normals.png") | |
| # exit(0) | |
| ( | |
| self.all_poses, | |
| self.all_normals, | |
| self.all_normals_world, | |
| self.all_w2cs, | |
| self.all_color_masks, | |
| ) = ([], [], [], [], []) | |
| for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals): | |
| RT_opengl = inv_RT(c2w_opengl) | |
| RT_opencv = RT_opengl2opencv(RT_opengl) | |
| c2w_opencv = inv_RT(RT_opencv) | |
| self.all_poses.append(c2w_opencv) | |
| self.all_w2cs.append(RT_opencv) | |
| normal = normal_opengl2opencv(normal) | |
| normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal) | |
| self.all_normals.append(normal) | |
| self.all_normals_world.append(normal_world) | |
| self.directions = torch.stack([self.directions] * len(self.all_images)) | |
| self.origins = self.directions | |
| self.all_poses = np.stack(self.all_poses) | |
| self.all_normals = np.stack(self.all_normals) | |
| self.all_normals_world = np.stack(self.all_normals_world) | |
| self.all_w2cs = np.stack(self.all_w2cs) | |
| self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank) | |
| self.all_images = self.all_images.to(self.rank) | |
| self.all_fg_masks = self.all_fg_masks.to(self.rank) | |
| self.all_rgb_masks = self.all_fg_masks.to(self.rank) | |
| self.all_normals_world = ( | |
| torch.from_numpy(self.all_normals_world).float().to(self.rank) | |
| ) | |
| class BlenderDataset(Dataset, BlenderDatasetBase): | |
| def __init__(self, config, split): | |
| self.setup(config, split) | |
| def __len__(self): | |
| return len(self.all_images) | |
| def __getitem__(self, index): | |
| return {"index": index} | |
| class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): | |
| def __init__(self, config, split): | |
| self.setup(config, split) | |
| def __iter__(self): | |
| while True: | |
| yield {} | |
| class BlenderDataModule(pl.LightningDataModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| def setup(self, stage=None): | |
| if stage in [None, "fit"]: | |
| self.train_dataset = BlenderIterableDataset( | |
| self.config, self.config.train_split | |
| ) | |
| if stage in [None, "fit", "validate"]: | |
| self.val_dataset = BlenderDataset(self.config, self.config.val_split) | |
| if stage in [None, "test"]: | |
| self.test_dataset = BlenderDataset(self.config, self.config.test_split) | |
| if stage in [None, "predict"]: | |
| self.predict_dataset = BlenderDataset(self.config, self.config.train_split) | |
| def prepare_data(self): | |
| pass | |
| def general_loader(self, dataset, batch_size): | |
| sampler = None | |
| return DataLoader( | |
| dataset, | |
| num_workers=os.cpu_count(), | |
| batch_size=batch_size, | |
| pin_memory=True, | |
| sampler=sampler, | |
| ) | |
| def train_dataloader(self): | |
| return self.general_loader(self.train_dataset, batch_size=1) | |
| def val_dataloader(self): | |
| return self.general_loader(self.val_dataset, batch_size=1) | |
| def test_dataloader(self): | |
| return self.general_loader(self.test_dataset, batch_size=1) | |
| def predict_dataloader(self): | |
| return self.general_loader(self.predict_dataset, batch_size=1) | |