ProArd / utils /init.py
smi08's picture
Upload folder using huggingface_hub
d008243 verified
import math
from typing import Dict, List, Union
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
__all__ = ["init_modules", "load_state_dict"]
def init_modules(
module: Union[nn.Module, List[nn.Module]], init_type="he_fout"
) -> None:
init_params = init_type.split("@")
if len(init_params) > 1:
init_params = float(init_params[1])
else:
init_params = None
if isinstance(module, list):
for sub_module in module:
init_modules(sub_module)
else:
for m in module.modules():
if isinstance(m, nn.Conv2d):
if init_type == "he_fout":
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif init_type.startswith("kaiming_uniform"):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
else:
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
m.bias.data.zero_()
else:
weight = getattr(m, "weight", None)
bias = getattr(m, "bias", None)
if isinstance(weight, torch.nn.Parameter):
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
if isinstance(bias, torch.nn.Parameter):
bias.data.zero_()
def load_state_dict(
model: nn.Module, state_dict: Dict[str, torch.Tensor], strict=True
) -> None:
current_state_dict = model.state_dict()
for key in state_dict:
if current_state_dict[key].shape != state_dict[key].shape:
if strict:
raise ValueError(
"%s shape mismatch (src=%s, target=%s)"
% (
key,
list(state_dict[key].shape),
list(current_state_dict[key].shape),
)
)
else:
print(
"Skip loading %s due to shape mismatch (src=%s, target=%s)"
% (
key,
list(state_dict[key].shape),
list(current_state_dict[key].shape),
)
)
else:
current_state_dict[key].copy_(state_dict[key])
model.load_state_dict(current_state_dict)