Spaces:
Runtime error
Runtime error
| import os | |
| import os.path as osp | |
| import numpy as np | |
| import numpy.random as npr | |
| import torch | |
| import torch.distributed as dist | |
| import torchvision | |
| import copy | |
| import itertools | |
| from ... import sync | |
| from ...cfg_holder import cfg_unique_holder as cfguh | |
| from ...log_service import print_log | |
| import torch.distributed as dist | |
| from multiprocessing import shared_memory | |
| import pickle | |
| import hashlib | |
| import random | |
| class ds_base(torch.utils.data.Dataset): | |
| def __init__(self, | |
| cfg, | |
| loader = None, | |
| estimator = None, | |
| transforms = None, | |
| formatter = None): | |
| self.cfg = cfg | |
| self.load_info = None | |
| self.init_load_info() | |
| self.loader = loader | |
| self.transforms = transforms | |
| self.formatter = formatter | |
| if self.load_info is not None: | |
| load_info_order_by = getattr(self.cfg, 'load_info_order_by', 'default') | |
| if load_info_order_by == 'default': | |
| self.load_info = sorted(self.load_info, key=lambda x:x['unique_id']) | |
| else: | |
| try: | |
| load_info_order_by, reverse = load_info_order_by.split('|') | |
| reverse = reverse == 'reverse' | |
| except: | |
| reverse = False | |
| self.load_info = sorted( | |
| self.load_info, key=lambda x:x[load_info_order_by], reverse=reverse) | |
| load_info_add_idx = getattr(self.cfg, 'load_info_add_idx', True) | |
| if (self.load_info is not None) and load_info_add_idx: | |
| for idx, info in enumerate(self.load_info): | |
| info['idx'] = idx | |
| if estimator is not None: | |
| self.load_info = estimator(self.load_info) | |
| self.try_sample = getattr(self.cfg, 'try_sample', None) | |
| if self.try_sample is not None: | |
| try: | |
| start, end = self.try_sample | |
| except: | |
| start, end = 0, self.try_sample | |
| self.load_info = self.load_info[start:end] | |
| self.repeat = getattr(self.cfg, 'repeat', 1) | |
| pick = getattr(self.cfg, 'pick', None) | |
| if pick is not None: | |
| self.load_info = [i for i in self.load_info if i['filename'] in pick] | |
| ######### | |
| # cache # | |
| ######### | |
| self.cache_sm = getattr(self.cfg, 'cache_sm', False) | |
| self.cache_cnt = 0 | |
| if self.cache_sm: | |
| self.cache_pct = getattr(self.cfg, 'cache_pct', 0) | |
| cache_unique_id = sync.nodewise_sync().random_sync_id() | |
| self.cache_unique_id = hashlib.sha256(pickle.dumps(cache_unique_id)).hexdigest() | |
| self.__cache__(self.cache_pct) | |
| ####### | |
| # log # | |
| ####### | |
| if self.load_info is not None: | |
| console_info = '{}: '.format(self.__class__.__name__) | |
| console_info += 'total {} unique images, '.format(len(self.load_info)) | |
| console_info += 'total {} unique sample. Cached {}. Repeat {} times.'.format( | |
| len(self.load_info), self.cache_cnt, self.repeat) | |
| else: | |
| console_info = '{}: load_info not ready.'.format(self.__class__.__name__) | |
| print_log(console_info) | |
| def init_load_info(self): | |
| # implement by sub class | |
| pass | |
| def __len__(self): | |
| return len(self.load_info)*self.repeat | |
| def __cache__(self, pct): | |
| if pct == 0: | |
| self.cache_cnt = 0 | |
| return | |
| self.cache_cnt = int(len(self.load_info)*pct) | |
| if not self.cache_sm: | |
| for i in range(self.cache_cnt): | |
| self.load_info[i] = self.loader(self.load_info[i]) | |
| return | |
| for i in range(self.cache_cnt): | |
| shm_name = str(self.load_info[i]['unique_id']) + '_' + self.cache_unique_id | |
| if i % self.local_world_size == self.local_rank: | |
| data = pickle.dumps(self.loader(self.load_info[i])) | |
| datan = len(data) | |
| # self.print_smname_to_file(shm_name) | |
| shm = shared_memory.SharedMemory( | |
| name=shm_name, create=True, size=datan) | |
| shm.buf[0:datan] = data[0:datan] | |
| shm.close() | |
| self.load_info[i] = shm_name | |
| else: | |
| self.load_info[i] = shm_name | |
| dist.barrier() | |
| def __getitem__(self, idx): | |
| idx = idx%len(self.load_info) | |
| # element = copy.deepcopy(self.load_info[idx]) | |
| # 0730 try shared memory | |
| element = copy.deepcopy(self.load_info[idx]) | |
| if isinstance(element, str): | |
| shm = shared_memory.SharedMemory(name=element) | |
| element = pickle.loads(shm.buf) | |
| shm.close() | |
| else: | |
| element = copy.deepcopy(element) | |
| element['load_info_ptr'] = self.load_info | |
| if idx >= self.cache_cnt: | |
| element = self.loader(element) | |
| if self.transforms is not None: | |
| element = self.transforms(element) | |
| if self.formatter is not None: | |
| return self.formatter(element) | |
| else: | |
| return element | |
| # 0730 try shared memory | |
| def __del__(self): | |
| # Clean the shared memory | |
| for infoi in self.load_info: | |
| if isinstance(infoi, str) and (self.local_rank==0): | |
| shm = shared_memory.SharedMemory(name=infoi) | |
| shm.close() | |
| shm.unlink() | |
| def print_smname_to_file(self, smname): | |
| try: | |
| log_file = cfguh().cfg.train.log_file | |
| except: | |
| try: | |
| log_file = cfguh().cfg.eval.log_file | |
| except: | |
| raise ValueError | |
| # a trick to use the log_file path | |
| sm_file = log_file.replace('.log', '.smname') | |
| with open(sm_file, 'a') as f: | |
| f.write(smname + '\n') | |
| def singleton(class_): | |
| instances = {} | |
| def getinstance(*args, **kwargs): | |
| if class_ not in instances: | |
| instances[class_] = class_(*args, **kwargs) | |
| return instances[class_] | |
| return getinstance | |
| from .ds_loader import get_loader | |
| from .ds_transform import get_transform | |
| from .ds_estimator import get_estimator | |
| from .ds_formatter import get_formatter | |
| class get_dataset(object): | |
| def __init__(self): | |
| self.dataset = {} | |
| def register(self, ds): | |
| self.dataset[ds.__name__] = ds | |
| def __call__(self, cfg): | |
| if cfg is None: | |
| return None | |
| t = cfg.type | |
| if t is None: | |
| return None | |
| elif t in ['laion2b', 'laion2b_dummy', | |
| 'laion2b_webdataset', | |
| 'laion2b_webdataset_sdofficial', ]: | |
| from .. import ds_laion2b | |
| elif t in ['coyo', 'coyo_dummy', | |
| 'coyo_webdataset', ]: | |
| from .. import ds_coyo_webdataset | |
| elif t in ['laionart', 'laionart_dummy', | |
| 'laionart_webdataset', ]: | |
| from .. import ds_laionart | |
| elif t in ['celeba']: | |
| from .. import ds_celeba | |
| elif t in ['div2k']: | |
| from .. import ds_div2k | |
| elif t in ['pafc']: | |
| from .. import ds_pafc | |
| elif t in ['coco_caption']: | |
| from .. import ds_coco | |
| else: | |
| raise ValueError | |
| loader = get_loader() (cfg.get('loader' , None)) | |
| transform = get_transform()(cfg.get('transform', None)) | |
| estimator = get_estimator()(cfg.get('estimator', None)) | |
| formatter = get_formatter()(cfg.get('formatter', None)) | |
| return self.dataset[t]( | |
| cfg, loader, estimator, | |
| transform, formatter) | |
| def register(): | |
| def wrapper(class_): | |
| get_dataset().register(class_) | |
| return class_ | |
| return wrapper | |
| # some other helpers | |
| class collate(object): | |
| """ | |
| Modified from torch.utils.data._utils.collate | |
| It handle list different from the default. | |
| List collate just by append each other. | |
| """ | |
| def __init__(self): | |
| self.default_collate = \ | |
| torch.utils.data._utils.collate.default_collate | |
| def __call__(self, batch): | |
| """ | |
| Args: | |
| batch: [data, data] -or- [(data1, data2, ...), (data1, data2, ...)] | |
| This function will not be used as induction function | |
| """ | |
| elem = batch[0] | |
| if not (elem, (tuple, list)): | |
| return self.default_collate(batch) | |
| rv = [] | |
| # transposed | |
| for i in zip(*batch): | |
| if isinstance(i[0], list): | |
| if len(i[0]) != 1: | |
| raise ValueError | |
| try: | |
| i = [[self.default_collate(ii).squeeze(0)] for ii in i] | |
| except: | |
| pass | |
| rvi = list(itertools.chain.from_iterable(i)) | |
| rv.append(rvi) # list concat | |
| else: | |
| rv.append(self.default_collate(i)) | |
| return rv | |