Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from torch.utils.data.distributed import DistributedSampler | |
| # stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93 | |
| class StatefulDistributedSampler(DistributedSampler): | |
| """ | |
| More fine-grained state DataSampler that uses training iteration and epoch | |
| both for shuffling data. PyTorch DistributedSampler only uses epoch | |
| for the shuffling and starts sampling data from the start. In case of training | |
| on very large data, we train for one epoch only and when we resume training, | |
| we want to resume the data sampler from the training iteration. | |
| """ | |
| def __init__(self, dataset, batch_size=None, seed: int = 0): | |
| """ | |
| Initializes the instance of StatefulDistributedSampler. Random seed is set | |
| for the epoch set and data is shuffled. For starting the sampling, use | |
| the start_iter (set to 0 or set by checkpointing resuming) to | |
| sample data from the remaining images. | |
| Args: | |
| dataset (Dataset): Pytorch dataset that sampler will shuffle | |
| batch_size (int): batch size we want the sampler to sample | |
| seed (int): Seed for the torch generator. | |
| """ | |
| super().__init__(dataset, shuffle=False, seed=seed) | |
| self.start_iter = 0 | |
| self.batch_size = batch_size | |
| self.total_size = len(dataset) - (len(dataset) % self.num_replicas) | |
| self.num_samples = self.total_size // self.num_replicas | |
| print(f"rank: {self.rank}: Sampler created...") | |
| def __iter__(self): | |
| # partition data into num_replicas and optionally shuffle within a rank | |
| g = torch.Generator() | |
| g.manual_seed(self.epoch + self.seed) | |
| shuffling = torch.randperm(self.num_samples, generator=g).tolist() | |
| indices = np.array( | |
| list( | |
| range( | |
| (self.rank * self.num_samples), (self.rank + 1) * self.num_samples | |
| ) | |
| ) | |
| )[shuffling].tolist() | |
| # make sure we have correct number of samples per replica | |
| assert len(indices) == self.num_samples | |
| assert self.batch_size > 0, "batch_size not set for the sampler" | |
| # resume the sampler | |
| start_index = self.start_iter * self.batch_size | |
| indices = indices[start_index:] | |
| return iter(indices) | |
| def set_start_iter(self, start_iter): | |
| """ | |
| Set the iteration number from which the sampling should start. This is | |
| used to find the marker in the data permutation order from where the | |
| sampler should start sampling. | |
| """ | |
| self.start_iter = start_iter | |