Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import warnings | |
| from typing import Any, Callable, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| class HistoryBuffer: | |
| """Unified storage format for different log types. | |
| ``HistoryBuffer`` records the history of log for further statistics. | |
| Examples: | |
| >>> history_buffer = HistoryBuffer() | |
| >>> # Update history_buffer. | |
| >>> history_buffer.update(1) | |
| >>> history_buffer.update(2) | |
| >>> history_buffer.min() # minimum of (1, 2) | |
| 1 | |
| >>> history_buffer.max() # maximum of (1, 2) | |
| 2 | |
| >>> history_buffer.mean() # mean of (1, 2) | |
| 1.5 | |
| >>> history_buffer.statistics('mean') # access method by string. | |
| 1.5 | |
| Args: | |
| log_history (Sequence): History logs. Defaults to []. | |
| count_history (Sequence): Counts of history logs. Defaults to []. | |
| max_length (int): The max length of history logs. Defaults to 1000000. | |
| """ | |
| _statistics_methods: dict = dict() | |
| def __init__(self, | |
| log_history: Sequence = [], | |
| count_history: Sequence = [], | |
| max_length: int = 1000000): | |
| self.max_length = max_length | |
| self._set_default_statistics() | |
| assert len(log_history) == len(count_history), \ | |
| 'The lengths of log_history and count_histroy should be equal' | |
| if len(log_history) > max_length: | |
| warnings.warn(f'The length of history buffer({len(log_history)}) ' | |
| f'exceeds the max_length({max_length}), the first ' | |
| 'few elements will be ignored.') | |
| self._log_history = np.array(log_history[-max_length:]) | |
| self._count_history = np.array(count_history[-max_length:]) | |
| else: | |
| self._log_history = np.array(log_history) | |
| self._count_history = np.array(count_history) | |
| def _set_default_statistics(self) -> None: | |
| """Register default statistic methods: min, max, current and mean.""" | |
| self._statistics_methods.setdefault('min', HistoryBuffer.min) | |
| self._statistics_methods.setdefault('max', HistoryBuffer.max) | |
| self._statistics_methods.setdefault('current', HistoryBuffer.current) | |
| self._statistics_methods.setdefault('mean', HistoryBuffer.mean) | |
| def update(self, log_val: Union[int, float], count: int = 1) -> None: | |
| """update the log history. | |
| If the length of the buffer exceeds ``self._max_length``, the oldest | |
| element will be removed from the buffer. | |
| Args: | |
| log_val (int or float): The value of log. | |
| count (int): The accumulation times of log, defaults to 1. | |
| ``count`` will be used in smooth statistics. | |
| """ | |
| if (not isinstance(log_val, (int, float)) | |
| or not isinstance(count, (int, float))): | |
| raise TypeError(f'log_val must be int or float but got ' | |
| f'{type(log_val)}, count must be int but got ' | |
| f'{type(count)}') | |
| self._log_history = np.append(self._log_history, log_val) | |
| self._count_history = np.append(self._count_history, count) | |
| if len(self._log_history) > self.max_length: | |
| self._log_history = self._log_history[-self.max_length:] | |
| self._count_history = self._count_history[-self.max_length:] | |
| def data(self) -> Tuple[np.ndarray, np.ndarray]: | |
| """Get the ``_log_history`` and ``_count_history``. | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray]: History logs and the counts of | |
| the history logs. | |
| """ | |
| return self._log_history, self._count_history | |
| def register_statistics(cls, method: Callable) -> Callable: | |
| """Register custom statistics method to ``_statistics_methods``. | |
| The registered method can be called by ``history_buffer.statistics`` | |
| with corresponding method name and arguments. | |
| Examples: | |
| >>> @HistoryBuffer.register_statistics | |
| >>> def weighted_mean(self, window_size, weight): | |
| >>> assert len(weight) == window_size | |
| >>> return (self._log_history[-window_size:] * | |
| >>> np.array(weight)).sum() / \ | |
| >>> self._count_history[-window_size:] | |
| >>> log_buffer = HistoryBuffer([1, 2], [1, 1]) | |
| >>> log_buffer.statistics('weighted_mean', 2, [2, 1]) | |
| 2 | |
| Args: | |
| method (Callable): Custom statistics method. | |
| Returns: | |
| Callable: Original custom statistics method. | |
| """ | |
| method_name = method.__name__ | |
| assert method_name not in cls._statistics_methods, \ | |
| 'method_name cannot be registered twice!' | |
| cls._statistics_methods[method_name] = method | |
| return method | |
| def statistics(self, method_name: str, *arg, **kwargs) -> Any: | |
| """Access statistics method by name. | |
| Args: | |
| method_name (str): Name of method. | |
| Returns: | |
| Any: Depends on corresponding method. | |
| """ | |
| if method_name not in self._statistics_methods: | |
| raise KeyError(f'{method_name} has not been registered in ' | |
| 'HistoryBuffer._statistics_methods') | |
| method = self._statistics_methods[method_name] | |
| # Provide self arguments for registered functions. | |
| return method(self, *arg, **kwargs) | |
| def mean(self, window_size: Optional[int] = None) -> np.ndarray: | |
| """Return the mean of the latest ``window_size`` values in log | |
| histories. | |
| If ``window_size is None`` or ``window_size > len(self._log_history)``, | |
| return the global mean value of history logs. | |
| Args: | |
| window_size (int, optional): Size of statistics window. | |
| Returns: | |
| np.ndarray: Mean value within the window. | |
| """ | |
| if window_size is not None: | |
| assert isinstance(window_size, int), \ | |
| 'The type of window size should be int, but got ' \ | |
| f'{type(window_size)}' | |
| else: | |
| window_size = len(self._log_history) | |
| logs_sum = self._log_history[-window_size:].sum() | |
| counts_sum = self._count_history[-window_size:].sum() | |
| return logs_sum / counts_sum | |
| def max(self, window_size: Optional[int] = None) -> np.ndarray: | |
| """Return the maximum value of the latest ``window_size`` values in log | |
| histories. | |
| If ``window_size is None`` or ``window_size > len(self._log_history)``, | |
| return the global maximum value of history logs. | |
| Args: | |
| window_size (int, optional): Size of statistics window. | |
| Returns: | |
| np.ndarray: The maximum value within the window. | |
| """ | |
| if window_size is not None: | |
| assert isinstance(window_size, int), \ | |
| 'The type of window size should be int, but got ' \ | |
| f'{type(window_size)}' | |
| else: | |
| window_size = len(self._log_history) | |
| return self._log_history[-window_size:].max() | |
| def min(self, window_size: Optional[int] = None) -> np.ndarray: | |
| """Return the minimum value of the latest ``window_size`` values in log | |
| histories. | |
| If ``window_size is None`` or ``window_size > len(self._log_history)``, | |
| return the global minimum value of history logs. | |
| Args: | |
| window_size (int, optional): Size of statistics window. | |
| Returns: | |
| np.ndarray: The minimum value within the window. | |
| """ | |
| if window_size is not None: | |
| assert isinstance(window_size, int), \ | |
| 'The type of window size should be int, but got ' \ | |
| f'{type(window_size)}' | |
| else: | |
| window_size = len(self._log_history) | |
| return self._log_history[-window_size:].min() | |
| def current(self) -> np.ndarray: | |
| """Return the recently updated values in log histories. | |
| Returns: | |
| np.ndarray: Recently updated values in log histories. | |
| """ | |
| if len(self._log_history) == 0: | |
| raise ValueError('HistoryBuffer._log_history is an empty array! ' | |
| 'please call update first') | |
| return self._log_history[-1] | |
| def __getstate__(self) -> dict: | |
| """Make ``_statistics_methods`` can be resumed. | |
| Returns: | |
| dict: State dict including statistics_methods. | |
| """ | |
| self.__dict__.update(statistics_methods=self._statistics_methods) | |
| return self.__dict__ | |
| def __setstate__(self, state): | |
| """Try to load ``_statistics_methods`` from state. | |
| Args: | |
| state (dict): State dict. | |
| """ | |
| statistics_methods = state.pop('statistics_methods', {}) | |
| self._set_default_statistics() | |
| self._statistics_methods.update(statistics_methods) | |
| self.__dict__.update(state) | |