Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union | |
| import numpy as np | |
| import pytorch3d | |
| import torch | |
| from torch.utils.data import SequentialSampler | |
| from omegaconf import DictConfig | |
| from pytorch3d.implicitron.dataset.data_loader_map_provider import \ | |
| SequenceDataLoaderMapProvider | |
| from pytorch3d.implicitron.dataset.dataset_base import FrameData | |
| from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset | |
| from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( | |
| JsonIndexDatasetMapProviderV2, registry) | |
| from pytorch3d.implicitron.tools.config import expand_args_fields | |
| from pytorch3d.renderer.cameras import CamerasBase | |
| from torch.utils.data import DataLoader | |
| from pytorch3d.datasets import R2N2, collate_batched_meshes | |
| from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional | |
| from .utils import DatasetMap | |
| def get_dataset(cfg: ProjectConfig): | |
| if cfg.dataset.type == 'co3dv2': | |
| from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE | |
| dataset_cfg: CO3DConfig = cfg.dataset | |
| dataloader_cfg: DataloaderConfig = cfg.dataloader | |
| # Exclude bad and low-quality sequences, XH: why this is needed? | |
| exclude_sequence = [] | |
| exclude_sequence.extend(EXCLUDE_SEQUENCE.get(dataset_cfg.category, [])) | |
| exclude_sequence.extend(LOW_QUALITY_SEQUENCE.get(dataset_cfg.category, [])) | |
| # Whether to load pointclouds | |
| kwargs = dict( | |
| remove_empty_masks=True, | |
| n_frames_per_sequence=1, | |
| load_point_clouds=True, | |
| max_points=dataset_cfg.max_points, | |
| image_height=dataset_cfg.image_size, | |
| image_width=dataset_cfg.image_size, | |
| mask_images=dataset_cfg.mask_images, | |
| exclude_sequence=exclude_sequence, | |
| pick_sequence=() if dataset_cfg.restrict_model_ids is None else dataset_cfg.restrict_model_ids, | |
| ) | |
| # Get dataset mapper | |
| dataset_map_provider_type = registry.get(JsonIndexDatasetMapProviderV2, "JsonIndexDatasetMapProviderV2") | |
| expand_args_fields(dataset_map_provider_type) | |
| dataset_map_provider = dataset_map_provider_type( | |
| category=dataset_cfg.category, | |
| subset_name=dataset_cfg.subset_name, | |
| dataset_root=dataset_cfg.root, | |
| test_on_train=False, | |
| only_test_set=False, | |
| load_eval_batches=True, | |
| dataset_JsonIndexDataset_args=DictConfig(kwargs), | |
| ) | |
| # Get datasets | |
| datasets = dataset_map_provider.get_dataset_map() # how to select specific frames?? | |
| # PATCH BUG WITH POINT CLOUD LOCATIONS! | |
| for dataset in (datasets["train"], datasets["val"]): | |
| # print(dataset.seq_annots.items()) | |
| for key, ann in dataset.seq_annots.items(): | |
| correct_point_cloud_path = Path(dataset.dataset_root) / Path(*Path(ann.point_cloud.path).parts[-3:]) | |
| assert correct_point_cloud_path.is_file(), correct_point_cloud_path | |
| ann.point_cloud.path = str(correct_point_cloud_path) | |
| # Get dataloader mapper | |
| data_loader_map_provider_type = registry.get(SequenceDataLoaderMapProvider, "SequenceDataLoaderMapProvider") | |
| expand_args_fields(data_loader_map_provider_type) | |
| data_loader_map_provider = data_loader_map_provider_type( | |
| batch_size=dataloader_cfg.batch_size, | |
| num_workers=dataloader_cfg.num_workers, | |
| ) | |
| # QUICK HACK: Patch the train dataset because it is not used but it throws an error | |
| if (len(datasets['train']) == 0 and len(datasets[dataset_cfg.eval_split]) > 0 and | |
| dataset_cfg.restrict_model_ids is not None and cfg.run.job == 'sample'): | |
| datasets = DatasetMap(train=datasets[dataset_cfg.eval_split], val=datasets[dataset_cfg.eval_split], | |
| test=datasets[dataset_cfg.eval_split]) | |
| # XH: why all eval split? | |
| print('Note: You used restrict_model_ids and there were no ids in the train set.') | |
| # Get dataloaders | |
| dataloaders = data_loader_map_provider.get_data_loader_map(datasets) | |
| dataloader_train = dataloaders['train'] | |
| dataloader_val = dataloader_vis = dataloaders[dataset_cfg.eval_split] | |
| # Replace validation dataloader sampler with SequentialSampler | |
| # seems to be randomly sampled? with a fixed random seed? but one cannot control which image is being sampled?? | |
| dataloader_val.batch_sampler.sampler = SequentialSampler(dataloader_val.batch_sampler.sampler.data_source) | |
| # Modify for accelerate | |
| dataloader_train.batch_sampler.drop_last = True | |
| dataloader_val.batch_sampler.drop_last = False | |
| elif cfg.dataset.type == 'shapenet_r2n2': | |
| # from ..configs.structured import ShapeNetR2N2Config | |
| from .r2n2_my import R2N2Sample | |
| dataset_cfg: ShapeNetR2N2Config = cfg.dataset | |
| # for k in dataset_cfg: | |
| # print(k) | |
| datasets = [R2N2Sample(dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| dataset_cfg.image_size, cfg.augmentations, | |
| s, dataset_cfg.shapenet_dir, | |
| dataset_cfg.r2n2_dir, dataset_cfg.splits_file, | |
| load_textures=False, return_all_views=True) for s in ['train', 'val', 'test']] | |
| dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=True) | |
| dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=False) | |
| dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=False) | |
| elif cfg.dataset.type in ['behave', 'behave-objonly', 'behave-humonly', 'behave-dtransl', | |
| 'behave-objonly-segm', 'behave-humonly-segm', 'behave-attn', | |
| 'behave-test', 'behave-attn-test', 'behave-hum-pe', 'behave-hum-noscale', | |
| 'behave-hum-surf', 'behave-objv2v']: | |
| from .behave_dataset import BehaveDataset, NTUDataset, BehaveObjOnly, BehaveHumanOnly, BehaveHumanOnlyPosEnc | |
| from .behave_dataset import BehaveHumanOnlySegmInput, BehaveObjOnlySegmInput, BehaveTestOnly, BehaveHumNoscale | |
| from .behave_dataset import BehaveHumanOnlySurfSample | |
| from .dtransl_dataset import DirectTranslDataset | |
| from .behave_paths import DataPaths | |
| from configs.structured import BehaveDatasetConfig | |
| from .behave_crossattn import BehaveCrossAttnDataset, BehaveCrossAttnTest | |
| from .behave_dataset import BehaveObjOnlyV2V | |
| dataset_cfg: BehaveDatasetConfig = cfg.dataset | |
| # print(dataset_cfg.behave_dir) | |
| train_paths, val_paths = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir) | |
| # exit(0) | |
| # split validation paths to only consider the selected batches | |
| bs = cfg.dataloader.batch_size | |
| num_batches_total = int(np.ceil(len(val_paths)/cfg.dataloader.batch_size)) | |
| end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total | |
| # print(cfg.run.batch_end, cfg.run.batch_start, end_idx) | |
| val_paths = val_paths[cfg.run.batch_start*bs:end_idx*bs] | |
| if cfg.dataset.type == 'behave': | |
| train_type = BehaveDataset | |
| val_datatype = BehaveDataset if 'ntu' not in dataset_cfg.split_file else NTUDataset | |
| elif cfg.dataset.type == 'behave-test': | |
| train_type = BehaveDataset | |
| val_datatype = BehaveTestOnly | |
| elif cfg.dataset.type == 'behave-objonly': | |
| train_type = BehaveObjOnly | |
| val_datatype = BehaveObjOnly | |
| assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' | |
| elif cfg.dataset.type == 'behave-humonly': | |
| train_type = BehaveHumanOnly | |
| val_datatype = BehaveHumanOnly | |
| assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' | |
| elif cfg.dataset.type == 'behave-hum-noscale': | |
| train_type = BehaveHumNoscale | |
| val_datatype = BehaveHumNoscale | |
| elif cfg.dataset.type == 'behave-hum-pe': | |
| train_type = BehaveHumanOnlyPosEnc | |
| val_datatype = BehaveHumanOnlyPosEnc | |
| elif cfg.dataset.type == 'behave-hum-surf': | |
| train_type = BehaveHumanOnlySurfSample | |
| val_datatype = BehaveHumanOnlySurfSample | |
| elif cfg.dataset.type == 'behave-humonly-segm': | |
| assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!' | |
| train_type = BehaveHumanOnly | |
| val_datatype = BehaveHumanOnlySegmInput | |
| assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' | |
| elif cfg.dataset.type == 'behave-objonly-segm': | |
| assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!' | |
| train_type = BehaveObjOnly | |
| val_datatype = BehaveObjOnlySegmInput | |
| assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' | |
| elif cfg.dataset.type == 'behave-dtransl': | |
| train_type = DirectTranslDataset | |
| val_datatype = DirectTranslDataset | |
| elif cfg.dataset.type == 'behave-attn': | |
| train_type = BehaveCrossAttnDataset | |
| val_datatype = BehaveCrossAttnDataset | |
| elif cfg.dataset.type == 'behave-attn-test': | |
| train_type = BehaveCrossAttnDataset | |
| val_datatype = BehaveCrossAttnTest | |
| elif cfg.dataset.type == 'behave-objv2v': | |
| train_type = BehaveObjOnlyV2V | |
| val_datatype = BehaveObjOnlyV2V | |
| else: | |
| raise NotImplementedError | |
| dataset_train = train_type(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| (dataset_cfg.image_size, dataset_cfg.image_size), | |
| split='train', sample_ratio_hum=dataset_cfg.sample_ratio_hum, | |
| normalize_type=dataset_cfg.normalize_type, smpl_type='gt', | |
| load_corr_points=dataset_cfg.load_corr_points, | |
| uniform_obj_sample=dataset_cfg.uniform_obj_sample, | |
| bkg_type=dataset_cfg.bkg_type, | |
| bbox_params=dataset_cfg.bbox_params, | |
| pred_binary=cfg.model.predict_binary, | |
| ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, | |
| compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', | |
| use_gt_transl=cfg.dataset.use_gt_transl, | |
| cam_noise_std=cfg.dataset.cam_noise_std, | |
| sep_same_crop=cfg.dataset.sep_same_crop, | |
| aug_blur=cfg.dataset.aug_blur, | |
| std_coverage=cfg.dataset.std_coverage, | |
| v2v_path=cfg.dataset.v2v_path) | |
| dataset_val = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| (dataset_cfg.image_size, dataset_cfg.image_size), | |
| split='val', sample_ratio_hum=dataset_cfg.sample_ratio_hum, | |
| normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type, | |
| load_corr_points=dataset_cfg.load_corr_points, | |
| test_transl_type=dataset_cfg.test_transl_type, | |
| uniform_obj_sample=dataset_cfg.uniform_obj_sample, | |
| bkg_type=dataset_cfg.bkg_type, | |
| bbox_params=dataset_cfg.bbox_params, | |
| pred_binary=cfg.model.predict_binary, | |
| ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, | |
| compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', | |
| use_gt_transl=cfg.dataset.use_gt_transl, | |
| sep_same_crop=cfg.dataset.sep_same_crop, | |
| std_coverage=cfg.dataset.std_coverage, | |
| v2v_path=cfg.dataset.v2v_path) | |
| # dataset_test = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| # (dataset_cfg.image_size, dataset_cfg.image_size), | |
| # split='test', sample_ratio_hum=dataset_cfg.sample_ratio_hum, | |
| # normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type, | |
| # load_corr_points=dataset_cfg.load_corr_points, | |
| # test_transl_type=dataset_cfg.test_transl_type, | |
| # uniform_obj_sample=dataset_cfg.uniform_obj_sample, | |
| # bkg_type=dataset_cfg.bkg_type, | |
| # bbox_params=dataset_cfg.bbox_params, | |
| # pred_binary=cfg.model.predict_binary, | |
| # ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, | |
| # compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', | |
| # use_gt_transl=cfg.dataset.use_gt_transl, | |
| # sep_same_crop=cfg.dataset.sep_same_crop) | |
| dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=True) | |
| shuffle = cfg.run.job == 'train' | |
| dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=shuffle) | |
| dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=shuffle) | |
| # datasets = [BehaveDataset(p, dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| # (dataset_cfg.image_size, dataset_cfg.image_size), | |
| # split=s, sample_ratio_hum=dataset_cfg.sample_ratio_hum, | |
| # normalize_type=dataset_cfg.normalize_type) for p, s in zip([train_paths, val_paths, val_paths], | |
| # ['train', 'val', 'test'])] | |
| # dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size, | |
| # collate_fn=collate_batched_meshes, | |
| # num_workers=cfg.dataloader.num_workers, shuffle=True) | |
| # dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size, | |
| # collate_fn=collate_batched_meshes, | |
| # num_workers=cfg.dataloader.num_workers, shuffle=False) | |
| # dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size, | |
| # collate_fn=collate_batched_meshes, | |
| # num_workers=cfg.dataloader.num_workers, shuffle=False) | |
| elif cfg.dataset.type in ['shape']: | |
| from .shape_dataset import ShapeDataset | |
| from .behave_paths import DataPaths | |
| from configs.structured import ShapeDatasetConfig | |
| dataset_cfg: ShapeDatasetConfig = cfg.dataset | |
| train_paths, _ = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir) | |
| val_paths = train_paths # same as training, this is for overfitting | |
| # split validation paths to only consider the selected batches | |
| bs = cfg.dataloader.batch_size | |
| num_batches_total = int(np.ceil(len(val_paths) / cfg.dataloader.batch_size)) | |
| end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total | |
| # print(cfg.run.batch_end, cfg.run.batch_start, end_idx) | |
| val_paths = val_paths[cfg.run.batch_start * bs:end_idx * bs] | |
| dataset_train = ShapeDataset(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| (dataset_cfg.image_size, dataset_cfg.image_size), | |
| split='train', ) | |
| dataset_val = ShapeDataset(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, | |
| (dataset_cfg.image_size, dataset_cfg.image_size), | |
| split='train', ) | |
| dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=True) | |
| shuffle = cfg.run.job == 'train' | |
| dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=shuffle) | |
| dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, | |
| collate_fn=collate_batched_meshes, | |
| num_workers=cfg.dataloader.num_workers, shuffle=shuffle) | |
| else: | |
| raise NotImplementedError(cfg.dataset.type) | |
| return dataloader_train, dataloader_val, dataloader_vis | |