Spaces:
Runtime error
Runtime error
| import datetime | |
| import logging | |
| import time | |
| class MessageLogger(): | |
| """Message logger for printing. | |
| Args: | |
| opt (dict): Config. It contains the following keys: | |
| name (str): Exp name. | |
| logger (dict): Contains 'print_freq' (str) for logger interval. | |
| train (dict): Contains 'niter' (int) for total iters. | |
| use_tb_logger (bool): Use tensorboard logger. | |
| start_iter (int): Start iter. Default: 1. | |
| tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. | |
| """ | |
| def __init__(self, opt, start_iter=1, tb_logger=None): | |
| self.exp_name = opt['name'] | |
| self.interval = opt['print_freq'] | |
| self.start_iter = start_iter | |
| self.max_iters = opt['max_iters'] | |
| self.use_tb_logger = opt['use_tb_logger'] | |
| self.tb_logger = tb_logger | |
| self.start_time = time.time() | |
| self.logger = get_root_logger() | |
| def __call__(self, log_vars): | |
| """Format logging message. | |
| Args: | |
| log_vars (dict): It contains the following keys: | |
| epoch (int): Epoch number. | |
| iter (int): Current iter. | |
| lrs (list): List for learning rates. | |
| time (float): Iter time. | |
| data_time (float): Data time for each iter. | |
| """ | |
| # epoch, iter, learning rates | |
| epoch = log_vars.pop('epoch') | |
| current_iter = log_vars.pop('iter') | |
| lrs = log_vars.pop('lrs') | |
| message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' | |
| f'iter:{current_iter:8,d}, lr:(') | |
| for v in lrs: | |
| message += f'{v:.3e},' | |
| message += ')] ' | |
| # time and estimated time | |
| if 'time' in log_vars.keys(): | |
| iter_time = log_vars.pop('time') | |
| data_time = log_vars.pop('data_time') | |
| total_time = time.time() - self.start_time | |
| time_sec_avg = total_time / (current_iter - self.start_iter + 1) | |
| eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) | |
| eta_str = str(datetime.timedelta(seconds=int(eta_sec))) | |
| message += f'[eta: {eta_str}, ' | |
| message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] ' | |
| # other items, especially losses | |
| for k, v in log_vars.items(): | |
| message += f'{k}: {v:.4e} ' | |
| # tensorboard logger | |
| if self.use_tb_logger and 'debug' not in self.exp_name: | |
| self.tb_logger.add_scalar(k, v, current_iter) | |
| self.logger.info(message) | |
| def init_tb_logger(log_dir): | |
| from torch.utils.tensorboard import SummaryWriter | |
| tb_logger = SummaryWriter(log_dir=log_dir) | |
| return tb_logger | |
| def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None): | |
| """Get the root logger. | |
| The logger will be initialized if it has not been initialized. By default a | |
| StreamHandler will be added. If `log_file` is specified, a FileHandler will | |
| also be added. | |
| Args: | |
| logger_name (str): root logger name. Default: base. | |
| log_file (str | None): The log filename. If specified, a FileHandler | |
| will be added to the root logger. | |
| log_level (int): The root logger level. Note that only the process of | |
| rank 0 is affected, while other processes will set the level to | |
| "Error" and be silent most of the time. | |
| Returns: | |
| logging.Logger: The root logger. | |
| """ | |
| logger = logging.getLogger(logger_name) | |
| # if the logger has been initialized, just return it | |
| if logger.hasHandlers(): | |
| return logger | |
| format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s' | |
| logging.basicConfig(format=format_str, level=log_level) | |
| if log_file is not None: | |
| file_handler = logging.FileHandler(log_file, 'w') | |
| file_handler.setFormatter(logging.Formatter(format_str)) | |
| file_handler.setLevel(log_level) | |
| logger.addHandler(file_handler) | |
| return logger | |