Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import time | |
| from typing import Callable, Dict, List, Optional, Union | |
| import torch.nn as nn | |
| import mmengine | |
| from mmengine.device import get_device | |
| from mmengine.model import revert_sync_batchnorm | |
| from mmengine.optim import BaseOptimWrapper, _ParamScheduler | |
| from mmengine.registry import STRATEGIES | |
| from mmengine.utils import get_git_hash | |
| from .base import BaseStrategy | |
| class SingleDeviceStrategy(BaseStrategy): | |
| """Strategy for single device training.""" | |
| def prepare( | |
| self, | |
| model: Union[nn.Module, dict], | |
| *, | |
| optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, | |
| param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, | |
| compile: Union[dict, bool] = False, | |
| dispatch_kwargs: Optional[dict] = None, | |
| ): | |
| """Prepare model and some components. | |
| Args: | |
| model (:obj:`torch.nn.Module` or dict): The model to be run. It | |
| can be a dict used for build a model. | |
| Keyword Args: | |
| optim_wrapper (BaseOptimWrapper or dict, optional): Computing the | |
| gradient of model parameters and updating them. | |
| Defaults to None. | |
| See :meth:`build_optim_wrapper` for examples. | |
| param_scheduler (_ParamScheduler or dict or list, optional): | |
| Parameter scheduler for updating optimizer parameters. If | |
| specified, :attr:`optim_wrapper` should also be specified. | |
| Defaults to None. | |
| See :meth:`build_param_scheduler` for examples. | |
| compile (dict, optional): Config to compile model. | |
| Defaults to False. Requires PyTorch>=2.0. | |
| dispatch_kwargs (dict, optional): Kwargs to be passed to other | |
| methods of Strategy. Defaults to None. | |
| If ``accumulative_counts`` is set in ``optim_wrapper``, you | |
| need to provide ``max_iters`` in ``dispatch_kwargs``. | |
| """ | |
| if self._prepared: | |
| return self._prepared_components() | |
| if dispatch_kwargs is not None: | |
| self.dispatch_kwargs.update(dispatch_kwargs) | |
| model = self.build_model(model) | |
| model = self._init_model_weights(model) | |
| model = self._wrap_model(model) | |
| model = self.compile_model(model, compile=compile) | |
| self.model = model | |
| if optim_wrapper is not None: | |
| self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) | |
| if param_scheduler is not None: | |
| self.param_schedulers = self.build_param_scheduler( | |
| param_scheduler, self.optim_wrapper) | |
| if optim_wrapper is not None: | |
| self._scale_lr() | |
| accumulative_counts = getattr(self.optim_wrapper, | |
| '_accumulative_counts', 1) | |
| if accumulative_counts > 1: | |
| if 'max_iters' not in self.dispatch_kwargs: | |
| raise ValueError( | |
| '"max_iters" must be specified because ' | |
| '"accumulative_counts" was set as ' | |
| f'{accumulative_counts} which is greater than 1.') | |
| self.optim_wrapper.initialize_count_status( # type: ignore | |
| self.model, 0, self.dispatch_kwargs['max_iters']) | |
| self._prepared = True | |
| return self._prepared_components() | |
| def _wrap_model(self, model: nn.Module) -> nn.Module: | |
| model = self.convert_model(model) | |
| current_device = get_device() | |
| return model.to(current_device) | |
| def convert_model(self, model: nn.Module) -> nn.Module: | |
| """Convert layers of model. | |
| convert all ``SyncBatchNorm`` (SyncBN) and | |
| ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers in the model to | |
| ``BatchNormXd`` layers. | |
| Args: | |
| model (nn.Module): Model to convert. | |
| """ | |
| self.logger.info( | |
| 'Distributed training is not used, all SyncBatchNorm (SyncBN) ' | |
| 'layers in the model will be automatically reverted to ' | |
| 'BatchNormXd layers if they are used.') | |
| model = revert_sync_batchnorm(model) | |
| return model | |
| def load_checkpoint( | |
| self, | |
| filename: str, | |
| *, | |
| map_location: Union[str, Callable] = 'cpu', | |
| strict: bool = False, | |
| revise_keys: list = [(r'^module.', '')], | |
| callback: Optional[Callable] = None, | |
| ) -> dict: | |
| """Load checkpoint from given ``filename``. | |
| Args: | |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
| ``open-mmlab://xxx``. | |
| Keyword Args: | |
| map_location (str or callable): A string or a callable function to | |
| specifying how to remap storage locations. | |
| Defaults to 'cpu'. | |
| strict (bool): strict (bool): Whether to allow different params for | |
| the model and checkpoint. | |
| revise_keys (list): A list of customized keywords to modify the | |
| state_dict in checkpoint. Each item is a (pattern, replacement) | |
| pair of the regular expression operations. Defaults to strip | |
| the prefix 'module.' by [(r'^module\\.', '')]. | |
| callback (callable, callable): Callback function to modify the | |
| checkpoint after loading the checkpoint. | |
| Defaults to None. | |
| """ | |
| from mmengine.runner.checkpoint import _load_checkpoint | |
| self.logger.info(f'Load checkpoint from {filename}') | |
| if map_location == 'default': | |
| device = get_device() | |
| checkpoint = _load_checkpoint(filename, map_location=device) | |
| else: | |
| checkpoint = _load_checkpoint(filename, map_location=map_location) | |
| # users can do some modification after loading checkpoint | |
| if callback is not None: | |
| callback(checkpoint) | |
| state_dict = checkpoint.pop('state_dict') | |
| self.load_model_state_dict( | |
| state_dict, strict=strict, revise_keys=revise_keys) | |
| return checkpoint | |
| def resume( | |
| self, | |
| filename: str, | |
| *, | |
| resume_optimizer: bool = True, | |
| resume_param_scheduler: bool = True, | |
| map_location: Union[str, Callable] = 'default', | |
| callback: Optional[Callable] = None, | |
| ) -> dict: | |
| """Resume training from given ``filename``. | |
| Four types of states will be resumed. | |
| - model state | |
| - optimizer state | |
| - scheduler state | |
| - randomness state | |
| Args: | |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
| ``open-mmlab://xxx``. | |
| Keyword Args: | |
| resume_optimizer (bool): Whether to resume optimizer state. | |
| Defaults to True. | |
| resume_param_scheduler (bool): Whether to resume param scheduler | |
| state. Defaults to True. | |
| map_location (str or callable):A string or a callable function to | |
| specifying how to remap storage locations. | |
| Defaults to 'default'. | |
| callback (callable, callable): Callback function to modify the | |
| checkpoint before saving the checkpoint. | |
| Defaults to None. | |
| """ | |
| self.logger.info(f'Resume checkpoint from {filename}') | |
| checkpoint = self.load_checkpoint( | |
| filename, map_location=map_location, callback=callback) | |
| if resume_optimizer: | |
| self.load_optim_state_dict(checkpoint.pop('optimizer')) | |
| if resume_param_scheduler and hasattr(self, 'param_schedulers'): | |
| self.load_scheduler_state_dict(checkpoint.pop('param_schedulers')) | |
| # resume random seed | |
| resumed_seed = checkpoint['meta'].get('seed', None) | |
| current_seed = self._randomness.get('seed') | |
| if resumed_seed is not None and resumed_seed != current_seed: | |
| if current_seed is not None: | |
| self.logger.warning(f'The value of random seed in the ' | |
| f'checkpoint "{resumed_seed}" is ' | |
| f'different from the value in ' | |
| f'`randomness` config "{current_seed}"') | |
| self._randomness.update(seed=resumed_seed) | |
| self._set_randomness(**self._randomness) | |
| # resume iter | |
| cur_iter = checkpoint['meta']['iter'] | |
| if hasattr(self, 'optim_wrapper'): | |
| accumulative_counts = getattr(self.optim_wrapper, | |
| '_accumulative_counts', 1) | |
| if accumulative_counts > 1: | |
| if 'max_iters' not in self.dispatch_kwargs: | |
| raise ValueError( | |
| '"max_iters" must be specified because ' | |
| '"accumulative_counts" was set as ' | |
| f'{accumulative_counts} which is greater than 1.') | |
| # Initiate inner count of `optim_wrapper`. | |
| self.optim_wrapper.initialize_count_status( # type: ignore | |
| self.model, cur_iter, self.dispatch_kwargs['max_iters']) | |
| return checkpoint | |
| def save_checkpoint( | |
| self, | |
| filename: str, | |
| *, | |
| save_optimizer: bool = True, | |
| save_param_scheduler: bool = True, | |
| extra_ckpt: Optional[dict] = None, | |
| callback: Optional[Callable] = None, | |
| ) -> None: | |
| """Save checkpoint to given ``filename``. | |
| Args: | |
| filename (str): Filename to save checkpoint. | |
| Keyword Args: | |
| save_optimizer (bool): Whether to save the optimizer to | |
| the checkpoint. Defaults to True. | |
| save_param_scheduler (bool): Whether to save the param_scheduler | |
| to the checkpoint. Defaults to True. | |
| extra_ckpt (dict, optional): Extra checkpoint to save. | |
| Defaults to None. | |
| callback (callable, callable): Callback function to modify the | |
| checkpoint before saving the checkpoint. | |
| Defaults to None. | |
| """ | |
| from mmengine.runner.checkpoint import save_checkpoint | |
| state_dict: dict = dict() | |
| state_dict['state_dict'] = self.model_state_dict() | |
| # save optimizer state dict | |
| if save_optimizer and hasattr(self, 'optim_wrapper'): | |
| state_dict['optimizer'] = self.optim_state_dict() | |
| if save_param_scheduler and hasattr(self, 'param_schedulers'): | |
| state_dict['param_schedulers'] = self.scheduler_state_dict() | |
| # save extra checkpoint passed by users | |
| if extra_ckpt is None: | |
| extra_ckpt = dict() | |
| if 'meta' not in extra_ckpt: | |
| extra_ckpt['meta'] = dict() | |
| extra_ckpt['meta'].update( | |
| seed=self.seed, | |
| time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), | |
| mmengine=mmengine.__version__ + get_git_hash(), | |
| ) | |
| state_dict.update(extra_ckpt) | |
| # users can do some modification before saving checkpoint | |
| if callback is not None: | |
| callback(state_dict) | |
| save_checkpoint(state_dict, filename) | |