|
|
|
|
|
import copy
|
|
|
import math
|
|
|
import warnings
|
|
|
import inspect
|
|
|
from typing import Any, Optional, Union
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch import Tensor
|
|
|
|
|
|
from ..configs.config.config import Config, ConfigDict
|
|
|
from ..utils.registry import Registry
|
|
|
from ..utils.manager import ManagerMixin
|
|
|
|
|
|
|
|
|
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
|
|
|
|
|
@WEIGHT_INITIALIZERS.register_module(name='Pretrained')
|
|
|
class PretrainedInit:
|
|
|
"""Initialize module by loading a pretrained model.
|
|
|
|
|
|
Args:
|
|
|
checkpoint (str): the checkpoint file of the pretrained model should
|
|
|
be load.
|
|
|
prefix (str, optional): the prefix of a sub-module in the pretrained
|
|
|
model. it is for loading a part of the pretrained model to
|
|
|
initialize. For example, if we would like to only load the
|
|
|
backbone of a detector model, we can set ``prefix='backbone.'``.
|
|
|
Defaults to None.
|
|
|
map_location (str): map tensors into proper locations. Defaults to cpu.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, checkpoint, prefix=None, map_location='cpu'):
|
|
|
self.checkpoint = checkpoint
|
|
|
self.prefix = prefix
|
|
|
self.map_location = map_location
|
|
|
|
|
|
def __call__(self, module):
|
|
|
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
|
|
|
load_checkpoint,
|
|
|
load_state_dict)
|
|
|
if self.prefix is None:
|
|
|
load_checkpoint(
|
|
|
module,
|
|
|
self.checkpoint,
|
|
|
map_location=self.map_location,
|
|
|
strict=False,
|
|
|
logger='current')
|
|
|
else:
|
|
|
state_dict = _load_checkpoint_with_prefix(
|
|
|
self.prefix, self.checkpoint, map_location=self.map_location)
|
|
|
load_state_dict(module, state_dict, strict=False, logger='current')
|
|
|
|
|
|
if hasattr(module, '_params_init_info'):
|
|
|
update_init_info(module, init_info=self._get_init_info())
|
|
|
|
|
|
def _get_init_info(self):
|
|
|
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
|
|
|
return info
|
|
|
|
|
|
|
|
|
def update_init_info(module, init_info):
|
|
|
"""Update the `_params_init_info` in the module if the value of parameters
|
|
|
are changed.
|
|
|
|
|
|
Args:
|
|
|
module (obj:`nn.Module`): The module of PyTorch with a user-defined
|
|
|
attribute `_params_init_info` which records the initialization
|
|
|
information.
|
|
|
init_info (str): The string that describes the initialization.
|
|
|
"""
|
|
|
assert hasattr(
|
|
|
module,
|
|
|
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
|
|
|
for name, param in module.named_parameters():
|
|
|
|
|
|
assert param in module._params_init_info, (
|
|
|
f'Find a new :obj:`Parameter` '
|
|
|
f'named `{name}` during executing the '
|
|
|
f'`init_weights` of '
|
|
|
f'`{module.__class__.__name__}`. '
|
|
|
f'Please do not add or '
|
|
|
f'replace parameters during executing '
|
|
|
f'the `init_weights`. ')
|
|
|
|
|
|
|
|
|
|
|
|
mean_value = param.data.mean().cpu()
|
|
|
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
|
|
|
module._params_init_info[param]['init_info'] = init_info
|
|
|
module._params_init_info[param]['tmp_mean_value'] = mean_value
|
|
|
|
|
|
|
|
|
def initialize(module, init_cfg):
|
|
|
r"""Initialize a module.
|
|
|
|
|
|
Args:
|
|
|
module (``torch.nn.Module``): the module will be initialized.
|
|
|
init_cfg (dict | list[dict]): initialization configuration dict to
|
|
|
define initializer. OpenMMLab has implemented 6 initializers
|
|
|
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
|
|
|
``Kaiming``, and ``Pretrained``.
|
|
|
|
|
|
Example:
|
|
|
>>> module = nn.Linear(2, 3, bias=True)
|
|
|
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
|
|
|
>>> initialize(module, init_cfg)
|
|
|
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
|
|
|
>>> # define key ``'layer'`` for initializing layer with different
|
|
|
>>> # configuration
|
|
|
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
|
|
|
dict(type='Constant', layer='Linear', val=2)]
|
|
|
>>> initialize(module, init_cfg)
|
|
|
>>> # define key``'override'`` to initialize some specific part in
|
|
|
>>> # module
|
|
|
>>> class FooNet(nn.Module):
|
|
|
>>> def __init__(self):
|
|
|
>>> super().__init__()
|
|
|
>>> self.feat = nn.Conv2d(3, 16, 3)
|
|
|
>>> self.reg = nn.Conv2d(16, 10, 3)
|
|
|
>>> self.cls = nn.Conv2d(16, 5, 3)
|
|
|
>>> model = FooNet()
|
|
|
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
|
|
|
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
|
|
|
>>> initialize(model, init_cfg)
|
|
|
>>> model = ResNet(depth=50)
|
|
|
>>> # Initialize weights with the pretrained model.
|
|
|
>>> init_cfg = dict(type='Pretrained',
|
|
|
checkpoint='torchvision://resnet50')
|
|
|
>>> initialize(model, init_cfg)
|
|
|
>>> # Initialize weights of a sub-module with the specific part of
|
|
|
>>> # a pretrained model by using "prefix".
|
|
|
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
|
|
|
>>> 'retinanet_r50_fpn_1x_coco/'\
|
|
|
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
|
|
|
>>> init_cfg = dict(type='Pretrained',
|
|
|
checkpoint=url, prefix='backbone.')
|
|
|
"""
|
|
|
if not isinstance(init_cfg, (dict, list)):
|
|
|
raise TypeError(f'init_cfg must be a dict or a list of dict, \
|
|
|
but got {type(init_cfg)}')
|
|
|
|
|
|
if isinstance(init_cfg, dict):
|
|
|
init_cfg = [init_cfg]
|
|
|
|
|
|
for cfg in init_cfg:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cp_cfg = copy.deepcopy(cfg)
|
|
|
override = cp_cfg.pop('override', None)
|
|
|
_initialize(module, cp_cfg)
|
|
|
|
|
|
if override is not None:
|
|
|
cp_cfg.pop('layer', None)
|
|
|
_initialize_override(module, override, cp_cfg)
|
|
|
else:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
def _initialize(module, cfg, wholemodule=False):
|
|
|
func = build_from_cfg(cfg, WEIGHT_INITIALIZERS)
|
|
|
|
|
|
|
|
|
|
|
|
func.wholemodule = wholemodule
|
|
|
func(module)
|
|
|
|
|
|
|
|
|
def _initialize_override(module, override, cfg):
|
|
|
if not isinstance(override, (dict, list)):
|
|
|
raise TypeError(f'override must be a dict or a list of dict, \
|
|
|
but got {type(override)}')
|
|
|
|
|
|
override = [override] if isinstance(override, dict) else override
|
|
|
|
|
|
for override_ in override:
|
|
|
|
|
|
cp_override = copy.deepcopy(override_)
|
|
|
name = cp_override.pop('name', None)
|
|
|
if name is None:
|
|
|
raise ValueError('`override` must contain the key "name",'
|
|
|
f'but got {cp_override}')
|
|
|
|
|
|
if not cp_override:
|
|
|
cp_override.update(cfg)
|
|
|
|
|
|
|
|
|
elif 'type' not in cp_override.keys():
|
|
|
raise ValueError(
|
|
|
f'`override` need "type" key, but got {cp_override}')
|
|
|
|
|
|
if hasattr(module, name):
|
|
|
_initialize(getattr(module, name), cp_override, wholemodule=True)
|
|
|
else:
|
|
|
raise RuntimeError(f'module did not have attribute {name}, '
|
|
|
f'but init_cfg is {cp_override}.')
|
|
|
|
|
|
|
|
|
def build_from_cfg(
|
|
|
cfg: Union[dict, ConfigDict, Config],
|
|
|
registry: Registry,
|
|
|
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
|
|
"""Build a module from config dict when it is a class configuration, or
|
|
|
call a function from config dict when it is a function configuration.
|
|
|
|
|
|
If the global variable default scope (:obj:`DefaultScope`) exists,
|
|
|
:meth:`build` will firstly get the responding registry and then call
|
|
|
its own :meth:`build`.
|
|
|
|
|
|
At least one of the ``cfg`` and ``default_args`` contains the key "type",
|
|
|
which should be either str or class. If they all contain it, the key
|
|
|
in ``cfg`` will be used because ``cfg`` has a high priority than
|
|
|
``default_args`` that means if a key exists in both of them, the value of
|
|
|
the key will be ``cfg[key]``. They will be merged first and the key "type"
|
|
|
will be popped up and the remaining keys will be used as initialization
|
|
|
arguments.
|
|
|
|
|
|
Examples:
|
|
|
>>> from mmengine import Registry, build_from_cfg
|
|
|
>>> MODELS = Registry('models')
|
|
|
>>> @MODELS.register_module()
|
|
|
>>> class ResNet:
|
|
|
>>> def __init__(self, depth, stages=4):
|
|
|
>>> self.depth = depth
|
|
|
>>> self.stages = stages
|
|
|
>>> cfg = dict(type='ResNet', depth=50)
|
|
|
>>> model = build_from_cfg(cfg, MODELS)
|
|
|
>>> # Returns an instantiated object
|
|
|
>>> @MODELS.register_module()
|
|
|
>>> def resnet50():
|
|
|
>>> pass
|
|
|
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
|
|
|
>>> # Return a result of the calling function
|
|
|
|
|
|
Args:
|
|
|
cfg (dict or ConfigDict or Config): Config dict. It should at least
|
|
|
contain the key "type".
|
|
|
registry (:obj:`Registry`): The registry to search the type from.
|
|
|
default_args (dict or ConfigDict or Config, optional): Default
|
|
|
initialization arguments. Defaults to None.
|
|
|
|
|
|
Returns:
|
|
|
object: The constructed object.
|
|
|
"""
|
|
|
if not isinstance(cfg, (dict, ConfigDict, Config)):
|
|
|
raise TypeError(
|
|
|
f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}')
|
|
|
|
|
|
if 'type' not in cfg:
|
|
|
if default_args is None or 'type' not in default_args:
|
|
|
raise KeyError(
|
|
|
'`cfg` or `default_args` must contain the key "type", '
|
|
|
f'but got {cfg}\n{default_args}')
|
|
|
|
|
|
if not isinstance(registry, Registry):
|
|
|
raise TypeError('registry must be a mmengine.Registry object, '
|
|
|
f'but got {type(registry)}')
|
|
|
|
|
|
if not (isinstance(default_args,
|
|
|
(dict, ConfigDict, Config)) or default_args is None):
|
|
|
raise TypeError(
|
|
|
'default_args should be a dict, ConfigDict, Config or None, '
|
|
|
f'but got {type(default_args)}')
|
|
|
|
|
|
args = cfg.copy()
|
|
|
if default_args is not None:
|
|
|
for name, value in default_args.items():
|
|
|
args.setdefault(name, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scope = args.pop('_scope_', None)
|
|
|
with registry.switch_scope_and_registry(scope) as registry:
|
|
|
obj_type = args.pop('type')
|
|
|
if isinstance(obj_type, str):
|
|
|
obj_cls = registry.get(obj_type)
|
|
|
if obj_cls is None:
|
|
|
raise KeyError(
|
|
|
f'{obj_type} is not in the {registry.scope}::{registry.name} registry. '
|
|
|
f'Please check whether the value of `{obj_type}` is '
|
|
|
'correct or it was registered as expected. More details '
|
|
|
'can be found at '
|
|
|
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module'
|
|
|
)
|
|
|
|
|
|
elif callable(obj_type):
|
|
|
obj_cls = obj_type
|
|
|
else:
|
|
|
raise TypeError(
|
|
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inspect.isclass(obj_cls) and \
|
|
|
issubclass(obj_cls, ManagerMixin):
|
|
|
obj = obj_cls.get_instance(**args)
|
|
|
else:
|
|
|
obj = obj_cls(**args)
|
|
|
return obj
|
|
|
|
|
|
|
|
|
def constant_init(module, val, bias=0):
|
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
|
nn.init.constant_(module.weight, val)
|
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
|
|
|
def normal_init(module, mean=0, std=1, bias=0):
|
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
|
nn.init.normal_(module.weight, mean, std)
|
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
|
|
|
def trunc_normal_init(module: nn.Module,
|
|
|
mean: float = 0,
|
|
|
std: float = 1,
|
|
|
a: float = -2,
|
|
|
b: float = 2,
|
|
|
bias: float = 0) -> None:
|
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
|
trunc_normal_(module.weight, mean, std, a, b)
|
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
|
|
|
def kaiming_init(module,
|
|
|
a=0,
|
|
|
mode='fan_out',
|
|
|
nonlinearity='relu',
|
|
|
bias=0,
|
|
|
distribution='normal'):
|
|
|
assert distribution in ['uniform', 'normal']
|
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
|
if distribution == 'uniform':
|
|
|
nn.init.kaiming_uniform_(
|
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
|
else:
|
|
|
nn.init.kaiming_normal_(
|
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
|
|
|
def trunc_normal_(tensor: Tensor,
|
|
|
mean: float = 0.,
|
|
|
std: float = 1.,
|
|
|
a: float = -2.,
|
|
|
b: float = 2.) -> Tensor:
|
|
|
r"""Fills the input Tensor with values drawn from a truncated normal
|
|
|
distribution. The values are effectively drawn from the normal distribution
|
|
|
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
|
|
|
:math:`[a, b]` redrawn until they are within the bounds. The method used
|
|
|
for generating the random values works best when :math:`a \leq \text{mean}
|
|
|
\leq b`.
|
|
|
|
|
|
Modified from
|
|
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
|
|
|
|
|
Args:
|
|
|
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
|
|
mean (float): the mean of the normal distribution.
|
|
|
std (float): the standard deviation of the normal distribution.
|
|
|
a (float): the minimum cutoff value.
|
|
|
b (float): the maximum cutoff value.
|
|
|
"""
|
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
|
|
|
|
|
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
|
|
b: float) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def norm_cdf(x):
|
|
|
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
|
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
|
warnings.warn(
|
|
|
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
|
|
'The distribution of values may be incorrect.',
|
|
|
stacklevel=2)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
lower = norm_cdf((a - mean) / std)
|
|
|
upper = norm_cdf((b - mean) / std)
|
|
|
|
|
|
|
|
|
|
|
|
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
|
|
|
|
|
|
|
|
|
|
|
tensor.erfinv_()
|
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.))
|
|
|
tensor.add_(mean)
|
|
|
|
|
|
|
|
|
tensor.clamp_(min=a, max=b)
|
|
|
return tensor |