File size: 9,899 Bytes
56ef371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from typing import Dict, List, Tuple, Union

import torch
import torch.nn as nn

from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding
from detect_tools.upn.models.module import NestedTensor
from detect_tools.upn.models.utils import clean_state_dict


class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        num_batches_tracked_key = prefix + "num_batches_tracked"
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias


class Joiner(nn.Module):
    """A wrapper for the backbone and the position embedding.

    Args:
        backbone_cfg (Dict): Config dict to build backbone.
        position_embedding_cfg (Dict): Config dict to build position embedding.
    """

    def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None:
        super().__init__()
        self.backbone = backbone
        self.pos_embed = position_embedding

    def forward(
        self, tensor_list: NestedTensor
    ) -> Union[List[NestedTensor], List[torch.Tensor]]:
        """Forward function.

        Args:
            tensor_list (NestedTensor): NestedTensor wrapping the input tensor.

        Returns:
            [List[NestedTensor]: A list of feature map in NestedTensor format.
            List[torch.Tensor]: A list of position encoding.
        """

        xs = self.backbone(tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for layer_idx, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self.pos_embed(x).to(x.tensors.dtype))

        return out, pos

    def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor:
        """Forward function for position embedding only. This is used to generate additional layer

        Args:
            x (NestedTensor): NestedTensor wrapping the input tensor.

        Returns:
            [List[torch.Tensor]: A list of position encoding.
        """
        return self.pos_embed(x)


@BACKBONES.register_module()
class SwinWrapper(nn.Module):
    """A wrapper for swin transformer.

    Args:
        backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we
            will call `get_swin_config` to get the config dict.
        dilation (bool): Whether to use dilation in stage 4.
        position_embedding_cfg (Dict): Config dict to build position embedding.
        lr_backbone (float): Learning rate of the backbone.
        return_interm_layers (List[int]): Which layers to return.
        backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone.
        use_checkpoint (bool): Whether to use checkpoint. Default: False.
        ckpt_path (str): Checkpoint path. Default: None.
        use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True.
    """

    def __init__(
        self,
        backbone_cfg: Union[Dict, str],
        dilation: bool,
        position_embedding_cfg: Dict,
        lr_backbone: float,
        return_interm_indices: List[int],
        backbone_freeze_keywords: List[str],
        use_checkpoint: bool = False,
        backbone_ckpt_path: str = None,
    ) -> None:
        super(SwinWrapper, self).__init__()
        pos_embedding = build_position_embedding(position_embedding_cfg)
        train_backbone = lr_backbone > 0
        if not train_backbone:
            raise ValueError("Please set lr_backbone > 0")
        assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]

        # build backbone
        if isinstance(backbone_cfg, str):
            assert (
                backbone_cfg
                in backbone_cfg
                in [
                    "swin_T_224_1k",
                    "swin_B_224_22k",
                    "swin_B_384_22k",
                    "swin_L_224_22k",
                    "swin_L_384_22k",
                ]
            )
            pretrain_img_size = int(backbone_cfg.split("_")[-2])
            backbone_cfg = get_swin_config(
                backbone_cfg,
                pretrain_img_size,
                out_indices=tuple(return_interm_indices),
                dilation=dilation,
                use_checkpoint=use_checkpoint,
            )
        backbone = build_backbone(backbone_cfg)

        # freeze some layers
        if backbone_freeze_keywords is not None:
            for name, parameter in backbone.named_parameters():
                for keyword in backbone_freeze_keywords:
                    if keyword in name:
                        parameter.requires_grad_(False)
                        break

        # load checkpoint
        if backbone_ckpt_path is not None:
            print("Loading backbone checkpoint from {}".format(backbone_ckpt_path))
            checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"]
            from collections import OrderedDict

            def key_select_function(keyname):
                if "head" in keyname:
                    return False
                if dilation and "layers.3" in keyname:
                    return False
                return True

            _tmp_st = OrderedDict(
                {
                    k: v
                    for k, v in clean_state_dict(checkpoint).items()
                    if key_select_function(k)
                }
            )
            _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
            print(str(_tmp_st_output))

        bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
        assert len(bb_num_channels) == len(
            return_interm_indices
        ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"

        model = Joiner(backbone, pos_embedding)
        model.num_channels = bb_num_channels
        self.num_channels = bb_num_channels
        self.model = model

    def forward(
        self, tensor_list: NestedTensor
    ) -> Union[List[NestedTensor], List[torch.Tensor]]:
        """Forward function.

        Args:
            tensor_list (NestedTensor): NestedTensor wrapping the input tensor.

        Returns:
            [List[NestedTensor]: A list of feature map in NestedTensor format.
            List[torch.Tensor]: A list of position encoding.
        """

        return self.model(tensor_list)

    def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor:
        """Forward function to get position embedding only.

        Args:
            tensor_list (NestedTensor): NestedTensor wrapping the input tensor.

        Returns:
            torch.Tensor: Position embedding.
        """
        return self.model.forward_pos_embed_only(tensor_list)


def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw):
    """Get swin config dict.

    Args:
        modelname (str): Name of the model.
        pretrain_img_size (Tuple[int, int]): Image size of the pretrain model.
        kw (Dict): Other key word arguments.

    Returns:
        Dict: Config dict.
        str: Path to the pretrained checkpoint.
    """
    assert modelname in [
        "swin_T_224_1k",
        "swin_B_224_22k",
        "swin_B_384_22k",
        "swin_L_224_22k",
        "swin_L_384_22k",
    ]
    model_para_dict = {
        "swin_T_224_1k": dict(
            type="SwinTransformer",
            embed_dim=96,
            depths=[2, 2, 6, 2],
            num_heads=[3, 6, 12, 24],
            window_size=7,
        ),
        "swin_B_224_22k": dict(
            type="SwinTransformer",
            embed_dim=128,
            depths=[2, 2, 18, 2],
            num_heads=[4, 8, 16, 32],
            window_size=7,
        ),
        "swin_B_384_22k": dict(
            type="SwinTransformer",
            embed_dim=128,
            depths=[2, 2, 18, 2],
            num_heads=[4, 8, 16, 32],
            window_size=12,
        ),
        "swin_L_224_22k": dict(
            type="SwinTransformer",
            embed_dim=192,
            depths=[2, 2, 18, 2],
            num_heads=[6, 12, 24, 48],
            window_size=7,
        ),
        "swin_L_384_22k": dict(
            type="SwinTransformer",
            embed_dim=192,
            depths=[2, 2, 18, 2],
            num_heads=[6, 12, 24, 48],
            window_size=12,
        ),
    }
    kw_cgf = model_para_dict[modelname]
    kw_cgf.update(kw)
    kw_cgf.update(dict(pretrain_img_size=pretrain_img_size))
    return kw_cgf