Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from contextlib import contextmanager | |
| from typing import Dict, Iterator, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from .optimizer_wrapper import OptimWrapper | |
| class OptimWrapperDict(OptimWrapper): | |
| """A dictionary container of :obj:`OptimWrapper`. | |
| If runner is training with multiple optimizers, all optimizer wrappers | |
| should be managed by :obj:`OptimWrapperDict` which is built by | |
| ``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save | |
| the state dictionary of all optimizer wrappers. | |
| Consider the semantic ambiguity of calling :meth:``update_params``, | |
| :meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not | |
| implement these methods. | |
| Examples: | |
| >>> import torch.nn as nn | |
| >>> from torch.optim import SGD | |
| >>> from mmengine.optim import OptimWrapperDict, OptimWrapper | |
| >>> model1 = nn.Linear(1, 1) | |
| >>> model2 = nn.Linear(1, 1) | |
| >>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1)) | |
| >>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1)) | |
| >>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1, | |
| >>> model2=optim_wrapper2) | |
| Note: | |
| The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed | |
| in the same way as `dict`. | |
| Args: | |
| **optim_wrappers: A dictionary of ``OptimWrapper`` instance. | |
| """ | |
| def __init__(self, **optim_wrapper_dict: OptimWrapper): | |
| for key, value in optim_wrapper_dict.items(): | |
| assert isinstance(value, OptimWrapper), ( | |
| '`OptimWrapperDict` only accept OptimWrapper instance, ' | |
| f'but got {key}: {type(value)}') | |
| self.optim_wrappers = optim_wrapper_dict | |
| def update_params( # type: ignore | |
| self, | |
| loss: torch.Tensor, | |
| step_kwargs: Optional[Dict] = None, | |
| zero_kwargs: Optional[Dict] = None) -> None: | |
| """Update all optimizer wrappers would lead to a duplicate backward | |
| errors, and OptimWrapperDict does not know which optimizer wrapper | |
| should be updated. | |
| Therefore, this method is not implemented. The optimizer wrapper of | |
| OptimWrapperDict should be accessed and call its `update_params`. | |
| """ | |
| raise NotImplementedError('`update_params` should be called by each ' | |
| 'optimizer separately`') | |
| def backward(self, loss: torch.Tensor, **kwargs) -> None: | |
| """Since OptimWrapperDict doesn't know which optimizer wrapper's | |
| backward method should be called (``loss_scaler`` maybe different in | |
| different :obj:AmpOptimWrapper), this method is not implemented. | |
| The optimizer wrapper of OptimWrapperDict should be accessed and call | |
| its `backward`. | |
| """ | |
| raise NotImplementedError('`backward` should be called by each ' | |
| 'optimizer separately`') | |
| def step(self, **kwargs) -> None: | |
| """Since the backward method is not implemented, the step should not be | |
| implemented either.""" | |
| raise NotImplementedError('`step` should be called by each ' | |
| 'optimizer separately`') | |
| def zero_grad(self, **kwargs) -> None: | |
| """Set the gradients of all optimizer wrappers to zero.""" | |
| for optim_wrapper in self.optim_wrappers.values(): | |
| optim_wrapper.zero_grad() | |
| def optim_context(self, model: nn.Module): | |
| """``optim_context`` should be called by each optimizer separately.""" | |
| raise NotImplementedError( | |
| '`optim_context` should be called by each optimizer separately') | |
| def initialize_count_status(self, model: nn.Module, cur_iter, | |
| max_iters) -> None: | |
| """Do nothing but provide unified interface for :obj:`OptimWrapper` | |
| Since ``OptimWrapperDict`` does not know the correspondence between | |
| model and optimizer wrapper. ``initialize_iter_status`` will do nothing | |
| and each optimizer wrapper should call ``initialize_iter_status`` | |
| separately. | |
| """ | |
| return | |
| def param_groups(self): | |
| """Returns the parameter groups of each OptimWrapper.""" | |
| param_groups = dict() | |
| for key, value in self.optim_wrappers.items(): | |
| param_groups[key] = value.param_groups | |
| return param_groups | |
| def get_lr(self) -> Dict[str, List[float]]: | |
| """Get the learning rate of all optimizers. | |
| Returns: | |
| Dict[str, List[float]]: Learning rate of all optimizers. | |
| """ | |
| lr_dict = dict() | |
| for name, optim_wrapper in self.optim_wrappers.items(): | |
| inner_lr_dict = optim_wrapper.get_lr() | |
| if 'base_lr' in inner_lr_dict: | |
| lr_dict[f'{name}.base_lr'] = inner_lr_dict['base_lr'] | |
| lr_dict[f'{name}.lr'] = inner_lr_dict['lr'] | |
| return lr_dict | |
| def get_momentum(self) -> Dict[str, List[float]]: | |
| """Get the momentum of all optimizers. | |
| Returns: | |
| Dict[str, List[float]]: momentum of all optimizers. | |
| """ | |
| momentum_dict = dict() | |
| for name, optim_wrapper in self.optim_wrappers.items(): | |
| momentum_dict[f'{name}.momentum'] = optim_wrapper.get_momentum( | |
| )['momentum'] | |
| return momentum_dict | |
| def state_dict(self) -> dict: | |
| """Get the state dictionary of all optimizer wrappers. | |
| Returns: | |
| dict: Each key-value pair in the dictionary represents the name | |
| and state dictionary of corresponding :obj:`OptimWrapper`. | |
| """ | |
| state_dict = dict() | |
| for name, optim_wrapper in self.optim_wrappers.items(): | |
| state_dict[name] = optim_wrapper.state_dict() | |
| return state_dict | |
| def load_state_dict(self, state_dict: dict) -> None: | |
| """Load the state dictionary from the ``state_dict``. | |
| Args: | |
| state_dict (dict): Each key-value pair in `state_dict` represents | |
| the name and the state dictionary of corresponding | |
| :obj:`OptimWrapper`. | |
| """ | |
| for name, _state_dict in state_dict.items(): | |
| assert name in self.optim_wrappers, ( | |
| f'Mismatched `state_dict`! cannot found {name} in ' | |
| 'OptimWrapperDict') | |
| self.optim_wrappers[name].load_state_dict(_state_dict) | |
| def items(self) -> Iterator[Tuple[str, OptimWrapper]]: | |
| """A generator to get the name and corresponding | |
| :obj:`OptimWrapper`""" | |
| yield from self.optim_wrappers.items() | |
| def values(self) -> Iterator[OptimWrapper]: | |
| """A generator to get :obj:`OptimWrapper`""" | |
| yield from self.optim_wrappers.values() | |
| def keys(self) -> Iterator[str]: | |
| """A generator to get the name of :obj:`OptimWrapper`""" | |
| yield from self.optim_wrappers.keys() | |
| def __getitem__(self, key: str) -> OptimWrapper: | |
| assert key in self.optim_wrappers, ( | |
| f'Cannot find {key} in OptimWrapperDict, please check ' | |
| 'your optimizer constructor.') | |
| return self.optim_wrappers[key] | |
| def __contains__(self, key: str) -> bool: | |
| return key in self.optim_wrappers | |
| def __len__(self) -> int: | |
| return len(self.optim_wrappers) | |
| def __repr__(self) -> str: | |
| desc = '' | |
| for name, optim_wrapper in self.optim_wrappers.items(): | |
| desc += f'name: {name}\n' | |
| desc += repr(optim_wrapper) | |
| return desc | |