File size: 2,811 Bytes
d008243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import math
from typing import Dict, List, Union

import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm

__all__ = ["init_modules", "load_state_dict"]


def init_modules(
    module: Union[nn.Module, List[nn.Module]], init_type="he_fout"
) -> None:
    init_params = init_type.split("@")
    if len(init_params) > 1:
        init_params = float(init_params[1])
    else:
        init_params = None

    if isinstance(module, list):
        for sub_module in module:
            init_modules(sub_module)
    else:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                if init_type == "he_fout":
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2.0 / n))
                elif init_type.startswith("kaiming_uniform"):
                    nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
                else:
                    nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    m.bias.data.zero_()
            else:
                weight = getattr(m, "weight", None)
                bias = getattr(m, "bias", None)
                if isinstance(weight, torch.nn.Parameter):
                    nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5))
                if isinstance(bias, torch.nn.Parameter):
                    bias.data.zero_()


def load_state_dict(
    model: nn.Module, state_dict: Dict[str, torch.Tensor], strict=True
) -> None:
    current_state_dict = model.state_dict()
    for key in state_dict:
        if current_state_dict[key].shape != state_dict[key].shape:
            if strict:
                raise ValueError(
                    "%s shape mismatch (src=%s, target=%s)"
                    % (
                        key,
                        list(state_dict[key].shape),
                        list(current_state_dict[key].shape),
                    )
                )
            else:
                print(
                    "Skip loading %s due to shape mismatch (src=%s, target=%s)"
                    % (
                        key,
                        list(state_dict[key].shape),
                        list(current_state_dict[key].shape),
                    )
                )
        else:
            current_state_dict[key].copy_(state_dict[key])
    model.load_state_dict(current_state_dict)