Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import warnings | |
| from math import inf, isfinite | |
| from typing import Optional, Tuple, Union | |
| from mmengine.registry import HOOKS | |
| from .hook import Hook | |
| DATA_BATCH = Optional[Union[dict, tuple, list]] | |
| class EarlyStoppingHook(Hook): | |
| """Early stop the training when the monitored metric reached a plateau. | |
| Args: | |
| monitor (str): The monitored metric key to decide early stopping. | |
| rule (str, optional): Comparison rule. Options are 'greater', | |
| 'less'. Defaults to None. | |
| min_delta (float, optional): Minimum difference to continue the | |
| training. Defaults to 0.01. | |
| strict (bool, optional): Whether to crash the training when `monitor` | |
| is not found in the `metrics`. Defaults to False. | |
| check_finite: Whether to stop training when the monitor becomes NaN or | |
| infinite. Defaults to True. | |
| patience (int, optional): The times of validation with no improvement | |
| after which training will be stopped. Defaults to 5. | |
| stopping_threshold (float, optional): Stop training immediately once | |
| the monitored quantity reaches this threshold. Defaults to None. | |
| Note: | |
| `New in version 0.7.0.` | |
| """ | |
| priority = 'LOWEST' | |
| rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} | |
| _default_greater_keys = [ | |
| 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', | |
| 'mAcc', 'aAcc' | |
| ] | |
| _default_less_keys = ['loss'] | |
| def __init__( | |
| self, | |
| monitor: str, | |
| rule: Optional[str] = None, | |
| min_delta: float = 0.1, | |
| strict: bool = False, | |
| check_finite: bool = True, | |
| patience: int = 5, | |
| stopping_threshold: Optional[float] = None, | |
| ): | |
| self.monitor = monitor | |
| if rule is not None: | |
| if rule not in ['greater', 'less']: | |
| raise ValueError( | |
| '`rule` should be either "greater" or "less", ' | |
| f'but got {rule}') | |
| else: | |
| rule = self._init_rule(monitor) | |
| self.rule = rule | |
| self.min_delta = min_delta if rule == 'greater' else -1 * min_delta | |
| self.strict = strict | |
| self.check_finite = check_finite | |
| self.patience = patience | |
| self.stopping_threshold = stopping_threshold | |
| self.wait_count = 0 | |
| self.best_score = -inf if rule == 'greater' else inf | |
| def _init_rule(self, monitor: str) -> str: | |
| greater_keys = {key.lower() for key in self._default_greater_keys} | |
| less_keys = {key.lower() for key in self._default_less_keys} | |
| monitor_lc = monitor.lower() | |
| if monitor_lc in greater_keys: | |
| rule = 'greater' | |
| elif monitor_lc in less_keys: | |
| rule = 'less' | |
| elif any(key in monitor_lc for key in greater_keys): | |
| rule = 'greater' | |
| elif any(key in monitor_lc for key in less_keys): | |
| rule = 'less' | |
| else: | |
| raise ValueError(f'Cannot infer the rule for {monitor}, thus rule ' | |
| 'must be specified.') | |
| return rule | |
| def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]: | |
| compare = self.rule_map[self.rule] | |
| stop_training = False | |
| reason_message = '' | |
| if self.check_finite and not isfinite(current_score): | |
| stop_training = True | |
| reason_message = (f'Monitored metric {self.monitor} = ' | |
| f'{current_score} is infinite. ' | |
| f'Previous best value was ' | |
| f'{self.best_score:.3f}.') | |
| elif self.stopping_threshold is not None and compare( | |
| current_score, self.stopping_threshold): | |
| stop_training = True | |
| self.best_score = current_score | |
| reason_message = (f'Stopping threshold reached: ' | |
| f'`{self.monitor}` = {current_score} is ' | |
| f'{self.rule} than {self.stopping_threshold}.') | |
| elif compare(self.best_score + self.min_delta, current_score): | |
| self.wait_count += 1 | |
| if self.wait_count >= self.patience: | |
| reason_message = (f'the monitored metric did not improve ' | |
| f'in the last {self.wait_count} records. ' | |
| f'best score: {self.best_score:.3f}. ') | |
| stop_training = True | |
| else: | |
| self.best_score = current_score | |
| self.wait_count = 0 | |
| return stop_training, reason_message | |
| def before_run(self, runner) -> None: | |
| """Check `stop_training` variable in `runner.train_loop`. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| assert hasattr(runner.train_loop, 'stop_training'), \ | |
| '`train_loop` should contain `stop_training` variable.' | |
| def after_val_epoch(self, runner, metrics): | |
| """Decide whether to stop the training process. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| metrics (dict): Evaluation results of all metrics | |
| """ | |
| if self.monitor not in metrics: | |
| if self.strict: | |
| raise RuntimeError( | |
| 'Early stopping conditioned on metric ' | |
| f'`{self.monitor} is not available. Please check available' | |
| f' metrics {metrics}, or set `strict=False` in ' | |
| '`EarlyStoppingHook`.') | |
| warnings.warn( | |
| 'Skip early stopping process since the evaluation ' | |
| f'results ({metrics.keys()}) do not include `monitor` ' | |
| f'({self.monitor}).') | |
| return | |
| current_score = metrics[self.monitor] | |
| stop_training, message = self._check_stop_condition(current_score) | |
| if stop_training: | |
| runner.train_loop.stop_training = True | |
| runner.logger.info(message) | |