Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Any, Dict, Optional, Union | |
| import numpy as np | |
| import torch | |
| from mmengine.registry import HOOKS | |
| from mmengine.utils import get_git_hash | |
| from mmengine.version import __version__ | |
| from .hook import Hook | |
| DATA_BATCH = Optional[Union[dict, tuple, list]] | |
| def _is_scalar(value: Any) -> bool: | |
| """Determine the value is a scalar type value. | |
| Args: | |
| value (Any): value of log. | |
| Returns: | |
| bool: whether the value is a scalar type value. | |
| """ | |
| if isinstance(value, np.ndarray): | |
| return value.size == 1 | |
| elif isinstance(value, (int, float, np.number)): | |
| return True | |
| elif isinstance(value, torch.Tensor): | |
| return value.numel() == 1 | |
| return False | |
| class RuntimeInfoHook(Hook): | |
| """A hook that updates runtime information into message hub. | |
| E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the | |
| training state. Components that cannot access the runner can get runtime | |
| information through the message hub. | |
| """ | |
| priority = 'VERY_HIGH' | |
| def before_run(self, runner) -> None: | |
| """Update metainfo. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| metainfo = dict( | |
| cfg=runner.cfg.pretty_text, | |
| seed=runner.seed, | |
| experiment_name=runner.experiment_name, | |
| mmengine_version=__version__ + get_git_hash()) | |
| runner.message_hub.update_info_dict(metainfo) | |
| self.last_loop_stage = None | |
| def before_train(self, runner) -> None: | |
| """Update resumed training state. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| runner.message_hub.update_info('loop_stage', 'train') | |
| runner.message_hub.update_info('epoch', runner.epoch) | |
| runner.message_hub.update_info('iter', runner.iter) | |
| runner.message_hub.update_info('max_epochs', runner.max_epochs) | |
| runner.message_hub.update_info('max_iters', runner.max_iters) | |
| if hasattr(runner.train_dataloader.dataset, 'metainfo'): | |
| runner.message_hub.update_info( | |
| 'dataset_meta', runner.train_dataloader.dataset.metainfo) | |
| def after_train(self, runner) -> None: | |
| runner.message_hub.pop_info('loop_stage') | |
| def before_train_epoch(self, runner) -> None: | |
| """Update current epoch information before every epoch. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| runner.message_hub.update_info('epoch', runner.epoch) | |
| def before_train_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None) -> None: | |
| """Update current iter and learning rate information before every | |
| iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the train loop. | |
| data_batch (Sequence[dict], optional): Data from dataloader. | |
| Defaults to None. | |
| """ | |
| runner.message_hub.update_info('iter', runner.iter) | |
| lr_dict = runner.optim_wrapper.get_lr() | |
| assert isinstance(lr_dict, dict), ( | |
| '`runner.optim_wrapper.get_lr()` should return a dict ' | |
| 'of learning rate when training with OptimWrapper(single ' | |
| 'optimizer) or OptimWrapperDict(multiple optimizer), ' | |
| f'but got {type(lr_dict)} please check your optimizer ' | |
| 'constructor return an `OptimWrapper` or `OptimWrapperDict` ' | |
| 'instance') | |
| for name, lr in lr_dict.items(): | |
| runner.message_hub.update_scalar(f'train/{name}', lr[0]) | |
| def after_train_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| outputs: Optional[dict] = None) -> None: | |
| """Update ``log_vars`` in model outputs every iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the train loop. | |
| data_batch (Sequence[dict], optional): Data from dataloader. | |
| Defaults to None. | |
| outputs (dict, optional): Outputs from model. Defaults to None. | |
| """ | |
| if outputs is not None: | |
| for key, value in outputs.items(): | |
| if key.startswith('vis_'): | |
| continue | |
| runner.message_hub.update_scalar(f'train/{key}', value) | |
| def before_val(self, runner) -> None: | |
| self.last_loop_stage = runner.message_hub.get_info('loop_stage') | |
| runner.message_hub.update_info('loop_stage', 'val') | |
| def after_val_epoch(self, | |
| runner, | |
| metrics: Optional[Dict[str, float]] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each validation epoch. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| metrics (Dict[str, float], optional): Evaluation results of all | |
| metrics on validation dataset. The keys are the names of the | |
| metrics, and the values are corresponding results. | |
| """ | |
| if metrics is not None: | |
| for key, value in metrics.items(): | |
| if _is_scalar(value): | |
| runner.message_hub.update_scalar(f'val/{key}', value) | |
| else: | |
| runner.message_hub.update_info(f'val/{key}', value) | |
| def after_val(self, runner) -> None: | |
| # ValLoop may be called within the TrainLoop, so we need to reset | |
| # the loop_stage | |
| # workflow: before_train -> before_val -> after_val -> after_train | |
| if self.last_loop_stage == 'train': | |
| runner.message_hub.update_info('loop_stage', self.last_loop_stage) | |
| self.last_loop_stage = None | |
| else: | |
| runner.message_hub.pop_info('loop_stage') | |
| def before_test(self, runner) -> None: | |
| runner.message_hub.update_info('loop_stage', 'test') | |
| def after_test(self, runner) -> None: | |
| runner.message_hub.pop_info('loop_stage') | |
| def after_test_epoch(self, | |
| runner, | |
| metrics: Optional[Dict[str, float]] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each test epoch. | |
| Args: | |
| runner (Runner): The runner of the testing process. | |
| metrics (Dict[str, float], optional): Evaluation results of all | |
| metrics on test dataset. The keys are the names of the | |
| metrics, and the values are corresponding results. | |
| """ | |
| if metrics is not None: | |
| for key, value in metrics.items(): | |
| if _is_scalar(value): | |
| runner.message_hub.update_scalar(f'test/{key}', value) | |
| else: | |
| runner.message_hub.update_info(f'test/{key}', value) | |