Spaces:
Runtime error
Runtime error
| from tokenize import group | |
| import torch | |
| import numpy as np | |
| import numpy.random as npr | |
| import torch.distributed as dist | |
| import math | |
| from ...log_service import print_log | |
| from ... import sync | |
| def singleton(class_): | |
| instances = {} | |
| def getinstance(*args, **kwargs): | |
| if class_ not in instances: | |
| instances[class_] = class_(*args, **kwargs) | |
| return instances[class_] | |
| return getinstance | |
| class get_sampler(object): | |
| def __init__(self): | |
| self.sampler = {} | |
| def register(self, sampler): | |
| self.sampler[sampler.__name__] = sampler | |
| def __call__(self, dataset, cfg): | |
| if cfg == 'default_train': | |
| return GlobalDistributedSampler(dataset, shuffle=True, extend=False) | |
| elif cfg == 'default_eval': | |
| return GlobalDistributedSampler(dataset, shuffle=False, extend=True) | |
| else: | |
| t = cfg.type | |
| return self.sampler[t](dataset=dataset, **cfg.args) | |
| def register(): | |
| def wrapper(class_): | |
| get_sampler().register(class_) | |
| return class_ | |
| return wrapper | |
| ###################### | |
| # DistributedSampler # | |
| ###################### | |
| class GlobalDistributedSampler(torch.utils.data.Sampler): | |
| """ | |
| This is a distributed sampler that sync accross gpus and nodes. | |
| """ | |
| def __init__(self, | |
| dataset, | |
| shuffle=True, | |
| extend=False,): | |
| """ | |
| Arguments: | |
| dataset: Dataset used for sampling. | |
| shuffle: If true, sampler will shuffle the indices | |
| extend: If true, sampler will extend the indices that can be even distributed by ranks | |
| otherwise sampler will truncate the indices to make it even. | |
| """ | |
| self.ddp = sync.is_ddp() | |
| self.rank = sync.get_rank('global') | |
| self.world_size = sync.get_world_size('global') | |
| self.dataset = dataset | |
| self.shuffle = shuffle | |
| self.extend = extend | |
| num_samples = len(dataset) // self.world_size | |
| if extend and (len(dataset)%self.world_size != 0): | |
| num_samples+=1 | |
| self.num_samples = num_samples | |
| self.total_size = num_samples * self.world_size | |
| def __iter__(self): | |
| indices = self.get_sync_order() | |
| if self.extend: | |
| # extend using the front indices | |
| indices = indices+indices[0:self.total_size-len(indices)] | |
| else: | |
| # truncate | |
| indices = indices[0:self.total_size] | |
| # subsample | |
| indices = indices[self.rank : len(indices) : self.world_size] | |
| return iter(indices) | |
| def __len__(self): | |
| return self.num_samples | |
| def get_sync_order(self): | |
| if self.shuffle: | |
| indices = torch.randperm(len(self.dataset)).to(self.rank) | |
| if self.ddp: | |
| dist.broadcast(indices, src=0) | |
| indices = indices.to('cpu').tolist() | |
| else: | |
| indices = list(range(len(self.dataset))) | |
| print_log('Sampler : {}'.format(str(indices[0:5])) ) | |
| return indices | |
| class LocalDistributedSampler(GlobalDistributedSampler): | |
| """ | |
| This is a distributed sampler that sync across gpus within the nodes. | |
| But not sync across nodes. | |
| """ | |
| def __init__(self, | |
| dataset, | |
| shuffle=True, | |
| extend=False,): | |
| super().__init__(dataset, shuffle, extend) | |
| self.rank = sync.get_rank('local') | |
| self.world_size = sync.get_world_size('local') | |
| def get_sync_order(self): | |
| if self.shuffle: | |
| if self.rank == 0: | |
| indices = list(npr.permutation(len(self.dataset))) | |
| sync.nodewise_sync().broadcast_r0(indices) | |
| else: | |
| indices = sync.nodewise_sync().broadcast_r0(None) | |
| else: | |
| indices = list(range(len(self.dataset))) | |
| print_log('Sampler : {}'.format(str(indices[0:5])) ) | |
| return indices | |
| ############################ | |
| # random sample with group # | |
| ############################ | |
| # Deprecated | |
| class GroupSampler(torch.utils.data.Sampler): | |
| """ | |
| This is a new DistributedSampler that sample all index according to group. | |
| i.e. | |
| if group_size=3, num_replicas=2, train mode: | |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 | |
| ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10] | |
| ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10]) | |
| process1: [0, 1, 2] | |
| ==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10) | |
| process1: [0, 1, 2] | |
| ==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10) | |
| process1: [0, 1, 2], [8, 9] | |
| it will avoid_batchsize=1: | |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, | |
| ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8] | |
| ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8]) | |
| process1: [0, 1, 2] | |
| ==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8]) | |
| process1: [0, 1, 2] | |
| ==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1) | |
| process1: [0, 1, 2] | |
| if group_size=3, num_replicas=2, eval mode: | |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 | |
| ==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10 | |
| ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10] | |
| ==> (distribute) process0: [0, 1, 2], [6, 7, 8], | |
| process1: [3, 4, 5], [9, 10, 10] | |
| """ | |
| def __init__(self, | |
| dataset, | |
| group_size, | |
| num_replicas=None, | |
| rank=None, | |
| mode='train',): | |
| if num_replicas is None: | |
| if not dist.is_available(): | |
| raise ValueError | |
| num_replicas = dist.get_world_size() | |
| if rank is None: | |
| if not dist.is_available(): | |
| raise ValueError | |
| rank = dist.get_rank() | |
| self.dataset = dataset | |
| self.len_dataset = len(dataset) | |
| self.group_size = group_size | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.mode = mode | |
| len_dataset = self.len_dataset | |
| if (len_dataset % num_replicas != 0) and (mode == 'train'): | |
| # drop the non_aligned | |
| aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)] | |
| aligned_len_dataset = aligned_indices.shape[0] | |
| elif (len_dataset % num_replicas != 0) and (mode == 'eval'): | |
| extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)]) | |
| aligned_indices = np.concatenate([range(len_dataset), extend]) | |
| aligned_len_dataset = aligned_indices.shape[0] | |
| else: | |
| aligned_indices = np.arange(len_dataset) | |
| aligned_len_dataset = len_dataset | |
| num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas) | |
| num_even = num_even_distributed_groups * group_size * num_replicas | |
| self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size) | |
| self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1) | |
| if self.leftover_groups.size == 0: | |
| self.leftover_groups = None | |
| elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'): | |
| # avoid bs=1 | |
| self.leftover_groups = None | |
| # a urly way to modify dataset.load_info according to the grouping | |
| for groupi in self.regular_groups: | |
| for idx in groupi: | |
| idx_lowerbd = groupi[0] | |
| idx_upperbd = groupi[-1] | |
| idx_reference = (idx_lowerbd+idx_upperbd)//2 | |
| dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] | |
| if self.leftover_groups is not None: | |
| for groupi in self.leftover_groups: | |
| for idx in groupi: | |
| idx_lowerbd = groupi[0] | |
| idx_upperbd = groupi[-1] | |
| idx_reference = (idx_lowerbd+idx_upperbd)//2 | |
| dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] | |
| def concat(self, nparrays, axis=0): | |
| # a helper for save concaternation | |
| nparrays = [i for i in nparrays if i.size > 0] | |
| return np.concatenate(nparrays, axis=axis) | |
| def __iter__(self): | |
| indices = self.get_sync_order() | |
| return iter(indices) | |
| def __len__(self): | |
| return self.num_samples | |
| def get_sync_order(self): | |
| # g = torch.Generator() | |
| # g.manual_seed(self.epoch) | |
| mode = self.mode | |
| rank = self.rank | |
| num_replicas = self.num_replicas | |
| group_size = self.group_size | |
| num_groups = len(self.regular_groups) | |
| if mode == 'train': | |
| g_indices = torch.randperm(num_groups).to(rank) | |
| dist.broadcast(g_indices, src=0) | |
| g_indices = g_indices.to('cpu').tolist() | |
| num_groups_per_rank = num_groups // num_replicas | |
| groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)] | |
| indices = groups.flatten() | |
| if self.leftover_groups is not None: | |
| leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank) | |
| dist.broadcast(leftg_indices, src=0) | |
| leftg_indices = leftg_indices.to('cpu').tolist() | |
| last = self.leftover_groups[leftg_indices][rank] | |
| indices = np.concatenate([indices, last], axis=0) | |
| elif mode == 'eval': | |
| groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :] | |
| indices = groups.flatten() | |
| if self.leftover_groups is not None: | |
| last = self.leftover_groups[rank] | |
| indices = np.concatenate([indices, last], axis=0) | |
| else: | |
| raise ValueError | |
| print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1]))) | |
| return indices | |