| import math | |
| from typing import Dict, List, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.modules.batchnorm import _BatchNorm | |
| __all__ = ["init_modules", "load_state_dict"] | |
| def init_modules( | |
| module: Union[nn.Module, List[nn.Module]], init_type="he_fout" | |
| ) -> None: | |
| init_params = init_type.split("@") | |
| if len(init_params) > 1: | |
| init_params = float(init_params[1]) | |
| else: | |
| init_params = None | |
| if isinstance(module, list): | |
| for sub_module in module: | |
| init_modules(sub_module) | |
| else: | |
| for m in module.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| if init_type == "he_fout": | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| m.weight.data.normal_(0, math.sqrt(2.0 / n)) | |
| elif init_type.startswith("kaiming_uniform"): | |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5)) | |
| else: | |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5)) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| elif isinstance(m, _BatchNorm): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| else: | |
| weight = getattr(m, "weight", None) | |
| bias = getattr(m, "bias", None) | |
| if isinstance(weight, torch.nn.Parameter): | |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5)) | |
| if isinstance(bias, torch.nn.Parameter): | |
| bias.data.zero_() | |
| def load_state_dict( | |
| model: nn.Module, state_dict: Dict[str, torch.Tensor], strict=True | |
| ) -> None: | |
| current_state_dict = model.state_dict() | |
| for key in state_dict: | |
| if current_state_dict[key].shape != state_dict[key].shape: | |
| if strict: | |
| raise ValueError( | |
| "%s shape mismatch (src=%s, target=%s)" | |
| % ( | |
| key, | |
| list(state_dict[key].shape), | |
| list(current_state_dict[key].shape), | |
| ) | |
| ) | |
| else: | |
| print( | |
| "Skip loading %s due to shape mismatch (src=%s, target=%s)" | |
| % ( | |
| key, | |
| list(state_dict[key].shape), | |
| list(current_state_dict[key].shape), | |
| ) | |
| ) | |
| else: | |
| current_state_dict[key].copy_(state_dict[key]) | |
| model.load_state_dict(current_state_dict) | |