Spaces:
Configuration error
Configuration error
| from .transforms import make_transforms | |
| from . import samplers | |
| import torch | |
| import torch.utils.data | |
| import imp | |
| import os | |
| from .collate_batch import make_collator | |
| import numpy as np | |
| import time | |
| from lib.config.config import cfg | |
| def _dataset_factory(is_train): | |
| if is_train: | |
| module = cfg.train_dataset_module | |
| path = cfg.train_dataset_path | |
| args = cfg.train_dataset | |
| else: | |
| module = cfg.test_dataset_module | |
| path = cfg.test_dataset_path | |
| args = cfg.test_dataset | |
| dataset = imp.load_source(module, path).Dataset(**args) | |
| return dataset | |
| def make_dataset(cfg, dataset_name, transforms, is_train=True): | |
| dataset = _dataset_factory(is_train) | |
| return dataset | |
| def make_data_sampler(dataset, shuffle, is_distributed, is_train): | |
| if not is_train and cfg.test.sampler == 'FrameSampler': | |
| sampler = samplers.FrameSampler(dataset) | |
| return sampler | |
| if is_distributed: | |
| return samplers.DistributedSampler(dataset, shuffle=shuffle) | |
| if shuffle: | |
| sampler = torch.utils.data.sampler.RandomSampler(dataset) | |
| else: | |
| sampler = torch.utils.data.sampler.SequentialSampler(dataset) | |
| return sampler | |
| def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter, | |
| is_train): | |
| if is_train: | |
| batch_sampler = cfg.train.batch_sampler | |
| sampler_meta = cfg.train.sampler_meta | |
| else: | |
| batch_sampler = cfg.test.batch_sampler | |
| sampler_meta = cfg.test.sampler_meta | |
| if batch_sampler == 'default': | |
| batch_sampler = torch.utils.data.sampler.BatchSampler( | |
| sampler, batch_size, drop_last) | |
| elif batch_sampler == 'image_size': | |
| batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size, | |
| drop_last, sampler_meta) | |
| if max_iter != -1: | |
| batch_sampler = samplers.IterationBasedBatchSampler( | |
| batch_sampler, max_iter) | |
| return batch_sampler | |
| def worker_init_fn(worker_id): | |
| np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16)))) | |
| def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1): | |
| if is_train: | |
| batch_size = cfg.train.batch_size | |
| # shuffle = True | |
| shuffle = cfg.train.shuffle | |
| drop_last = False | |
| else: | |
| batch_size = cfg.test.batch_size | |
| shuffle = True if is_distributed else False | |
| drop_last = False | |
| dataset_name = cfg.train.dataset if is_train else cfg.test.dataset | |
| transforms = make_transforms(cfg, is_train) | |
| dataset = make_dataset(cfg, dataset_name, transforms, is_train) | |
| sampler = make_data_sampler(dataset, shuffle, is_distributed, is_train) | |
| batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, | |
| drop_last, max_iter, is_train) | |
| num_workers = cfg.train.num_workers | |
| collator = make_collator(cfg, is_train) | |
| data_loader = torch.utils.data.DataLoader(dataset, | |
| batch_sampler=batch_sampler, | |
| num_workers=num_workers, | |
| collate_fn=collator, | |
| worker_init_fn=worker_init_fn) | |
| return data_loader | |