Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import inspect | |
| import logging | |
| import os | |
| import os.path as osp | |
| import sys | |
| import warnings | |
| from getpass import getuser | |
| from logging import Logger, LogRecord, handlers | |
| from socket import gethostname | |
| from typing import Dict, Optional, Union | |
| from termcolor import colored | |
| from mmengine.utils import ManagerMixin | |
| from mmengine.utils.manager import _accquire_lock, _release_lock | |
| class FilterDuplicateWarning(logging.Filter): | |
| """Filter the repeated warning message. | |
| Args: | |
| name (str): name of the filter. | |
| """ | |
| def __init__(self, name: str = 'mmengine'): | |
| super().__init__(name) | |
| self.seen: set = set() | |
| def filter(self, record: LogRecord) -> bool: | |
| """Filter the repeated warning message. | |
| Args: | |
| record (LogRecord): The log record. | |
| Returns: | |
| bool: Whether to output the log record. | |
| """ | |
| if record.levelno != logging.WARNING: | |
| return True | |
| if record.msg not in self.seen: | |
| self.seen.add(record.msg) | |
| return True | |
| return False | |
| class MMFormatter(logging.Formatter): | |
| """Colorful format for MMLogger. If the log level is error, the logger will | |
| additionally output the location of the code. | |
| Args: | |
| color (bool): Whether to use colorful format. filehandler is not | |
| allowed to use color format, otherwise it will be garbled. | |
| blink (bool): Whether to blink the ``INFO`` and ``DEBUG`` logging | |
| level. | |
| **kwargs: Keyword arguments passed to | |
| :meth:`logging.Formatter.__init__`. | |
| """ | |
| _color_mapping: dict = dict( | |
| ERROR='red', WARNING='yellow', INFO='white', DEBUG='green') | |
| def __init__(self, color: bool = True, blink: bool = False, **kwargs): | |
| super().__init__(**kwargs) | |
| assert not (not color and blink), ( | |
| 'blink should only be available when color is True') | |
| # Get prefix format according to color. | |
| error_prefix = self._get_prefix('ERROR', color, blink=True) | |
| warn_prefix = self._get_prefix('WARNING', color, blink=True) | |
| info_prefix = self._get_prefix('INFO', color, blink) | |
| debug_prefix = self._get_prefix('DEBUG', color, blink) | |
| # Config output format. | |
| self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - ' | |
| '%(pathname)s - %(funcName)s - %(lineno)d - ' | |
| '%(message)s') | |
| self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %(' | |
| 'message)s') | |
| self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %(' | |
| 'message)s') | |
| self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %(' | |
| 'message)s') | |
| def _get_prefix(self, level: str, color: bool, blink=False) -> str: | |
| """Get the prefix of the target log level. | |
| Args: | |
| level (str): log level. | |
| color (bool): Whether to get colorful prefix. | |
| blink (bool): Whether the prefix will blink. | |
| Returns: | |
| str: The plain or colorful prefix. | |
| """ | |
| if color: | |
| attrs = ['underline'] | |
| if blink: | |
| attrs.append('blink') | |
| prefix = colored(level, self._color_mapping[level], attrs=attrs) | |
| else: | |
| prefix = level | |
| return prefix | |
| def format(self, record: LogRecord) -> str: | |
| """Override the `logging.Formatter.format`` method `. Output the | |
| message according to the specified log level. | |
| Args: | |
| record (LogRecord): A LogRecord instance represents an event being | |
| logged. | |
| Returns: | |
| str: Formatted result. | |
| """ | |
| if record.levelno == logging.ERROR: | |
| self._style._fmt = self.err_format | |
| elif record.levelno == logging.WARNING: | |
| self._style._fmt = self.warn_format | |
| elif record.levelno == logging.INFO: | |
| self._style._fmt = self.info_format | |
| elif record.levelno == logging.DEBUG: | |
| self._style._fmt = self.debug_format | |
| result = logging.Formatter.format(self, record) | |
| return result | |
| class MMLogger(Logger, ManagerMixin): | |
| """Formatted logger used to record messages. | |
| ``MMLogger`` can create formatted logger to log message with different | |
| log levels and get instance in the same way as ``ManagerMixin``. | |
| ``MMLogger`` has the following features: | |
| - Distributed log storage, ``MMLogger`` can choose whether to save log of | |
| different ranks according to `log_file`. | |
| - Message with different log levels will have different colors and format | |
| when displayed on terminal. | |
| Note: | |
| - The `name` of logger and the ``instance_name`` of ``MMLogger`` could | |
| be different. We can only get ``MMLogger`` instance by | |
| ``MMLogger.get_instance`` but not ``logging.getLogger``. This feature | |
| ensures ``MMLogger`` will not be incluenced by third-party logging | |
| config. | |
| - Different from ``logging.Logger``, ``MMLogger`` will not log warning | |
| or error message without ``Handler``. | |
| Examples: | |
| >>> logger = MMLogger.get_instance(name='MMLogger', | |
| >>> logger_name='Logger') | |
| >>> # Although logger has name attribute just like `logging.Logger` | |
| >>> # We cannot get logger instance by `logging.getLogger`. | |
| >>> assert logger.name == 'Logger' | |
| >>> assert logger.instance_name = 'MMLogger' | |
| >>> assert id(logger) != id(logging.getLogger('Logger')) | |
| >>> # Get logger that do not store logs. | |
| >>> logger1 = MMLogger.get_instance('logger1') | |
| >>> # Get logger only save rank0 logs. | |
| >>> logger2 = MMLogger.get_instance('logger2', log_file='out.log') | |
| >>> # Get logger only save multiple ranks logs. | |
| >>> logger3 = MMLogger.get_instance('logger3', log_file='out.log', | |
| >>> distributed=True) | |
| Args: | |
| name (str): Global instance name. | |
| logger_name (str): ``name`` attribute of ``Logging.Logger`` instance. | |
| If `logger_name` is not defined, defaults to 'mmengine'. | |
| log_file (str, optional): The log filename. If specified, a | |
| ``FileHandler`` will be added to the logger. Defaults to None. | |
| log_level (str): The log level of the handler. Defaults to | |
| 'INFO'. If log level is 'DEBUG', distributed logs will be saved | |
| during distributed training. | |
| file_mode (str): The file mode used to open log file. Defaults to 'w'. | |
| distributed (bool): Whether to save distributed logs, Defaults to | |
| false. | |
| file_handler_cfg (dict, optional): Configuration of file handler. | |
| Defaults to None. If ``file_handler_cfg`` is not specified, | |
| ``logging.FileHandler`` will be used by default. If it is | |
| specified, the ``type`` key should be set. It can be | |
| ``RotatingFileHandler``, ``TimedRotatingFileHandler``, | |
| ``WatchedFileHandler`` or other file handlers, and the remaining | |
| fields will be used to build the handler. | |
| Examples: | |
| >>> file_handler_cfg = dict( | |
| >>> type='TimedRotatingFileHandler', | |
| >>> when='MIDNIGHT', | |
| >>> interval=1, | |
| >>> backupCount=365) | |
| `New in version 0.8.5.` | |
| """ | |
| def __init__(self, | |
| name: str, | |
| logger_name='mmengine', | |
| log_file: Optional[str] = None, | |
| log_level: Union[int, str] = 'INFO', | |
| file_mode: str = 'w', | |
| distributed=False, | |
| file_handler_cfg: Optional[dict] = None): | |
| Logger.__init__(self, logger_name) | |
| ManagerMixin.__init__(self, name) | |
| # Get rank in DDP mode. | |
| if isinstance(log_level, str): | |
| log_level = logging._nameToLevel[log_level] | |
| global_rank = _get_rank() | |
| device_id = _get_device_id() | |
| # Config stream_handler. If `rank != 0`. stream_handler can only | |
| # export ERROR logs. | |
| stream_handler = logging.StreamHandler(stream=sys.stdout) | |
| # `StreamHandler` record month, day, hour, minute, and second | |
| # timestamp. | |
| stream_handler.setFormatter( | |
| MMFormatter(color=True, datefmt='%m/%d %H:%M:%S')) | |
| # Only rank0 `StreamHandler` will log messages below error level. | |
| if global_rank == 0: | |
| stream_handler.setLevel(log_level) | |
| else: | |
| stream_handler.setLevel(logging.ERROR) | |
| stream_handler.addFilter(FilterDuplicateWarning(logger_name)) | |
| self.handlers.append(stream_handler) | |
| if log_file is not None: | |
| world_size = _get_world_size() | |
| is_distributed = (log_level <= logging.DEBUG | |
| or distributed) and world_size > 1 | |
| if is_distributed: | |
| filename, suffix = osp.splitext(osp.basename(log_file)) | |
| hostname = _get_host_info() | |
| if hostname: | |
| filename = (f'{filename}_{hostname}_device{device_id}_' | |
| f'rank{global_rank}{suffix}') | |
| else: | |
| # Omit hostname if it is empty | |
| filename = (f'{filename}_device{device_id}_' | |
| f'rank{global_rank}{suffix}') | |
| log_file = osp.join(osp.dirname(log_file), filename) | |
| # Save multi-ranks logs if distributed is True. The logs of rank0 | |
| # will always be saved. | |
| if global_rank == 0 or is_distributed: | |
| if file_handler_cfg is not None: | |
| assert 'type' in file_handler_cfg | |
| file_handler_type = file_handler_cfg.pop('type') | |
| file_handlers_map = _get_logging_file_handlers() | |
| if file_handler_type in file_handlers_map: | |
| file_handler_cls = file_handlers_map[file_handler_type] | |
| file_handler_cfg.setdefault('filename', log_file) | |
| file_handler = file_handler_cls(**file_handler_cfg) | |
| else: | |
| raise ValueError('`logging.handlers` does not ' | |
| f'contain {file_handler_type}') | |
| else: | |
| # Here, the default behavior of the official | |
| # logger is 'a'. Thus, we provide an interface to | |
| # change the file mode to the default behavior. | |
| # `FileHandler` is not supported to have colors, | |
| # otherwise it will appear garbled. | |
| file_handler = logging.FileHandler(log_file, file_mode) | |
| # `StreamHandler` record year, month, day hour, minute, | |
| # and second timestamp. file_handler will only record logs | |
| # without color to avoid garbled code saved in files. | |
| file_handler.setFormatter( | |
| MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) | |
| file_handler.setLevel(log_level) | |
| file_handler.addFilter(FilterDuplicateWarning(logger_name)) | |
| self.handlers.append(file_handler) | |
| self._log_file = log_file | |
| def log_file(self): | |
| return self._log_file | |
| def get_current_instance(cls) -> 'MMLogger': | |
| """Get latest created ``MMLogger`` instance. | |
| :obj:`MMLogger` can call :meth:`get_current_instance` before any | |
| instance has been created, and return a logger with the instance name | |
| "mmengine". | |
| Returns: | |
| MMLogger: Configured logger instance. | |
| """ | |
| if not cls._instance_dict: | |
| cls.get_instance('mmengine') | |
| return super().get_current_instance() | |
| def callHandlers(self, record: LogRecord) -> None: | |
| """Pass a record to all relevant handlers. | |
| Override ``callHandlers`` method in ``logging.Logger`` to avoid | |
| multiple warning messages in DDP mode. Loop through all handlers of | |
| the logger instance and its parents in the logger hierarchy. If no | |
| handler was found, the record will not be output. | |
| Args: | |
| record (LogRecord): A ``LogRecord`` instance contains logged | |
| message. | |
| """ | |
| for handler in self.handlers: | |
| if record.levelno >= handler.level: | |
| handler.handle(record) | |
| def setLevel(self, level): | |
| """Set the logging level of this logger. | |
| If ``logging.Logger.selLevel`` is called, all ``logging.Logger`` | |
| instances managed by ``logging.Manager`` will clear the cache. Since | |
| ``MMLogger`` is not managed by ``logging.Manager`` anymore, | |
| ``MMLogger`` should override this method to clear caches of all | |
| ``MMLogger`` instance which is managed by :obj:`ManagerMixin`. | |
| level must be an int or a str. | |
| """ | |
| self.level = logging._checkLevel(level) | |
| _accquire_lock() | |
| # The same logic as `logging.Manager._clear_cache`. | |
| for logger in MMLogger._instance_dict.values(): | |
| logger._cache.clear() | |
| _release_lock() | |
| def print_log(msg, | |
| logger: Optional[Union[Logger, str]] = None, | |
| level=logging.INFO) -> None: | |
| """Print a log message. | |
| Args: | |
| msg (str): The message to be logged. | |
| logger (Logger or str, optional): If the type of logger is | |
| ``logging.Logger``, we directly use logger to log messages. | |
| Some special loggers are: | |
| - "silent": No message will be printed. | |
| - "current": Use latest created logger to log message. | |
| - other str: Instance name of logger. The corresponding logger | |
| will log message if it has been created, otherwise ``print_log`` | |
| will raise a `ValueError`. | |
| - None: The `print()` method will be used to print log messages. | |
| level (int): Logging level. Only available when `logger` is a Logger | |
| object, "current", or a created logger instance name. | |
| """ | |
| if logger is None: | |
| print(msg) | |
| elif isinstance(logger, logging.Logger): | |
| logger.log(level, msg) | |
| elif logger == 'silent': | |
| pass | |
| elif logger == 'current': | |
| logger_instance = MMLogger.get_current_instance() | |
| logger_instance.log(level, msg) | |
| elif isinstance(logger, str): | |
| # If the type of `logger` is `str`, but not with value of `current` or | |
| # `silent`, we assume it indicates the name of the logger. If the | |
| # corresponding logger has not been created, `print_log` will raise | |
| # a `ValueError`. | |
| if MMLogger.check_instance_created(logger): | |
| logger_instance = MMLogger.get_instance(logger) | |
| logger_instance.log(level, msg) | |
| else: | |
| raise ValueError(f'MMLogger: {logger} has not been created!') | |
| else: | |
| raise TypeError( | |
| '`logger` should be either a logging.Logger object, str, ' | |
| f'"silent", "current" or None, but got {type(logger)}') | |
| def _get_world_size(): | |
| """Support using logging module without torch.""" | |
| try: | |
| # requires torch | |
| from mmengine.dist import get_world_size | |
| except ImportError: | |
| return 1 | |
| else: | |
| return get_world_size() | |
| def _get_rank(): | |
| """Support using logging module without torch.""" | |
| try: | |
| # requires torch | |
| from mmengine.dist import get_rank | |
| except ImportError: | |
| return 0 | |
| else: | |
| return get_rank() | |
| def _get_device_id(): | |
| """Get device id of current machine.""" | |
| try: | |
| import torch | |
| except ImportError: | |
| return 0 | |
| else: | |
| local_rank = int(os.getenv('LOCAL_RANK', '0')) | |
| # TODO: return device id of npu and mlu. | |
| if not torch.cuda.is_available(): | |
| return local_rank | |
| cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) | |
| if cuda_visible_devices is None: | |
| num_device = torch.cuda.device_count() | |
| cuda_visible_devices = list(range(num_device)) | |
| else: | |
| cuda_visible_devices = cuda_visible_devices.split(',') | |
| try: | |
| return int(cuda_visible_devices[local_rank]) | |
| except ValueError: | |
| # handle case for Multi-Instance GPUs | |
| # see #1148 for details | |
| return cuda_visible_devices[local_rank] | |
| def _get_host_info() -> str: | |
| """Get hostname and username. | |
| Return empty string if exception raised, e.g. ``getpass.getuser()`` will | |
| lead to error in docker container | |
| """ | |
| host = '' | |
| try: | |
| host = f'{getuser()}@{gethostname()}' | |
| except Exception as e: | |
| warnings.warn(f'Host or user not found: {str(e)}') | |
| finally: | |
| return host | |
| def _get_logging_file_handlers() -> Dict: | |
| """Get additional file_handlers in ``logging.handlers``. | |
| Returns: | |
| Dict: A map of file_handlers. | |
| """ | |
| file_handlers_map = {} | |
| for module_name in dir(handlers): | |
| if module_name.startswith('__'): | |
| continue | |
| _fh = getattr(handlers, module_name) | |
| if inspect.isclass(_fh) and issubclass(_fh, logging.FileHandler): | |
| file_handlers_map[module_name] = _fh | |
| return file_handlers_map | |