Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, List, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding | |
| from detect_tools.upn.models.module import NestedTensor | |
| from detect_tools.upn.models.utils import clean_state_dict | |
| class FrozenBatchNorm2d(torch.nn.Module): | |
| """ | |
| BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
| Copy-paste from torchvision.misc.ops with added eps before rqsrt, | |
| without which any other models than torchvision.models.resnet[18,34,50,101] | |
| produce nans. | |
| """ | |
| def __init__(self, n): | |
| super(FrozenBatchNorm2d, self).__init__() | |
| self.register_buffer("weight", torch.ones(n)) | |
| self.register_buffer("bias", torch.zeros(n)) | |
| self.register_buffer("running_mean", torch.zeros(n)) | |
| self.register_buffer("running_var", torch.ones(n)) | |
| def _load_from_state_dict( | |
| self, | |
| state_dict, | |
| prefix, | |
| local_metadata, | |
| strict, | |
| missing_keys, | |
| unexpected_keys, | |
| error_msgs, | |
| ): | |
| num_batches_tracked_key = prefix + "num_batches_tracked" | |
| if num_batches_tracked_key in state_dict: | |
| del state_dict[num_batches_tracked_key] | |
| super(FrozenBatchNorm2d, self)._load_from_state_dict( | |
| state_dict, | |
| prefix, | |
| local_metadata, | |
| strict, | |
| missing_keys, | |
| unexpected_keys, | |
| error_msgs, | |
| ) | |
| def forward(self, x): | |
| # move reshapes to the beginning | |
| # to make it fuser-friendly | |
| w = self.weight.reshape(1, -1, 1, 1) | |
| b = self.bias.reshape(1, -1, 1, 1) | |
| rv = self.running_var.reshape(1, -1, 1, 1) | |
| rm = self.running_mean.reshape(1, -1, 1, 1) | |
| eps = 1e-5 | |
| scale = w * (rv + eps).rsqrt() | |
| bias = b - rm * scale | |
| return x * scale + bias | |
| class Joiner(nn.Module): | |
| """A wrapper for the backbone and the position embedding. | |
| Args: | |
| backbone_cfg (Dict): Config dict to build backbone. | |
| position_embedding_cfg (Dict): Config dict to build position embedding. | |
| """ | |
| def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None: | |
| super().__init__() | |
| self.backbone = backbone | |
| self.pos_embed = position_embedding | |
| def forward( | |
| self, tensor_list: NestedTensor | |
| ) -> Union[List[NestedTensor], List[torch.Tensor]]: | |
| """Forward function. | |
| Args: | |
| tensor_list (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| [List[NestedTensor]: A list of feature map in NestedTensor format. | |
| List[torch.Tensor]: A list of position encoding. | |
| """ | |
| xs = self.backbone(tensor_list) | |
| out: List[NestedTensor] = [] | |
| pos = [] | |
| for layer_idx, x in xs.items(): | |
| out.append(x) | |
| # position encoding | |
| pos.append(self.pos_embed(x).to(x.tensors.dtype)) | |
| return out, pos | |
| def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor: | |
| """Forward function for position embedding only. This is used to generate additional layer | |
| Args: | |
| x (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| [List[torch.Tensor]: A list of position encoding. | |
| """ | |
| return self.pos_embed(x) | |
| class SwinWrapper(nn.Module): | |
| """A wrapper for swin transformer. | |
| Args: | |
| backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we | |
| will call `get_swin_config` to get the config dict. | |
| dilation (bool): Whether to use dilation in stage 4. | |
| position_embedding_cfg (Dict): Config dict to build position embedding. | |
| lr_backbone (float): Learning rate of the backbone. | |
| return_interm_layers (List[int]): Which layers to return. | |
| backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone. | |
| use_checkpoint (bool): Whether to use checkpoint. Default: False. | |
| ckpt_path (str): Checkpoint path. Default: None. | |
| use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True. | |
| """ | |
| def __init__( | |
| self, | |
| backbone_cfg: Union[Dict, str], | |
| dilation: bool, | |
| position_embedding_cfg: Dict, | |
| lr_backbone: float, | |
| return_interm_indices: List[int], | |
| backbone_freeze_keywords: List[str], | |
| use_checkpoint: bool = False, | |
| backbone_ckpt_path: str = None, | |
| ) -> None: | |
| super(SwinWrapper, self).__init__() | |
| pos_embedding = build_position_embedding(position_embedding_cfg) | |
| train_backbone = lr_backbone > 0 | |
| if not train_backbone: | |
| raise ValueError("Please set lr_backbone > 0") | |
| assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] | |
| # build backbone | |
| if isinstance(backbone_cfg, str): | |
| assert ( | |
| backbone_cfg | |
| in backbone_cfg | |
| in [ | |
| "swin_T_224_1k", | |
| "swin_B_224_22k", | |
| "swin_B_384_22k", | |
| "swin_L_224_22k", | |
| "swin_L_384_22k", | |
| ] | |
| ) | |
| pretrain_img_size = int(backbone_cfg.split("_")[-2]) | |
| backbone_cfg = get_swin_config( | |
| backbone_cfg, | |
| pretrain_img_size, | |
| out_indices=tuple(return_interm_indices), | |
| dilation=dilation, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| backbone = build_backbone(backbone_cfg) | |
| # freeze some layers | |
| if backbone_freeze_keywords is not None: | |
| for name, parameter in backbone.named_parameters(): | |
| for keyword in backbone_freeze_keywords: | |
| if keyword in name: | |
| parameter.requires_grad_(False) | |
| break | |
| # load checkpoint | |
| if backbone_ckpt_path is not None: | |
| print("Loading backbone checkpoint from {}".format(backbone_ckpt_path)) | |
| checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"] | |
| from collections import OrderedDict | |
| def key_select_function(keyname): | |
| if "head" in keyname: | |
| return False | |
| if dilation and "layers.3" in keyname: | |
| return False | |
| return True | |
| _tmp_st = OrderedDict( | |
| { | |
| k: v | |
| for k, v in clean_state_dict(checkpoint).items() | |
| if key_select_function(k) | |
| } | |
| ) | |
| _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False) | |
| print(str(_tmp_st_output)) | |
| bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] | |
| assert len(bb_num_channels) == len( | |
| return_interm_indices | |
| ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" | |
| model = Joiner(backbone, pos_embedding) | |
| model.num_channels = bb_num_channels | |
| self.num_channels = bb_num_channels | |
| self.model = model | |
| def forward( | |
| self, tensor_list: NestedTensor | |
| ) -> Union[List[NestedTensor], List[torch.Tensor]]: | |
| """Forward function. | |
| Args: | |
| tensor_list (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| [List[NestedTensor]: A list of feature map in NestedTensor format. | |
| List[torch.Tensor]: A list of position encoding. | |
| """ | |
| return self.model(tensor_list) | |
| def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor: | |
| """Forward function to get position embedding only. | |
| Args: | |
| tensor_list (NestedTensor): NestedTensor wrapping the input tensor. | |
| Returns: | |
| torch.Tensor: Position embedding. | |
| """ | |
| return self.model.forward_pos_embed_only(tensor_list) | |
| def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw): | |
| """Get swin config dict. | |
| Args: | |
| modelname (str): Name of the model. | |
| pretrain_img_size (Tuple[int, int]): Image size of the pretrain model. | |
| kw (Dict): Other key word arguments. | |
| Returns: | |
| Dict: Config dict. | |
| str: Path to the pretrained checkpoint. | |
| """ | |
| assert modelname in [ | |
| "swin_T_224_1k", | |
| "swin_B_224_22k", | |
| "swin_B_384_22k", | |
| "swin_L_224_22k", | |
| "swin_L_384_22k", | |
| ] | |
| model_para_dict = { | |
| "swin_T_224_1k": dict( | |
| type="SwinTransformer", | |
| embed_dim=96, | |
| depths=[2, 2, 6, 2], | |
| num_heads=[3, 6, 12, 24], | |
| window_size=7, | |
| ), | |
| "swin_B_224_22k": dict( | |
| type="SwinTransformer", | |
| embed_dim=128, | |
| depths=[2, 2, 18, 2], | |
| num_heads=[4, 8, 16, 32], | |
| window_size=7, | |
| ), | |
| "swin_B_384_22k": dict( | |
| type="SwinTransformer", | |
| embed_dim=128, | |
| depths=[2, 2, 18, 2], | |
| num_heads=[4, 8, 16, 32], | |
| window_size=12, | |
| ), | |
| "swin_L_224_22k": dict( | |
| type="SwinTransformer", | |
| embed_dim=192, | |
| depths=[2, 2, 18, 2], | |
| num_heads=[6, 12, 24, 48], | |
| window_size=7, | |
| ), | |
| "swin_L_384_22k": dict( | |
| type="SwinTransformer", | |
| embed_dim=192, | |
| depths=[2, 2, 18, 2], | |
| num_heads=[6, 12, 24, 48], | |
| window_size=12, | |
| ), | |
| } | |
| kw_cgf = model_para_dict[modelname] | |
| kw_cgf.update(kw) | |
| kw_cgf.update(dict(pretrain_img_size=pretrain_img_size)) | |
| return kw_cgf | |