Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Dict, Optional, Sequence, Union | |
| from mmengine import is_method_overridden | |
| DATA_BATCH = Optional[Union[dict, tuple, list]] | |
| class Hook: | |
| """Base hook class. | |
| All hooks should inherit from this class. | |
| """ | |
| priority = 'NORMAL' | |
| stages = ('before_run', 'after_load_checkpoint', 'before_train', | |
| 'before_train_epoch', 'before_train_iter', 'after_train_iter', | |
| 'after_train_epoch', 'before_val', 'before_val_epoch', | |
| 'before_val_iter', 'after_val_iter', 'after_val_epoch', | |
| 'after_val', 'before_save_checkpoint', 'after_train', | |
| 'before_test', 'before_test_epoch', 'before_test_iter', | |
| 'after_test_iter', 'after_test_epoch', 'after_test', 'after_run') | |
| def before_run(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before the training validation or testing process. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| """ | |
| def after_run(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before the training validation or testing process. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| """ | |
| def before_train(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before train. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| def after_train(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after train. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| def before_val(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before validation. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| """ | |
| def after_val(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after validation. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| """ | |
| def before_test(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before testing. | |
| Args: | |
| runner (Runner): The runner of the testing process. | |
| """ | |
| def after_test(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after testing. | |
| Args: | |
| runner (Runner): The runner of the testing process. | |
| """ | |
| def before_save_checkpoint(self, runner, checkpoint: dict) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before saving the checkpoint. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| checkpoint (dict): Model's checkpoint. | |
| """ | |
| def after_load_checkpoint(self, runner, checkpoint: dict) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after loading the checkpoint. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| checkpoint (dict): Model's checkpoint. | |
| """ | |
| def before_train_epoch(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each training epoch. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| self._before_epoch(runner, mode='train') | |
| def before_val_epoch(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each validation epoch. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| """ | |
| self._before_epoch(runner, mode='val') | |
| def before_test_epoch(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each test epoch. | |
| Args: | |
| runner (Runner): The runner of the testing process. | |
| """ | |
| self._before_epoch(runner, mode='test') | |
| def after_train_epoch(self, runner) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each training epoch. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| """ | |
| self._after_epoch(runner, mode='train') | |
| def after_val_epoch(self, | |
| runner, | |
| metrics: Optional[Dict[str, float]] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each validation epoch. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| metrics (Dict[str, float], optional): Evaluation results of all | |
| metrics on validation dataset. The keys are the names of the | |
| metrics, and the values are corresponding results. | |
| """ | |
| self._after_epoch(runner, mode='val') | |
| def after_test_epoch(self, | |
| runner, | |
| metrics: Optional[Dict[str, float]] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each test epoch. | |
| Args: | |
| runner (Runner): The runner of the testing process. | |
| metrics (Dict[str, float], optional): Evaluation results of all | |
| metrics on test dataset. The keys are the names of the | |
| metrics, and the values are corresponding results. | |
| """ | |
| self._after_epoch(runner, mode='test') | |
| def before_train_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each training iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the train loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| """ | |
| self._before_iter( | |
| runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') | |
| def before_val_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each validation iteration. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict, optional): Data from dataloader. | |
| Defaults to None. | |
| """ | |
| self._before_iter( | |
| runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') | |
| def before_test_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each test iteration. | |
| Args: | |
| runner (Runner): The runner of the testing process. | |
| batch_idx (int): The index of the current batch in the test loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| Defaults to None. | |
| """ | |
| self._before_iter( | |
| runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') | |
| def after_train_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| outputs: Optional[dict] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each training iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the train loop. | |
| data_batch (dict tuple or list, optional): Data from dataloader. | |
| outputs (dict, optional): Outputs from model. | |
| """ | |
| self._after_iter( | |
| runner, | |
| batch_idx=batch_idx, | |
| data_batch=data_batch, | |
| outputs=outputs, | |
| mode='train') | |
| def after_val_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| outputs: Optional[Sequence] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each validation iteration. | |
| Args: | |
| runner (Runner): The runner of the validation process. | |
| batch_idx (int): The index of the current batch in the val loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| outputs (Sequence, optional): Outputs from model. | |
| """ | |
| self._after_iter( | |
| runner, | |
| batch_idx=batch_idx, | |
| data_batch=data_batch, | |
| outputs=outputs, | |
| mode='val') | |
| def after_test_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| outputs: Optional[Sequence] = None) -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each test iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| batch_idx (int): The index of the current batch in the test loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| outputs (Sequence, optional): Outputs from model. | |
| """ | |
| self._after_iter( | |
| runner, | |
| batch_idx=batch_idx, | |
| data_batch=data_batch, | |
| outputs=outputs, | |
| mode='test') | |
| def _before_epoch(self, runner, mode: str = 'train') -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each epoch. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| mode (str): Current mode of runner. Defaults to 'train'. | |
| """ | |
| def _after_epoch(self, runner, mode: str = 'train') -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each epoch. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| mode (str): Current mode of runner. Defaults to 'train'. | |
| """ | |
| def _before_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| mode: str = 'train') -> None: | |
| """All subclasses should override this method, if they need any | |
| operations before each iter. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| batch_idx (int): The index of the current batch in the loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| mode (str): Current mode of runner. Defaults to 'train'. | |
| """ | |
| def _after_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: DATA_BATCH = None, | |
| outputs: Optional[Union[Sequence, dict]] = None, | |
| mode: str = 'train') -> None: | |
| """All subclasses should override this method, if they need any | |
| operations after each epoch. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| batch_idx (int): The index of the current batch in the loop. | |
| data_batch (dict or tuple or list, optional): Data from dataloader. | |
| outputs (dict or Sequence, optional): Outputs from model. | |
| mode (str): Current mode of runner. Defaults to 'train'. | |
| """ | |
| def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: | |
| """Test whether current epoch can be evenly divided by n. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| n (int): Whether current epoch can be evenly divided by n. | |
| start (int): Starting from `start` to check the logic for | |
| every n epochs. Defaults to 0. | |
| Returns: | |
| bool: Whether current epoch can be evenly divided by n. | |
| """ | |
| dividend = runner.epoch + 1 - start | |
| return dividend % n == 0 if dividend >= 0 and n > 0 else False | |
| def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: | |
| """Test whether current inner iteration can be evenly divided by n. | |
| Args: | |
| batch_idx (int): Current batch index of the training, validation | |
| or testing loop. | |
| n (int): Whether current inner iteration can be evenly | |
| divided by n. | |
| Returns: | |
| bool: Whether current inner iteration can be evenly | |
| divided by n. | |
| """ | |
| return (batch_idx + 1) % n == 0 if n > 0 else False | |
| def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: | |
| """Test whether current training iteration can be evenly divided by n. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| n (int): Whether current iteration can be evenly divided by n. | |
| start (int): Starting from `start` to check the logic for | |
| every n iterations. Defaults to 0. | |
| Returns: | |
| bool: Return True if the current iteration can be evenly divided | |
| by n, otherwise False. | |
| """ | |
| dividend = runner.iter + 1 - start | |
| return dividend % n == 0 if dividend >= 0 and n > 0 else False | |
| def end_of_epoch(self, dataloader, batch_idx: int) -> bool: | |
| """Check whether the current iteration reaches the last iteration of | |
| the dataloader. | |
| Args: | |
| dataloader (Dataloader): The dataloader of the training, | |
| validation or testing process. | |
| batch_idx (int): The index of the current batch in the loop. | |
| Returns: | |
| bool: Whether reaches the end of current epoch or not. | |
| """ | |
| return batch_idx + 1 == len(dataloader) | |
| def is_last_train_epoch(self, runner) -> bool: | |
| """Test whether current epoch is the last train epoch. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| Returns: | |
| bool: Whether reaches the end of training epoch. | |
| """ | |
| return runner.epoch + 1 == runner.max_epochs | |
| def is_last_train_iter(self, runner) -> bool: | |
| """Test whether current iteration is the last train iteration. | |
| Args: | |
| runner (Runner): The runner of the training process. | |
| Returns: | |
| bool: Whether current iteration is the last train iteration. | |
| """ | |
| return runner.iter + 1 == runner.max_iters | |
| def get_triggered_stages(self) -> list: | |
| """Get all triggered stages with method name of the hook. | |
| Returns: | |
| list: List of triggered stages. | |
| """ | |
| trigger_stages = set() | |
| for stage in Hook.stages: | |
| if is_method_overridden(stage, Hook, self): | |
| trigger_stages.add(stage) | |
| # some methods will be triggered in multi stages | |
| # use this dict to map method to stages. | |
| method_stages_map = { | |
| '_before_epoch': | |
| ['before_train_epoch', 'before_val_epoch', 'before_test_epoch'], | |
| '_after_epoch': | |
| ['after_train_epoch', 'after_val_epoch', 'after_test_epoch'], | |
| '_before_iter': | |
| ['before_train_iter', 'before_val_iter', 'before_test_iter'], | |
| '_after_iter': | |
| ['after_train_iter', 'after_val_iter', 'after_test_iter'], | |
| } | |
| for method, map_stages in method_stages_map.items(): | |
| if is_method_overridden(method, Hook, self): | |
| trigger_stages.update(map_stages) | |
| return list(trigger_stages) | |