Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import time | |
| from typing import Optional, Union | |
| import torch | |
| from mmengine.dist.utils import master_only | |
| from mmengine.logging import MMLogger, print_log | |
| class TimeCounter: | |
| """A tool that counts the average running time of a function or a method. | |
| Users can use it as a decorator or context manager to calculate the average | |
| running time of code blocks. | |
| Args: | |
| log_interval (int): The interval of logging. Defaults to 1. | |
| warmup_interval (int): The interval of warmup. Defaults to 1. | |
| with_sync (bool): Whether to synchronize cuda. Defaults to True. | |
| tag (str, optional): Function tag. Used to distinguish between | |
| different functions or methods being called. Defaults to None. | |
| logger (MMLogger, optional): Formatted logger used to record messages. | |
| Defaults to None. | |
| Examples: | |
| >>> import time | |
| >>> from mmengine.utils.dl_utils import TimeCounter | |
| >>> @TimeCounter() | |
| ... def fun1(): | |
| ... time.sleep(0.1) | |
| ... fun1() | |
| [fun1]-time per run averaged in the past 1 runs: 100.0 ms | |
| >>> @@TimeCounter(log_interval=2, tag='fun') | |
| ... def fun2(): | |
| ... time.sleep(0.2) | |
| >>> for _ in range(3): | |
| ... fun2() | |
| [fun]-time per run averaged in the past 2 runs: 200.0 ms | |
| >>> with TimeCounter(tag='fun3'): | |
| ... time.sleep(0.3) | |
| [fun3]-time per run averaged in the past 1 runs: 300.0 ms | |
| """ | |
| instance_dict: dict = dict() | |
| log_interval: int | |
| warmup_interval: int | |
| logger: Optional[MMLogger] | |
| __count: int | |
| __pure_inf_time: float | |
| def __new__(cls, | |
| log_interval: int = 1, | |
| warmup_interval: int = 1, | |
| with_sync: bool = True, | |
| tag: Optional[str] = None, | |
| logger: Optional[MMLogger] = None): | |
| assert warmup_interval >= 1 | |
| if tag is not None and tag in cls.instance_dict: | |
| return cls.instance_dict[tag] | |
| instance = super().__new__(cls) | |
| cls.instance_dict[tag] = instance | |
| instance.log_interval = log_interval | |
| instance.warmup_interval = warmup_interval | |
| instance.with_sync = with_sync | |
| instance.tag = tag | |
| instance.logger = logger | |
| instance.__count = 0 | |
| instance.__pure_inf_time = 0. | |
| instance.__start_time = 0. | |
| return instance | |
| def __call__(self, fn): | |
| if self.tag is None: | |
| self.tag = fn.__name__ | |
| def wrapper(*args, **kwargs): | |
| self.__count += 1 | |
| if self.with_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| start_time = time.perf_counter() | |
| result = fn(*args, **kwargs) | |
| if self.with_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elapsed = time.perf_counter() - start_time | |
| self.print_time(elapsed) | |
| return result | |
| return wrapper | |
| def __enter__(self): | |
| assert self.tag is not None, 'In order to clearly distinguish ' \ | |
| 'printing information in different ' \ | |
| 'contexts, please specify the ' \ | |
| 'tag parameter' | |
| self.__count += 1 | |
| if self.with_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.__start_time = time.perf_counter() | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if self.with_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elapsed = time.perf_counter() - self.__start_time | |
| self.print_time(elapsed) | |
| def print_time(self, elapsed: Union[int, float]) -> None: | |
| """print times per count.""" | |
| if self.__count >= self.warmup_interval: | |
| self.__pure_inf_time += elapsed | |
| if self.__count % self.log_interval == 0: | |
| times_per_count = 1000 * self.__pure_inf_time / ( | |
| self.__count - self.warmup_interval + 1) | |
| print_log( | |
| f'[{self.tag}]-time per run averaged in the past ' | |
| f'{self.__count} runs: {times_per_count:.1f} ms', | |
| self.logger) | |