Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from copy import deepcopy | |
| from functools import partial | |
| from glob import glob | |
| from hashlib import sha1 | |
| from typing import Callable, Iterable, Optional, Tuple | |
| import cv2 | |
| import numpy as np | |
| from glog import logger | |
| from joblib import Parallel, cpu_count, delayed | |
| from skimage.io import imread | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| import aug | |
| def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True): | |
| data = list(data) | |
| buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn) | |
| lower_bound, upper_bound = [x * n_buckets for x in bounds] | |
| msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}' | |
| if salt: | |
| msg += f'; salt is {salt}' | |
| if verbose: | |
| logger.info(msg) | |
| return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound]) | |
| def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str: | |
| path_a, path_b = x | |
| names = ''.join(map(os.path.basename, (path_a, path_b))) | |
| return sha1(f'{names}_{salt}'.encode()).hexdigest() | |
| def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''): | |
| hashes = map(partial(hash_fn, salt=salt), data) | |
| return np.array([int(x, 16) % n_buckets for x in hashes]) | |
| def _read_img(x: str): | |
| img = cv2.imread(x) | |
| if img is None: | |
| logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image') | |
| img = imread(x) | |
| return img | |
| class PairedDataset(Dataset): | |
| def __init__(self, | |
| files_a: Tuple[str], | |
| files_b: Tuple[str], | |
| transform_fn: Callable, | |
| normalize_fn: Callable, | |
| corrupt_fn: Optional[Callable] = None, | |
| preload: bool = True, | |
| preload_size: Optional[int] = 0, | |
| verbose=True): | |
| assert len(files_a) == len(files_b) | |
| self.preload = preload | |
| self.data_a = files_a | |
| self.data_b = files_b | |
| self.verbose = verbose | |
| self.corrupt_fn = corrupt_fn | |
| self.transform_fn = transform_fn | |
| self.normalize_fn = normalize_fn | |
| logger.info(f'Dataset has been created with {len(self.data_a)} samples') | |
| if preload: | |
| preload_fn = partial(self._bulk_preload, preload_size=preload_size) | |
| if files_a == files_b: | |
| self.data_a = self.data_b = preload_fn(self.data_a) | |
| else: | |
| self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b)) | |
| self.preload = True | |
| def _bulk_preload(self, data: Iterable[str], preload_size: int): | |
| jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data] | |
| jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose) | |
| return Parallel(n_jobs=cpu_count(), backend='threading')(jobs) | |
| def _preload(x: str, preload_size: int): | |
| img = _read_img(x) | |
| if preload_size: | |
| h, w, *_ = img.shape | |
| h_scale = preload_size / h | |
| w_scale = preload_size / w | |
| scale = max(h_scale, w_scale) | |
| img = cv2.resize(img, fx=scale, fy=scale, dsize=None) | |
| assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}' | |
| return img | |
| def _preprocess(self, img, res): | |
| def transpose(x): | |
| return np.transpose(x, (2, 0, 1)) | |
| return map(transpose, self.normalize_fn(img, res)) | |
| def __len__(self): | |
| return len(self.data_a) | |
| def __getitem__(self, idx): | |
| a, b = self.data_a[idx], self.data_b[idx] | |
| if not self.preload: | |
| a, b = map(_read_img, (a, b)) | |
| a, b = self.transform_fn(a, b) | |
| if self.corrupt_fn is not None: | |
| a = self.corrupt_fn(a) | |
| a, b = self._preprocess(a, b) | |
| return {'a': a, 'b': b} | |
| def from_config(config): | |
| config = deepcopy(config) | |
| files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b')) | |
| transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop']) | |
| normalize_fn = aug.get_normalize() | |
| hash_fn = hash_from_paths | |
| # ToDo: add more hash functions | |
| verbose = config.get('verbose', True) | |
| data = subsample(data=zip(files_a, files_b), | |
| bounds=config.get('bounds', (0, 1)), | |
| hash_fn=hash_fn, | |
| verbose=verbose) | |
| files_a, files_b = map(list, zip(*data)) | |
| return PairedDataset(files_a=files_a, | |
| files_b=files_b, | |
| preload=config['preload'], | |
| preload_size=config['preload_size'], | |
| normalize_fn=normalize_fn, | |
| transform_fn=transform_fn, | |
| verbose=verbose) |