Spaces:
Sleeping
Sleeping
| import torch | |
| import cv2 | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import Normalize | |
| from common import constants | |
| def mask_split(img, num_parts): | |
| if not len(img.shape) == 2: | |
| img = img[:, :, 0] | |
| mask = np.zeros((img.shape[0], img.shape[1], num_parts)) | |
| for i in np.unique(img): | |
| mask[:, :, i] = np.where(img == i, 1., 0.) | |
| return np.transpose(mask, (2, 0, 1)) | |
| class BaseDataset(Dataset): | |
| def __init__(self, dataset, mode, model_type='smpl', normalize=False): | |
| self.dataset = dataset | |
| self.mode = mode | |
| print(f'Loading dataset: {constants.DATASET_FILES[mode][dataset]} for mode: {mode}') | |
| self.data = np.load(constants.DATASET_FILES[mode][dataset], allow_pickle=True) | |
| self.images = self.data['imgname'] | |
| # get 3d contact labels, if available | |
| try: | |
| self.contact_labels_3d = self.data['contact_label'] | |
| # make a has_contact_3d numpy array which contains 1 if contact labels are no empty and 0 otherwise | |
| self.has_contact_3d = np.array([1 if len(x) > 0 else 0 for x in self.contact_labels_3d]) | |
| except KeyError: | |
| self.has_contact_3d = np.zeros(len(self.images)) | |
| # get 2d polygon contact labels, if available | |
| try: | |
| self.polygon_contacts_2d = self.data['polygon_2d_contact'] | |
| self.has_polygon_contact_2d = np.ones(len(self.images)) | |
| except KeyError: | |
| self.has_polygon_contact_2d = np.zeros(len(self.images)) | |
| # Get camera parameters - only intrinsics for now | |
| try: | |
| self.cam_k = self.data['cam_k'] | |
| except KeyError: | |
| self.cam_k = np.zeros((len(self.images), 3, 3)) | |
| self.sem_masks = self.data['scene_seg'] | |
| self.part_masks = self.data['part_seg'] | |
| # Get gt SMPL parameters, if available | |
| try: | |
| self.pose = self.data['pose'].astype(float) | |
| self.betas = self.data['shape'].astype(float) | |
| self.transl = self.data['transl'].astype(float) | |
| if 'has_smpl' in self.data: | |
| self.has_smpl = self.data['has_smpl'] | |
| else: | |
| self.has_smpl = np.ones(len(self.images)) | |
| self.is_smplx = np.ones(len(self.images)) if model_type == 'smplx' else np.zeros(len(self.images)) | |
| except KeyError: | |
| self.has_smpl = np.zeros(len(self.images)) | |
| self.is_smplx = np.zeros(len(self.images)) | |
| if model_type == 'smpl': | |
| self.n_vertices = 6890 | |
| elif model_type == 'smplx': | |
| self.n_vertices = 10475 | |
| else: | |
| raise NotImplementedError | |
| self.normalize = normalize | |
| self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) | |
| def __getitem__(self, index): | |
| item = {} | |
| # Load image | |
| img_path = self.images[index] | |
| try: | |
| img = cv2.imread(img_path) | |
| img_h, img_w, _ = img.shape | |
| img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC) | |
| img = img.transpose(2, 0, 1) / 255.0 | |
| except: | |
| print('Img: ', img_path) | |
| img_scale_factor = np.array([256 / img_w, 256 / img_h]) | |
| # Get SMPL parameters, if available | |
| if self.has_smpl[index]: | |
| pose = self.pose[index].copy() | |
| betas = self.betas[index].copy() | |
| transl = self.transl[index].copy() | |
| else: | |
| pose = np.zeros(72) | |
| betas = np.zeros(10) | |
| transl = np.zeros(3) | |
| # Load vertex_contact | |
| if self.has_contact_3d[index]: | |
| contact_label_3d = self.contact_labels_3d[index] | |
| else: | |
| contact_label_3d = np.zeros(self.n_vertices) | |
| sem_mask_path = self.sem_masks[index] | |
| try: | |
| sem_mask = cv2.imread(sem_mask_path) | |
| sem_mask = cv2.resize(sem_mask, (256, 256), cv2.INTER_CUBIC) | |
| sem_mask = mask_split(sem_mask, 133) | |
| except: | |
| print('Scene seg: ', sem_mask_path) | |
| try: | |
| part_mask_path = self.part_masks[index] | |
| part_mask = cv2.imread(part_mask_path) | |
| part_mask = cv2.resize(part_mask, (256, 256), cv2.INTER_CUBIC) | |
| part_mask = mask_split(part_mask, 26) | |
| except: | |
| print('Part seg: ', part_mask_path) | |
| try: | |
| if self.has_polygon_contact_2d[index]: | |
| polygon_contact_2d_path = self.polygon_contacts_2d[index] | |
| polygon_contact_2d = cv2.imread(polygon_contact_2d_path) | |
| polygon_contact_2d = cv2.resize(polygon_contact_2d, (256, 256), cv2.INTER_NEAREST) | |
| # binarize the part mask | |
| polygon_contact_2d = np.where(polygon_contact_2d > 0, 1, 0) | |
| else: | |
| polygon_contact_2d = np.zeros((256, 256, 3)) | |
| except: | |
| print('2D polygon contact: ', polygon_contact_2d_path) | |
| if self.normalize: | |
| img = torch.tensor(img, dtype=torch.float32) | |
| item['img'] = self.normalize_img(img) | |
| else: | |
| item['img'] = torch.tensor(img, dtype=torch.float32) | |
| if self.is_smplx[index]: | |
| # Add 6 zeros to the end of the pose vector to match with smpl | |
| pose = np.concatenate((pose, np.zeros(6))) | |
| item['img_path'] = img_path | |
| item['pose'] = torch.tensor(pose, dtype=torch.float32) | |
| item['betas'] = torch.tensor(betas, dtype=torch.float32) | |
| item['transl'] = torch.tensor(transl, dtype=torch.float32) | |
| item['cam_k'] = self.cam_k[index] | |
| item['img_scale_factor'] = torch.tensor(img_scale_factor, dtype=torch.float32) | |
| item['contact_label_3d'] = torch.tensor(contact_label_3d, dtype=torch.float32) | |
| item['sem_mask'] = torch.tensor(sem_mask, dtype=torch.float32) | |
| item['part_mask'] = torch.tensor(part_mask, dtype=torch.float32) | |
| item['polygon_contact_2d'] = torch.tensor(polygon_contact_2d, dtype=torch.float32) | |
| item['has_smpl'] = self.has_smpl[index] | |
| item['is_smplx'] = self.is_smplx[index] | |
| item['has_contact_3d'] = self.has_contact_3d[index] | |
| item['has_polygon_contact_2d'] = self.has_polygon_contact_2d[index] | |
| return item | |
| def __len__(self): | |
| return len(self.images) | |