Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed as dist | |
| import os | |
| import os.path as osp | |
| import numpy as np | |
| import copy | |
| import json | |
| from ..log_service import print_log | |
| def singleton(class_): | |
| instances = {} | |
| def getinstance(*args, **kwargs): | |
| if class_ not in instances: | |
| instances[class_] = class_(*args, **kwargs) | |
| return instances[class_] | |
| return getinstance | |
| class get_evaluator(object): | |
| def __init__(self): | |
| self.evaluator = {} | |
| def register(self, evaf, name): | |
| self.evaluator[name] = evaf | |
| def __call__(self, pipeline_cfg=None): | |
| if pipeline_cfg is None: | |
| from . import eva_null | |
| return self.evaluator['null']() | |
| if not isinstance(pipeline_cfg, list): | |
| t = pipeline_cfg.type | |
| if t == 'miou': | |
| from . import eva_miou | |
| if t == 'psnr': | |
| from . import eva_psnr | |
| if t == 'ssim': | |
| from . import eva_ssim | |
| if t == 'lpips': | |
| from . import eva_lpips | |
| if t == 'fid': | |
| from . import eva_fid | |
| return self.evaluator[t](**pipeline_cfg.args) | |
| evaluator = [] | |
| for ci in pipeline_cfg: | |
| t = ci.type | |
| if t == 'miou': | |
| from . import eva_miou | |
| if t == 'psnr': | |
| from . import eva_psnr | |
| if t == 'ssim': | |
| from . import eva_ssim | |
| if t == 'lpips': | |
| from . import eva_lpips | |
| if t == 'fid': | |
| from . import eva_fid | |
| evaluator.append( | |
| self.evaluator[t](**ci.args)) | |
| if len(evaluator) == 0: | |
| return None | |
| else: | |
| return compose(evaluator) | |
| def register(name): | |
| def wrapper(class_): | |
| get_evaluator().register(class_, name) | |
| return class_ | |
| return wrapper | |
| class base_evaluator(object): | |
| def __init__(self, | |
| **args): | |
| ''' | |
| Args: | |
| sample_n, int, | |
| the total number of sample. used in | |
| distributed sync | |
| ''' | |
| if not dist.is_available(): | |
| raise ValueError | |
| self.world_size = dist.get_world_size() | |
| self.rank = dist.get_rank() | |
| self.sample_n = None | |
| self.final = {} | |
| def sync(self, data): | |
| """ | |
| Args: | |
| data: any, | |
| the data needs to be broadcasted | |
| """ | |
| if data is None: | |
| return None | |
| if isinstance(data, tuple): | |
| data = list(data) | |
| if isinstance(data, list): | |
| data_list = [] | |
| for datai in data: | |
| data_list.append(self.sync(datai)) | |
| data = [[*i] for i in zip(*data_list)] | |
| return data | |
| data = [ | |
| self.sync_(data, ranki) | |
| for ranki in range(self.world_size) | |
| ] | |
| return data | |
| def sync_(self, data, rank): | |
| t = type(data) | |
| is_broadcast = rank == self.rank | |
| if t is np.ndarray: | |
| dtrans = data | |
| dt = data.dtype | |
| if dt in [ | |
| int, | |
| np.bool, | |
| np.uint8, | |
| np.int8, | |
| np.int16, | |
| np.int32, | |
| np.int64,]: | |
| dtt = torch.int64 | |
| elif dt in [ | |
| float, | |
| np.float16, | |
| np.float32, | |
| np.float64,]: | |
| dtt = torch.float64 | |
| elif t is str: | |
| dtrans = np.array( | |
| [ord(c) for c in data], | |
| dtype = np.int64 | |
| ) | |
| dt = np.int64 | |
| dtt = torch.int64 | |
| else: | |
| raise ValueError | |
| if is_broadcast: | |
| n = len(dtrans.shape) | |
| n = torch.tensor(n).long() | |
| n = n.to(self.rank) | |
| dist.broadcast(n, src=rank) | |
| n = list(dtrans.shape) | |
| n = torch.tensor(n).long() | |
| n = n.to(self.rank) | |
| dist.broadcast(n, src=rank) | |
| n = torch.tensor(dtrans, dtype=dtt) | |
| n = n.to(self.rank) | |
| dist.broadcast(n, src=rank) | |
| return data | |
| n = torch.tensor(0).long() | |
| n = n.to(self.rank) | |
| dist.broadcast(n, src=rank) | |
| n = n.item() | |
| n = torch.zeros(n).long() | |
| n = n.to(self.rank) | |
| dist.broadcast(n, src=rank) | |
| n = list(n.to('cpu').numpy()) | |
| n = torch.zeros(n, dtype=dtt) | |
| n = n.to(self.rank) | |
| dist.broadcast(n, src=rank) | |
| n = n.to('cpu').numpy().astype(dt) | |
| if t is np.ndarray: | |
| return n | |
| elif t is str: | |
| n = ''.join([chr(c) for c in n]) | |
| return n | |
| def zipzap_arrange(self, data): | |
| ''' | |
| Order the data so it range like this: | |
| input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...] | |
| ''' | |
| if isinstance(data[0], list): | |
| data_new = [] | |
| maxlen = max([len(i) for i in data]) | |
| totlen = sum([len(i) for i in data]) | |
| cnt = 0 | |
| for idx in range(maxlen): | |
| for datai in data: | |
| data_new += [datai[idx]] | |
| cnt += 1 | |
| if cnt >= totlen: | |
| break | |
| return data_new | |
| elif isinstance(data[0], np.ndarray): | |
| maxlen = max([i.shape[0] for i in data]) | |
| totlen = sum([i.shape[0] for i in data]) | |
| datai_shape = data[0].shape[1:] | |
| data = [ | |
| np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0) | |
| if datai.shape[0] < maxlen else datai | |
| for datai in data | |
| ] # even the array | |
| data = np.stack(data, axis=1).reshape(-1, *datai_shape) | |
| data = data[:totlen] | |
| return data | |
| else: | |
| raise NotImplementedError | |
| def add_batch(self, **args): | |
| raise NotImplementedError | |
| def set_sample_n(self, sample_n): | |
| self.sample_n = sample_n | |
| def compute(self): | |
| raise NotImplementedError | |
| # Function needed in training to judge which | |
| # evaluated number is better | |
| def isbetter(self, old, new): | |
| return new>old | |
| def one_line_summary(self): | |
| print_log('Evaluator display') | |
| def save(self, path): | |
| if not osp.exists(path): | |
| os.makedirs(path) | |
| ofile = osp.join(path, 'result.json') | |
| with open(ofile, 'w') as f: | |
| json.dump(self.final, f, indent=4) | |
| def clear_data(self): | |
| raise NotImplementedError | |
| class compose(object): | |
| def __init__(self, pipeline): | |
| self.pipeline = pipeline | |
| self.sample_n = None | |
| self.final = {} | |
| def add_batch(self, *args, **kwargs): | |
| for pi in self.pipeline: | |
| pi.add_batch(*args, **kwargs) | |
| def set_sample_n(self, sample_n): | |
| self.sample_n = sample_n | |
| for pi in self.pipeline: | |
| pi.set_sample_n(sample_n) | |
| def compute(self): | |
| rv = {} | |
| for pi in self.pipeline: | |
| rv[pi.symbol] = pi.compute() | |
| self.final[pi.symbol] = pi.final | |
| return rv | |
| def isbetter(self, old, new): | |
| check = 0 | |
| for pi in self.pipeline: | |
| if pi.isbetter(old, new): | |
| check+=1 | |
| if check/len(self.pipeline)>0.5: | |
| return True | |
| else: | |
| return False | |
| def one_line_summary(self): | |
| for pi in self.pipeline: | |
| pi.one_line_summary() | |
| def save(self, path): | |
| if not osp.exists(path): | |
| os.makedirs(path) | |
| ofile = osp.join(path, 'result.json') | |
| with open(ofile, 'w') as f: | |
| json.dump(self.final, f, indent=4) | |
| def clear_data(self): | |
| for pi in self.pipeline: | |
| pi.clear_data() | |