Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed as dist | |
| from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process | |
| import random | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class MetaLoader(object): | |
| """ wraps multiple data loader """ | |
| def __init__(self, name2loader): | |
| """Iterates over multiple dataloaders, it ensures all processes | |
| work on data from the same dataloader. This loader will end when | |
| the shorter dataloader raises StopIteration exception. | |
| loaders: Dict, {name: dataloader} | |
| """ | |
| self.name2loader = name2loader | |
| self.name2iter = {name: iter(l) for name, l in name2loader.items()} | |
| name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} | |
| index2name = {v: k for k, v in name2index.items()} | |
| iter_order = [] | |
| for n, l in name2loader.items(): | |
| iter_order.extend([name2index[n]]*len(l)) | |
| random.shuffle(iter_order) | |
| iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) | |
| # sync | |
| if is_dist_avail_and_initialized(): | |
| # make sure all processes have the same order so that | |
| # each step they will have data from the same loader | |
| dist.broadcast(iter_order, src=0) | |
| self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] | |
| logger.info(str(self)) | |
| def __str__(self): | |
| output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] | |
| for idx, (name, loader) in enumerate(self.name2loader.items()): | |
| output.append( | |
| f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " | |
| ) | |
| return "\n".join(output) | |
| def __len__(self): | |
| return len(self.iter_order) | |
| def __iter__(self): | |
| """ this iterator will run indefinitely """ | |
| for name in self.iter_order: | |
| _iter = self.name2iter[name] | |
| batch = next(_iter) | |
| yield name, batch | |
| class MetaLoader_rs(object): | |
| """ wraps multiple data loader """ | |
| def __init__(self, name2loader, skip_num=0): | |
| """Iterates over multiple dataloaders, it ensures all processes | |
| work on data from the same dataloader. This loader will end when | |
| the shorter dataloader raises StopIteration exception. | |
| loaders: Dict, {name: dataloader} | |
| """ | |
| self.name2loader = name2loader | |
| name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} | |
| index2name = {v: k for k, v in name2index.items()} | |
| iter_order = [] | |
| for n, l in name2loader.items(): | |
| iter_order.extend([name2index[n]]*len(l)) | |
| random.shuffle(iter_order) | |
| iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) | |
| # sync | |
| if is_dist_avail_and_initialized(): | |
| # make sure all processes have the same order so that | |
| # each step they will have data from the same loader | |
| dist.broadcast(iter_order, src=0) | |
| if skip_num > 0: | |
| iter_order_skip = iter_order[:skip_num] | |
| for k, v in index2name.items(): | |
| media_step = (iter_order_skip == k).sum().item() | |
| name2loader[v].sampler.set_start_iter(media_step) | |
| logger.info(f"{v} dataloder skip steps: {media_step}") | |
| iter_order = iter_order[skip_num:] | |
| self.name2loader = name2loader | |
| else: | |
| logger.info("Do not skip steps for any dataloader!") | |
| for k, v in index2name.items(): | |
| name2loader[v].sampler.set_start_iter(0) | |
| self.name2iter = {name: iter(l) for name, l in name2loader.items()} | |
| self.iter_idx = iter_order | |
| self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] | |
| logger.info(str(self)) | |
| def __str__(self): | |
| output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] | |
| for idx, (name, loader) in enumerate(self.name2loader.items()): | |
| length = (self.iter_idx == idx).sum() | |
| output.append( | |
| f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} " | |
| ) | |
| return "\n".join(output) | |
| def __len__(self): | |
| return len(self.iter_order) | |
| def __iter__(self): | |
| """ this iterator will run indefinitely """ | |
| for name in self.iter_order: | |
| _iter = self.name2iter[name] | |
| batch = next(_iter) | |
| yield name, batch |