Spaces:
Build error
Build error
| import glob | |
| import json | |
| import os | |
| import cv2 | |
| import pickle | |
| import random | |
| import re | |
| import subprocess | |
| from functools import partial | |
| import librosa.core | |
| import numpy as np | |
| import torch | |
| import torch.distributions | |
| import torch.distributed as dist | |
| import torch.optim | |
| import torch.utils.data | |
| from utils.commons.indexed_datasets import IndexedDataset | |
| from torch.utils.data import Dataset, DataLoader | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import csv | |
| from utils.commons.hparams import hparams, set_hparams | |
| from utils.commons.meters import Timer | |
| from data_util.face3d_helper import Face3DHelper | |
| from utils.audio import librosa_wav2mfcc | |
| from utils.commons.dataset_utils import collate_xd | |
| from utils.commons.tensor_utils import convert_to_tensor | |
| from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image | |
| from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic | |
| from utils.commons.image_utils import load_image_as_uint8_tensor | |
| from modules.eg3ds.camera_utils.pose_sampler import UnifiedCameraPoseSampler | |
| def sample_idx(img_dir, num_frames): | |
| cnt = 0 | |
| while True: | |
| cnt += 1 | |
| if cnt > 1000: | |
| print(f"recycle for more than 1000 times, check this {img_dir}") | |
| idx = random.randint(0, num_frames-1) | |
| ret1 = find_img_name(img_dir, idx) | |
| if ret1 == 'None': | |
| continue | |
| ret2 = find_img_name(img_dir.replace("/gt_imgs/","/head_imgs/"), idx) | |
| if ret2 == 'None': | |
| continue | |
| ret3 = find_img_name(img_dir.replace("/gt_imgs/","/inpaint_torso_imgs/"), idx) | |
| if ret3 == 'None': | |
| continue | |
| ret4 = find_img_name(img_dir.replace("/gt_imgs/","/com_imgs/"), idx) | |
| if ret4 == 'None': | |
| continue | |
| return idx | |
| def find_img_name(img_dir, idx): | |
| gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".jpg") | |
| if not os.path.exists(gt_img_fname): | |
| gt_img_fname = os.path.join(img_dir, str(idx) + ".jpg") | |
| if not os.path.exists(gt_img_fname): | |
| gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".jpg") | |
| if not os.path.exists(gt_img_fname): | |
| gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".png") | |
| if not os.path.exists(gt_img_fname): | |
| gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".png") | |
| if not os.path.exists(gt_img_fname): | |
| gt_img_fname = os.path.join(img_dir, str(idx) + ".png") | |
| if os.path.exists(gt_img_fname): | |
| return gt_img_fname | |
| else: | |
| return 'None' | |
| def get_win_from_arr(arr, index, win_size): | |
| left = index - win_size//2 | |
| right = index + (win_size - win_size//2) | |
| pad_left = 0 | |
| pad_right = 0 | |
| if left < 0: | |
| pad_left = -left | |
| left = 0 | |
| if right > arr.shape[0]: | |
| pad_right = right - arr.shape[0] | |
| right = arr.shape[0] | |
| win = arr[left:right] | |
| if pad_left > 0: | |
| if isinstance(arr, np.ndarray): | |
| win = np.concatenate([np.zeros_like(win[:pad_left]), win], axis=0) | |
| else: | |
| win = torch.cat([torch.zeros_like(win[:pad_left]), win], dim=0) | |
| if pad_right > 0: | |
| if isinstance(arr, np.ndarray): | |
| win = np.concatenate([win, np.zeros_like(win[:pad_right])], axis=0) # [8, 16] | |
| else: | |
| win = torch.cat([win, torch.zeros_like(win[:pad_right])], dim=0) # [8, 16] | |
| return win | |
| class Img2Plane_Dataset(Dataset): | |
| def __init__(self, prefix='train', data_dir=None): | |
| self.db_key = prefix | |
| self.ds = None | |
| self.sizes = None | |
| self.x_maxframes = 200 # 50 video frames | |
| self.face3d_helper = Face3DHelper('deep_3drecon/BFM') | |
| self.x_multiply = 8 | |
| self.hparams = hparams | |
| self.pose_sampler = UnifiedCameraPoseSampler() | |
| self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir | |
| def __len__(self): | |
| ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
| return len(ds) | |
| def _get_item(self, index): | |
| """ | |
| This func is necessary to open files in multi-threads! | |
| """ | |
| if self.ds is None: | |
| self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
| return self.ds[index] | |
| def __getitem__(self, idx): | |
| raw_item = self._get_item(idx) | |
| if raw_item is None: | |
| print("loading from binary data failed!") | |
| return None | |
| item = { | |
| 'idx': idx, | |
| 'item_name': raw_item['img_dir'], | |
| } | |
| img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') | |
| num_frames = len(raw_item['exp']) | |
| hparams = self.hparams | |
| camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) | |
| c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] | |
| raw_item['c2w'] = c2w | |
| raw_item['intrinsics'] = intrinsics | |
| max_pitch = 10 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref | |
| min_pitch = -max_pitch | |
| pitch = random.random() * (max_pitch - min_pitch) + min_pitch | |
| max_yaw = 16 / 180 * 3.1415926 | |
| min_yaw = - max_yaw | |
| yaw = random.random() * (max_yaw - min_yaw) + min_yaw | |
| distance = random.random() * (3.2-2.7) + 2.7 # [2.7, 4.0] | |
| ws_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] | |
| if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : | |
| max_pitch = 26 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref | |
| min_pitch = -max_pitch | |
| pitch = random.random() * (max_pitch - min_pitch) + min_pitch | |
| max_yaw = 38 / 180 * 3.1415926 | |
| min_yaw = - max_yaw | |
| yaw = random.random() * (max_yaw - min_yaw) + min_yaw | |
| distance = random.random() * (4.0-2.7) + 2.7 # [2.7, 4.0] | |
| real_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] | |
| else: | |
| real_idx = sample_idx(img_dir, num_frames) | |
| real_c2w = raw_item['c2w'][real_idx] | |
| real_intrinsics = raw_item['intrinsics'][real_idx] | |
| real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) | |
| real_camera = convert_to_tensor(real_camera) | |
| if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : | |
| max_pitch = 26 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref | |
| min_pitch = -max_pitch | |
| pitch = random.random() * (max_pitch - min_pitch) + min_pitch | |
| max_yaw = 38 / 180 * 3.1415926 | |
| min_yaw = - max_yaw | |
| yaw = random.random() * (max_yaw - min_yaw) + min_yaw | |
| distance = random.random() * (4.0-2.7) + 2.7 # [2.7, 4.0] | |
| fake_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] | |
| else: | |
| fake_idx = sample_idx(img_dir, num_frames) | |
| fake_c2w = raw_item['c2w'][fake_idx] | |
| fake_intrinsics = raw_item['intrinsics'][fake_idx] | |
| fake_camera = np.concatenate([fake_c2w.reshape([16,]), fake_intrinsics.reshape([9,])], axis=0) | |
| fake_camera = convert_to_tensor(fake_camera) | |
| item.update({ | |
| 'ws_camera': ws_camera, | |
| 'real_camera': real_camera, | |
| 'fake_camera': fake_camera, | |
| # id,exp,euler,trans, used to generate the secc map | |
| }) | |
| return item | |
| def get_dataloader(self, batch_size=1, num_workers=0): | |
| loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) | |
| return loader | |
| def collater(self, samples): | |
| hparams = self.hparams | |
| if len(samples) == 0: | |
| return {} | |
| batch = {} | |
| batch['ffhq_ws_cameras'] = torch.stack([s['ws_camera'] for s in samples], dim=0) # [B, 204] | |
| batch['ffhq_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) # [B, 204] | |
| batch['ffhq_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) # [B, 204] | |
| return batch | |
| class Motion2Video_Dataset(Dataset): | |
| def __init__(self, prefix='train', data_dir=None): | |
| self.db_key = prefix | |
| self.ds = None | |
| self.sizes = None | |
| self.x_maxframes = 200 # 50 video frames | |
| self.face3d_helper = Face3DHelper('deep_3drecon/BFM') | |
| self.x_multiply = 8 | |
| self.hparams = hparams | |
| self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir | |
| def __len__(self): | |
| ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
| return len(ds) | |
| def _get_item(self, index): | |
| """ | |
| This func is necessary to open files in multi-threads! | |
| """ | |
| if self.ds is None: | |
| self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
| return self.ds[index] | |
| def __getitem__(self, idx): | |
| raw_item = self._get_item(idx) | |
| if raw_item is None: | |
| print("loading from binary data failed!") | |
| return None | |
| item = { | |
| 'idx': idx, | |
| 'item_name': raw_item['img_dir'], | |
| } | |
| camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) | |
| c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] | |
| raw_item['c2w'] = c2w | |
| raw_item['intrinsics'] = intrinsics | |
| img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') | |
| num_frames = len(raw_item['exp']) | |
| # src | |
| real_idx = sample_idx(img_dir, num_frames) | |
| real_c2w = raw_item['c2w'][real_idx] | |
| real_intrinsics = raw_item['intrinsics'][real_idx] | |
| real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) | |
| real_camera = convert_to_tensor(real_camera) | |
| item['real_camera'] = real_camera | |
| gt_img_fname = find_img_name(img_dir, real_idx) | |
| gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] # ignore alpha channel when png | |
| item['real_gt_img'] = gt_img.float() / 127.5 - 1 | |
| # for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
| for key in ['head', 'com', 'inpaint_torso']: | |
| key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") | |
| key_img_fname = find_img_name(key_img_dir, real_idx) | |
| key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] # ignore alpha channel when png | |
| item[f'real_{key}_img'] = key_img.float() / 127.5 - 1 | |
| bg_img_name = img_dir.replace("/gt_imgs/",f"/bg_img/") + '.jpg' | |
| bg_img = load_image_as_uint8_tensor(bg_img_name)[..., :3] # ignore alpha channel when png | |
| item[f'bg_img'] = bg_img.float() / 127.5 - 1 | |
| seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") | |
| seg_img = cv2.imread(seg_img_name)[:,:, ::-1] | |
| segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W] | |
| item[f'real_segmap'] = segmap | |
| item[f'real_head_mask'] = segmap[[1,3,5]].sum(dim=0) | |
| item[f'real_torso_mask'] = segmap[[2,4]].sum(dim=0) | |
| item.update({ | |
| # id,exp,euler,trans, used to generate the secc map | |
| 'real_identity': convert_to_tensor(raw_item['id']).reshape([80,]), | |
| # 'real_identity': convert_to_tensor(raw_item['id'][real_idx]).reshape([80,]), | |
| 'real_expression': convert_to_tensor(raw_item['exp'][real_idx]).reshape([64,]), | |
| 'real_euler': convert_to_tensor(raw_item['euler'][real_idx]).reshape([3,]), | |
| 'real_trans': convert_to_tensor(raw_item['trans'][real_idx]).reshape([3,]), | |
| }) | |
| pertube_idx_candidates = [idx for idx in [real_idx-1, real_idx+1] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
| # pertube_idx_candidates = [idx for idx in [real_idx-2, real_idx-1, real_idx+1, real_idx+2] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
| pertube_idx = random.choice(pertube_idx_candidates) | |
| item[f'real_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) | |
| item[f'real_pertube_expression_2'] = item['real_expression'] * 2 - item[f'real_pertube_expression_1'] | |
| # tgt | |
| fake_idx = sample_idx(img_dir, num_frames) | |
| min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) | |
| while abs(fake_idx - real_idx) < min_offset: | |
| fake_idx = sample_idx(img_dir, num_frames) | |
| min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) | |
| fake_c2w = raw_item['c2w'][fake_idx] | |
| fake_intrinsics = raw_item['intrinsics'][fake_idx] | |
| fake_camera = np.concatenate([fake_c2w.reshape([16,]) , fake_intrinsics.reshape([9,])], axis=0) | |
| fake_camera = convert_to_tensor(fake_camera) | |
| item['fake_camera'] = fake_camera | |
| gt_img_fname = find_img_name(img_dir, fake_idx) | |
| gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] # ignore alpha channel when png | |
| item['fake_gt_img'] = gt_img.float() / 127.5 - 1 | |
| seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") | |
| seg_img = cv2.imread(seg_img_name)[:,:, ::-1] | |
| segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W] | |
| item[f'fake_segmap'] = segmap | |
| item[f'fake_head_mask'] = segmap[[1,3,5]].sum(dim=0) | |
| item[f'fake_torso_mask'] = segmap[[2,4]].sum(dim=0) | |
| # for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
| for key in ['head', 'com', 'inpaint_torso']: | |
| key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") | |
| key_img_fname = find_img_name(key_img_dir, fake_idx) | |
| key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] # ignore alpha channel when png | |
| item[f'fake_{key}_img'] = key_img.float() / 127.5 - 1 | |
| item.update({ | |
| # id,exp,euler,trans, used to generate the secc map | |
| f'fake_identity': convert_to_tensor(raw_item['id']).reshape([80,]), | |
| # f'fake_identity': convert_to_tensor(raw_item['id'][fake_idx]).reshape([80,]), | |
| f'fake_expression': convert_to_tensor(raw_item['exp'][fake_idx]).reshape([64,]), | |
| f'fake_euler': convert_to_tensor(raw_item['euler'][fake_idx]).reshape([3,]), | |
| f'fake_trans': convert_to_tensor(raw_item['trans'][fake_idx]).reshape([3,]), | |
| }) | |
| # pertube_idx_candidates = [idx for idx in [fake_idx-2, fake_idx-1, fake_idx+1, fake_idx+2] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
| pertube_idx_candidates = [idx for idx in [fake_idx-1, fake_idx+1] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
| pertube_idx = random.choice(pertube_idx_candidates) | |
| item[f'fake_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) | |
| item[f'fake_pertube_expression_2'] = item['fake_expression'] * 2 - item[f'fake_pertube_expression_1'] | |
| return item | |
| def get_dataloader(self, batch_size=1, num_workers=0): | |
| loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) | |
| return loader | |
| def collater(self, samples): | |
| hparams = self.hparams | |
| if len(samples) == 0: | |
| return {} | |
| batch = {} | |
| batch['th1kh_item_names'] = [s['item_name'] for s in samples] | |
| batch['th1kh_ref_gt_imgs'] = torch.stack([s['real_gt_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
| batch['th1kh_ref_head_masks'] = torch.stack([s['real_head_mask'] for s in samples]) # [B,6,H,W] | |
| batch['th1kh_ref_torso_masks'] = torch.stack([s['real_torso_mask'] for s in samples]) # [B,6,H,W] | |
| batch['th1kh_ref_segmaps'] = torch.stack([s['real_segmap'] for s in samples]) # [B,6,H,W] | |
| # for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
| for key in ['head', 'com', 'inpaint_torso']: | |
| batch[f'th1kh_ref_{key}_imgs'] = torch.stack([s[f'real_{key}_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
| batch[f'th1kh_bg_imgs'] = torch.stack([s[f'bg_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
| batch['th1kh_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_ref_ids'] = torch.stack([s['real_identity'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_ref_exps'] = torch.stack([s['real_expression'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_ref_eulers'] = torch.stack([s['real_euler'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_ref_trans'] = torch.stack([s['real_trans'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_gt_imgs'] = torch.stack([s['fake_gt_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
| # for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
| for key in ['head', 'com', 'inpaint_torso']: | |
| batch[f'th1kh_mv_{key}_imgs'] = torch.stack([s[f'fake_{key}_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
| batch['th1kh_mv_head_masks'] = torch.stack([s['fake_head_mask'] for s in samples]) # [B,6,H,W] | |
| batch['th1kh_mv_torso_masks'] = torch.stack([s['fake_torso_mask'] for s in samples]) # [B,6,H,W] | |
| batch['th1kh_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_ids'] = torch.stack([s['fake_identity'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_exps'] = torch.stack([s['fake_expression'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_eulers'] = torch.stack([s['fake_euler'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_trans'] = torch.stack([s['fake_trans'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_ref_pertube_exps_1'] = torch.stack([s['real_pertube_expression_1'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_ref_pertube_exps_2'] = torch.stack([s['real_pertube_expression_2'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_pertube_exps_1'] = torch.stack([s['fake_pertube_expression_1'] for s in samples], dim=0) # [B, 204] | |
| batch['th1kh_mv_pertube_exps_2'] = torch.stack([s['fake_pertube_expression_2'] for s in samples], dim=0) # [B, 204] | |
| return batch | |
| if __name__ == '__main__': | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| ds = Img2Plane_Dataset("train", 'data/binary/th1kh') | |
| # ds = Motion2Video_Dataset("train", 'data/binary/th1kh') | |
| dl = ds.get_dataloader() | |
| for b in tqdm(dl): | |
| pass | |