Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import hashlib | |
| import logging | |
| import os.path as osp | |
| import pickle | |
| from collections import deque | |
| from math import inf | |
| from pathlib import Path | |
| from typing import Callable, Dict, List, Optional, Sequence, Union | |
| from mmengine.dist import is_main_process, master_only | |
| from mmengine.fileio import FileClient, get_file_backend | |
| from mmengine.logging import print_log | |
| from mmengine.registry import HOOKS | |
| from mmengine.utils import is_list_of, is_seq_of | |
| from .hook import Hook | |
| DATA_BATCH = Optional[Union[dict, tuple, list]] | |
| class CheckpointHook(Hook): | |
| """Save checkpoints periodically. | |
| Args: | |
| interval (int): The saving period. If ``by_epoch=True``, interval | |
| indicates epochs, otherwise it indicates iterations. | |
| Defaults to -1, which means "never". | |
| by_epoch (bool): Saving checkpoints by epoch or by iteration. | |
| Defaults to True. | |
| save_optimizer (bool): Whether to save optimizer state_dict in the | |
| checkpoint. It is usually used for resuming experiments. | |
| Defaults to True. | |
| save_param_scheduler (bool): Whether to save param_scheduler state_dict | |
| in the checkpoint. It is usually used for resuming experiments. | |
| Defaults to True. | |
| out_dir (str, Path, Optional): The root directory to save checkpoints. | |
| If not specified, ``runner.work_dir`` will be used by default. If | |
| specified, the ``out_dir`` will be the concatenation of ``out_dir`` | |
| and the last level directory of ``runner.work_dir``. For example, | |
| if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is | |
| ``./work_dir/cur_exp``, then the ckpt will be saved in | |
| ``./tmp/cur_exp``. Defaults to None. | |
| max_keep_ckpts (int): The maximum checkpoints to keep. | |
| In some cases we want only the latest few checkpoints and would | |
| like to delete old ones to save the disk space. | |
| Defaults to -1, which means unlimited. | |
| save_last (bool): Whether to force the last checkpoint to be | |
| saved regardless of interval. Defaults to True. | |
| save_best (str, List[str], optional): If a metric is specified, it | |
| would measure the best checkpoint during evaluation. If a list of | |
| metrics is passed, it would measure a group of best checkpoints | |
| corresponding to the passed metrics. The information about best | |
| checkpoint(s) would be saved in ``runner.message_hub`` to keep | |
| best score value and best checkpoint path, which will be also | |
| loaded when resuming checkpoint. Options are the evaluation metrics | |
| on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox | |
| detection and instance segmentation. ``AR@100`` for proposal | |
| recall. If ``save_best`` is ``auto``, the first key of the returned | |
| ``OrderedDict`` result will be used. Defaults to None. | |
| rule (str, List[str], optional): Comparison rule for best score. If | |
| set to None, it will infer a reasonable rule. Keys such as 'acc', | |
| 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' | |
| will be inferred by 'less' rule. If ``save_best`` is a list of | |
| metrics and ``rule`` is a str, all metrics in ``save_best`` will | |
| share the comparison rule. If ``save_best`` and ``rule`` are both | |
| lists, their length must be the same, and metrics in ``save_best`` | |
| will use the corresponding comparison rule in ``rule``. Options | |
| are 'greater', 'less', None and list which contains 'greater' and | |
| 'less'. Defaults to None. | |
| greater_keys (List[str], optional): Metric keys that will be | |
| inferred by 'greater' comparison rule. If ``None``, | |
| _default_greater_keys will be used. Defaults to None. | |
| less_keys (List[str], optional): Metric keys that will be | |
| inferred by 'less' comparison rule. If ``None``, _default_less_keys | |
| will be used. Defaults to None. | |
| file_client_args (dict, optional): Arguments to instantiate a | |
| FileClient. See :class:`mmengine.fileio.FileClient` for details. | |
| Defaults to None. It will be deprecated in future. Please use | |
| ``backend_args`` instead. | |
| filename_tmpl (str, optional): String template to indicate checkpoint | |
| name. If specified, must contain one and only one "{}", which will | |
| be replaced with ``epoch + 1`` if ``by_epoch=True`` else | |
| ``iteration + 1``. | |
| Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" | |
| accordingly. | |
| backend_args (dict, optional): Arguments to instantiate the | |
| prefix of uri corresponding backend. Defaults to None. | |
| `New in version 0.2.0.` | |
| published_keys (str, List[str], optional): If ``save_last`` is ``True`` | |
| or ``save_best`` is not ``None``, it will automatically | |
| publish model with keys in the list after training. | |
| Defaults to None. | |
| `New in version 0.7.1.` | |
| save_begin (int): Control the epoch number or iteration number | |
| at which checkpoint saving begins. Defaults to 0, which means | |
| saving at the beginning. | |
| `New in version 0.8.3.` | |
| Examples: | |
| >>> # Save best based on single metric | |
| >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', | |
| >>> rule='less') | |
| >>> # Save best based on multi metrics with the same comparison rule | |
| >>> CheckpointHook(interval=2, by_epoch=True, | |
| >>> save_best=['acc', 'mIoU'], rule='greater') | |
| >>> # Save best based on multi metrics with different comparison rule | |
| >>> CheckpointHook(interval=2, by_epoch=True, | |
| >>> save_best=['FID', 'IS'], rule=['less', 'greater']) | |
| >>> # Save best based on single metric and publish model after training | |
| >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', | |
| >>> rule='less', published_keys=['meta', 'state_dict']) | |
| """ | |
| out_dir: str | |
| priority = 'VERY_LOW' | |
| # logic to save best checkpoints | |
| # Since the key for determining greater or less is related to the | |
| # downstream tasks, downstream repositories may need to overwrite | |
| # the following inner variables accordingly. | |
| rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} | |
| init_value_map = {'greater': -inf, 'less': inf} | |
| _default_greater_keys = [ | |
| 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', | |
| 'mAcc', 'aAcc' | |
| ] | |
| _default_less_keys = ['loss'] | |
| def __init__(self, | |
| interval: int = -1, | |
| by_epoch: bool = True, | |
| save_optimizer: bool = True, | |
| save_param_scheduler: bool = True, | |
| out_dir: Optional[Union[str, Path]] = None, | |
| max_keep_ckpts: int = -1, | |
| save_last: bool = True, | |
| save_best: Union[str, List[str], None] = None, | |
| rule: Union[str, List[str], None] = None, | |
| greater_keys: Optional[Sequence[str]] = None, | |
| less_keys: Optional[Sequence[str]] = None, | |
| file_client_args: Optional[dict] = None, | |
| filename_tmpl: Optional[str] = None, | |
| backend_args: Optional[dict] = None, | |
| published_keys: Union[str, List[str], None] = None, | |
| save_begin: int = 0, | |
| **kwargs) -> None: | |
| self.interval = interval | |
| self.by_epoch = by_epoch | |
| self.save_optimizer = save_optimizer | |
| self.save_param_scheduler = save_param_scheduler | |
| self.out_dir = out_dir # type: ignore | |
| self.max_keep_ckpts = max_keep_ckpts | |
| self.save_last = save_last | |
| self.args = kwargs | |
| if file_client_args is not None: | |
| print_log( | |
| '"file_client_args" will be deprecated in future. ' | |
| 'Please use "backend_args" instead', | |
| logger='current', | |
| level=logging.WARNING) | |
| if backend_args is not None: | |
| raise ValueError( | |
| '"file_client_args" and "backend_args" cannot be set ' | |
| 'at the same time.') | |
| self.file_client_args = file_client_args | |
| self.backend_args = backend_args | |
| if filename_tmpl is None: | |
| if self.by_epoch: | |
| self.filename_tmpl = 'epoch_{}.pth' | |
| else: | |
| self.filename_tmpl = 'iter_{}.pth' | |
| else: | |
| self.filename_tmpl = filename_tmpl | |
| # save best logic | |
| assert (isinstance(save_best, str) or is_list_of(save_best, str) | |
| or (save_best is None)), ( | |
| '"save_best" should be a str or list of str or None, ' | |
| f'but got {type(save_best)}') | |
| if isinstance(save_best, list): | |
| if 'auto' in save_best: | |
| assert len(save_best) == 1, ( | |
| 'Only support one "auto" in "save_best" list.') | |
| assert len(save_best) == len( | |
| set(save_best)), ('Find duplicate element in "save_best".') | |
| else: | |
| # convert str to list[str] | |
| if save_best is not None: | |
| save_best = [save_best] # type: ignore # noqa: F401 | |
| self.save_best = save_best | |
| # rule logic | |
| assert (isinstance(rule, str) or is_list_of(rule, str) | |
| or (rule is None)), ( | |
| '"rule" should be a str or list of str or None, ' | |
| f'but got {type(rule)}') | |
| if isinstance(rule, list): | |
| # check the length of rule list | |
| assert len(rule) in [ | |
| 1, | |
| len(self.save_best) # type: ignore | |
| ], ('Number of "rule" must be 1 or the same as number of ' | |
| f'"save_best", but got {len(rule)}.') | |
| else: | |
| # convert str/None to list | |
| rule = [rule] # type: ignore # noqa: F401 | |
| if greater_keys is None: | |
| self.greater_keys = self._default_greater_keys | |
| else: | |
| if not isinstance(greater_keys, (list, tuple)): | |
| greater_keys = (greater_keys, ) # type: ignore | |
| assert is_seq_of(greater_keys, str) | |
| self.greater_keys = greater_keys # type: ignore | |
| if less_keys is None: | |
| self.less_keys = self._default_less_keys | |
| else: | |
| if not isinstance(less_keys, (list, tuple)): | |
| less_keys = (less_keys, ) # type: ignore | |
| assert is_seq_of(less_keys, str) | |
| self.less_keys = less_keys # type: ignore | |
| if self.save_best is not None: | |
| self.is_better_than: Dict[str, Callable] = dict() | |
| self._init_rule(rule, self.save_best) | |
| if len(self.key_indicators) == 1: | |
| self.best_ckpt_path: Optional[str] = None | |
| else: | |
| self.best_ckpt_path_dict: Dict = dict() | |
| # published keys | |
| if not (isinstance(published_keys, str) | |
| or is_seq_of(published_keys, str) or published_keys is None): | |
| raise TypeError( | |
| '"published_keys" should be a str or a sequence of str or ' | |
| f'None, but got {type(published_keys)}') | |
| if isinstance(published_keys, str): | |
| published_keys = [published_keys] | |
| elif isinstance(published_keys, (list, tuple)): | |
| assert len(published_keys) == len(set(published_keys)), ( | |
| 'Find duplicate elements in "published_keys".') | |
| self.published_keys = published_keys | |
| self.last_ckpt = None | |
| if save_begin < 0: | |
| raise ValueError( | |
| 'save_begin should not be less than 0, but got {save_begin}') | |
| self.save_begin = save_begin | |
| def before_train(self, runner) -> None: | |
| """Finish all operations, related to checkpoint. | |
| This function will get the appropriate file client, and the directory | |
| to save these checkpoints of the model. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| if self.out_dir is None: | |
| self.out_dir = runner.work_dir | |
| # If self.file_client_args is None, self.file_client will not | |
| # used in CheckpointHook. To avoid breaking backward compatibility, | |
| # it will not be removed util the release of MMEngine1.0 | |
| self.file_client = FileClient.infer_client(self.file_client_args, | |
| self.out_dir) | |
| if self.file_client_args is None: | |
| self.file_backend = get_file_backend( | |
| self.out_dir, backend_args=self.backend_args) | |
| else: | |
| self.file_backend = self.file_client | |
| # if `self.out_dir` is not equal to `runner.work_dir`, it means that | |
| # `self.out_dir` is set so the final `self.out_dir` is the | |
| # concatenation of `self.out_dir` and the last level directory of | |
| # `runner.work_dir` | |
| if self.out_dir != runner.work_dir: | |
| basename = osp.basename(runner.work_dir.rstrip(osp.sep)) | |
| self.out_dir = self.file_backend.join_path( | |
| self.out_dir, basename) # type: ignore # noqa: E501 | |
| runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') | |
| if self.save_best is not None: | |
| if len(self.key_indicators) == 1: | |
| if 'best_ckpt' not in runner.message_hub.runtime_info: | |
| self.best_ckpt_path = None | |
| else: | |
| self.best_ckpt_path = runner.message_hub.get_info( | |
| 'best_ckpt') | |
| else: | |
| for key_indicator in self.key_indicators: | |
| best_ckpt_name = f'best_ckpt_{key_indicator}' | |
| if best_ckpt_name not in runner.message_hub.runtime_info: | |
| self.best_ckpt_path_dict[key_indicator] = None | |
| else: | |
| self.best_ckpt_path_dict[ | |
| key_indicator] = runner.message_hub.get_info( | |
| best_ckpt_name) | |
| if self.max_keep_ckpts > 0: | |
| keep_ckpt_ids = [] | |
| if 'keep_ckpt_ids' in runner.message_hub.runtime_info: | |
| keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') | |
| while len(keep_ckpt_ids) > self.max_keep_ckpts: | |
| step = keep_ckpt_ids.pop(0) | |
| if is_main_process(): | |
| path = self.file_backend.join_path( | |
| self.out_dir, self.filename_tmpl.format(step)) | |
| if self.file_backend.isfile(path): | |
| self.file_backend.remove(path) | |
| elif self.file_backend.isdir(path): | |
| # checkpoints saved by deepspeed are directories | |
| self.file_backend.rmtree(path) | |
| self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, | |
| self.max_keep_ckpts) | |
| def after_train_epoch(self, runner) -> None: | |
| """Save the checkpoint and synchronize buffers after each epoch. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| if not self.by_epoch: | |
| return | |
| # save checkpoint for following cases: | |
| # 1. every ``self.interval`` epochs which start at ``self.save_begin`` | |
| # 2. reach the last epoch of training | |
| if self.every_n_epochs(runner, self.interval, self.save_begin) or ( | |
| self.save_last and self.is_last_train_epoch(runner)): | |
| runner.logger.info( | |
| f'Saving checkpoint at {runner.epoch + 1} epochs') | |
| self._save_checkpoint(runner) | |
| def after_val_epoch(self, runner, metrics): | |
| """Save the checkpoint and synchronize buffers after each evaluation | |
| epoch. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| metrics (dict): Evaluation results of all metrics | |
| """ | |
| if len(metrics) == 0: | |
| runner.logger.warning( | |
| 'Since `metrics` is an empty dict, the behavior to save ' | |
| 'the best checkpoint will be skipped in this evaluation.') | |
| return | |
| self._save_best_checkpoint(runner, metrics) | |
| def after_train(self, runner) -> None: | |
| """Publish the checkpoint after training. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| if self.published_keys is None: | |
| return | |
| if self.save_last and self.last_ckpt is not None: | |
| self._publish_model(runner, self.last_ckpt) | |
| if getattr(self, 'best_ckpt_path', None) is not None: | |
| self._publish_model(runner, str(self.best_ckpt_path)) | |
| if getattr(self, 'best_ckpt_path_dict', None) is not None: | |
| for best_ckpt in self.best_ckpt_path_dict.values(): | |
| self._publish_model(runner, best_ckpt) | |
| def _publish_model(self, runner, ckpt_path: str) -> None: | |
| """Remove unnecessary keys from ckpt_path and save the new checkpoint. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| ckpt_path (str): The checkpoint path that ought to be published. | |
| """ | |
| from mmengine.runner import save_checkpoint | |
| from mmengine.runner.checkpoint import _load_checkpoint | |
| checkpoint = _load_checkpoint(ckpt_path) | |
| assert self.published_keys is not None | |
| removed_keys = [] | |
| for key in list(checkpoint.keys()): | |
| if key not in self.published_keys: | |
| removed_keys.append(key) | |
| checkpoint.pop(key) | |
| if removed_keys: | |
| print_log( | |
| f'Key {removed_keys} will be removed because they are not ' | |
| 'found in published_keys. If you want to keep them, ' | |
| f'please set `{removed_keys}` in published_keys', | |
| logger='current') | |
| checkpoint_data = pickle.dumps(checkpoint) | |
| sha = hashlib.sha256(checkpoint_data).hexdigest() | |
| final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth' | |
| save_checkpoint(checkpoint, final_path) | |
| print_log( | |
| f'The checkpoint ({ckpt_path}) is published to ' | |
| f'{final_path}.', | |
| logger='current') | |
| def _save_checkpoint_with_step(self, runner, step, meta): | |
| # remove other checkpoints before save checkpoint to make the | |
| # self.keep_ckpt_ids are saved as expected | |
| if self.max_keep_ckpts > 0: | |
| # _save_checkpoint and _save_best_checkpoint may call this | |
| # _save_checkpoint_with_step in one epoch | |
| if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: | |
| pass | |
| else: | |
| if len(self.keep_ckpt_ids) == self.max_keep_ckpts: | |
| _step = self.keep_ckpt_ids.popleft() | |
| if is_main_process(): | |
| ckpt_path = self.file_backend.join_path( | |
| self.out_dir, self.filename_tmpl.format(_step)) | |
| if self.file_backend.isfile(ckpt_path): | |
| self.file_backend.remove(ckpt_path) | |
| elif self.file_backend.isdir(ckpt_path): | |
| # checkpoints saved by deepspeed are directories | |
| self.file_backend.rmtree(ckpt_path) | |
| self.keep_ckpt_ids.append(step) | |
| runner.message_hub.update_info('keep_ckpt_ids', | |
| list(self.keep_ckpt_ids)) | |
| ckpt_filename = self.filename_tmpl.format(step) | |
| self.last_ckpt = self.file_backend.join_path(self.out_dir, | |
| ckpt_filename) | |
| runner.message_hub.update_info('last_ckpt', self.last_ckpt) | |
| runner.save_checkpoint( | |
| self.out_dir, | |
| ckpt_filename, | |
| self.file_client_args, | |
| save_optimizer=self.save_optimizer, | |
| save_param_scheduler=self.save_param_scheduler, | |
| meta=meta, | |
| by_epoch=self.by_epoch, | |
| backend_args=self.backend_args, | |
| **self.args) | |
| # Model parallel-like training should involve pulling sharded states | |
| # from all ranks, but skip the following procedure. | |
| if not is_main_process(): | |
| return | |
| save_file = osp.join(runner.work_dir, 'last_checkpoint') | |
| with open(save_file, 'w') as f: | |
| f.write(self.last_ckpt) # type: ignore | |
| def _save_checkpoint(self, runner) -> None: | |
| """Save the current checkpoint and delete outdated checkpoint. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| if self.by_epoch: | |
| step = runner.epoch + 1 | |
| meta = dict(epoch=step, iter=runner.iter) | |
| else: | |
| step = runner.iter + 1 | |
| meta = dict(epoch=runner.epoch, iter=step) | |
| self._save_checkpoint_with_step(runner, step, meta=meta) | |
| def _save_best_checkpoint(self, runner, metrics) -> None: | |
| """Save the current checkpoint and delete outdated checkpoint. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| metrics (dict): Evaluation results of all metrics. | |
| """ | |
| if not self.save_best: | |
| return | |
| if self.by_epoch: | |
| ckpt_filename = self.filename_tmpl.format(runner.epoch) | |
| cur_type, cur_time = 'epoch', runner.epoch | |
| else: | |
| ckpt_filename = self.filename_tmpl.format(runner.iter) | |
| cur_type, cur_time = 'iter', runner.iter | |
| meta = dict(epoch=runner.epoch, iter=runner.iter) | |
| # handle auto in self.key_indicators and self.rules before the loop | |
| if 'auto' in self.key_indicators: | |
| self._init_rule(self.rules, [list(metrics.keys())[0]]) | |
| best_ckpt_updated = False | |
| # save best logic | |
| # get score from messagehub | |
| for key_indicator, rule in zip(self.key_indicators, self.rules): | |
| key_score = metrics[key_indicator] | |
| if len(self.key_indicators) == 1: | |
| best_score_key = 'best_score' | |
| runtime_best_ckpt_key = 'best_ckpt' | |
| best_ckpt_path = self.best_ckpt_path | |
| else: | |
| best_score_key = f'best_score_{key_indicator}' | |
| runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' | |
| best_ckpt_path = self.best_ckpt_path_dict[key_indicator] | |
| if best_score_key not in runner.message_hub.runtime_info: | |
| best_score = self.init_value_map[rule] | |
| else: | |
| best_score = runner.message_hub.get_info(best_score_key) | |
| if key_score is None or not self.is_better_than[key_indicator]( | |
| key_score, best_score): | |
| continue | |
| best_ckpt_updated = True | |
| best_score = key_score | |
| runner.message_hub.update_info(best_score_key, best_score) | |
| if best_ckpt_path and is_main_process(): | |
| is_removed = False | |
| if self.file_backend.isfile(best_ckpt_path): | |
| self.file_backend.remove(best_ckpt_path) | |
| is_removed = True | |
| elif self.file_backend.isdir(best_ckpt_path): | |
| # checkpoints saved by deepspeed are directories | |
| self.file_backend.rmtree(best_ckpt_path) | |
| is_removed = True | |
| if is_removed: | |
| runner.logger.info( | |
| f'The previous best checkpoint {best_ckpt_path} ' | |
| 'is removed') | |
| best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' | |
| # Replace illegal characters for filename with `_` | |
| best_ckpt_name = best_ckpt_name.replace('/', '_') | |
| if len(self.key_indicators) == 1: | |
| self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501 | |
| self.out_dir, best_ckpt_name) | |
| runner.message_hub.update_info(runtime_best_ckpt_key, | |
| self.best_ckpt_path) | |
| else: | |
| self.best_ckpt_path_dict[ | |
| key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501 | |
| self.out_dir, best_ckpt_name) | |
| runner.message_hub.update_info( | |
| runtime_best_ckpt_key, | |
| self.best_ckpt_path_dict[key_indicator]) | |
| runner.save_checkpoint( | |
| self.out_dir, | |
| filename=best_ckpt_name, | |
| file_client_args=self.file_client_args, | |
| save_optimizer=False, | |
| save_param_scheduler=False, | |
| meta=meta, | |
| by_epoch=False, | |
| backend_args=self.backend_args) | |
| runner.logger.info( | |
| f'The best checkpoint with {best_score:0.4f} {key_indicator} ' | |
| f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') | |
| # save checkpoint again to update the best_score and best_ckpt stored | |
| # in message_hub because the checkpoint saved in `after_train_epoch` | |
| # or `after_train_iter` stage only keep the previous best checkpoint | |
| # not the current best checkpoint which causes the current best | |
| # checkpoint can not be removed when resuming training. | |
| if best_ckpt_updated and self.last_ckpt is not None: | |
| self._save_checkpoint_with_step(runner, cur_time, meta) | |
| def _init_rule(self, rules, key_indicators) -> None: | |
| """Initialize rule, key_indicator, comparison_func, and best score. If | |
| key_indicator is a list of string and rule is a string, all metric in | |
| the key_indicator will share the same rule. | |
| Here is the rule to determine which rule is used for key indicator when | |
| the rule is not specific (note that the key indicator matching is case- | |
| insensitive): | |
| 1. If the key indicator is in ``self.greater_keys``, the rule | |
| will be specified as 'greater'. | |
| 2. Or if the key indicator is in ``self.less_keys``, the rule | |
| will be specified as 'less'. | |
| 3. Or if any one item in ``self.greater_keys`` is a substring of | |
| key_indicator, the rule will be specified as 'greater'. | |
| 4. Or if any one item in ``self.less_keys`` is a substring of | |
| key_indicator, the rule will be specified as 'less'. | |
| Args: | |
| rule (List[Optional[str]]): Comparison rule for best score. | |
| key_indicator (List[str]): Key indicator to determine | |
| the comparison rule. | |
| """ | |
| if len(rules) == 1: | |
| rules = rules * len(key_indicators) | |
| self.rules = [] | |
| for rule, key_indicator in zip(rules, key_indicators): | |
| if rule not in self.rule_map and rule is not None: | |
| raise KeyError('rule must be greater, less or None, ' | |
| f'but got {rule}.') | |
| if rule is None and key_indicator != 'auto': | |
| # `_lc` here means we use the lower case of keys for | |
| # case-insensitive matching | |
| key_indicator_lc = key_indicator.lower() | |
| greater_keys = {key.lower() for key in self.greater_keys} | |
| less_keys = {key.lower() for key in self.less_keys} | |
| if key_indicator_lc in greater_keys: | |
| rule = 'greater' | |
| elif key_indicator_lc in less_keys: | |
| rule = 'less' | |
| elif any(key in key_indicator_lc for key in greater_keys): | |
| rule = 'greater' | |
| elif any(key in key_indicator_lc for key in less_keys): | |
| rule = 'less' | |
| else: | |
| raise ValueError('Cannot infer the rule for key ' | |
| f'{key_indicator}, thus a specific rule ' | |
| 'must be specified.') | |
| if rule is not None: | |
| self.is_better_than[key_indicator] = self.rule_map[rule] | |
| self.rules.append(rule) | |
| self.key_indicators = key_indicators | |
| def after_train_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| outputs=Optional[dict]) -> None: | |
| """Save the checkpoint and synchronize buffers after each iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the train loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| outputs (dict, optional): Outputs from model. | |
| """ | |
| if self.by_epoch: | |
| return | |
| # save checkpoint for following cases: | |
| # 1. every ``self.interval`` iterations | |
| # which start at ``self.save_begin`` | |
| # 2. reach the last iteration of training | |
| if self.every_n_train_iters(runner, self.interval, | |
| self.save_begin) or \ | |
| (self.save_last and | |
| self.is_last_train_iter(runner)): | |
| runner.logger.info( | |
| f'Saving checkpoint at {runner.iter + 1} iterations') | |
| self._save_checkpoint(runner) | |