PengLiu
push inference code
56ef371
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding
from detect_tools.upn.models.module import NestedTensor
from detect_tools.upn.models.utils import clean_state_dict
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class Joiner(nn.Module):
"""A wrapper for the backbone and the position embedding.
Args:
backbone_cfg (Dict): Config dict to build backbone.
position_embedding_cfg (Dict): Config dict to build position embedding.
"""
def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None:
super().__init__()
self.backbone = backbone
self.pos_embed = position_embedding
def forward(
self, tensor_list: NestedTensor
) -> Union[List[NestedTensor], List[torch.Tensor]]:
"""Forward function.
Args:
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
[List[NestedTensor]: A list of feature map in NestedTensor format.
List[torch.Tensor]: A list of position encoding.
"""
xs = self.backbone(tensor_list)
out: List[NestedTensor] = []
pos = []
for layer_idx, x in xs.items():
out.append(x)
# position encoding
pos.append(self.pos_embed(x).to(x.tensors.dtype))
return out, pos
def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor:
"""Forward function for position embedding only. This is used to generate additional layer
Args:
x (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
[List[torch.Tensor]: A list of position encoding.
"""
return self.pos_embed(x)
@BACKBONES.register_module()
class SwinWrapper(nn.Module):
"""A wrapper for swin transformer.
Args:
backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we
will call `get_swin_config` to get the config dict.
dilation (bool): Whether to use dilation in stage 4.
position_embedding_cfg (Dict): Config dict to build position embedding.
lr_backbone (float): Learning rate of the backbone.
return_interm_layers (List[int]): Which layers to return.
backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone.
use_checkpoint (bool): Whether to use checkpoint. Default: False.
ckpt_path (str): Checkpoint path. Default: None.
use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True.
"""
def __init__(
self,
backbone_cfg: Union[Dict, str],
dilation: bool,
position_embedding_cfg: Dict,
lr_backbone: float,
return_interm_indices: List[int],
backbone_freeze_keywords: List[str],
use_checkpoint: bool = False,
backbone_ckpt_path: str = None,
) -> None:
super(SwinWrapper, self).__init__()
pos_embedding = build_position_embedding(position_embedding_cfg)
train_backbone = lr_backbone > 0
if not train_backbone:
raise ValueError("Please set lr_backbone > 0")
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
# build backbone
if isinstance(backbone_cfg, str):
assert (
backbone_cfg
in backbone_cfg
in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]
)
pretrain_img_size = int(backbone_cfg.split("_")[-2])
backbone_cfg = get_swin_config(
backbone_cfg,
pretrain_img_size,
out_indices=tuple(return_interm_indices),
dilation=dilation,
use_checkpoint=use_checkpoint,
)
backbone = build_backbone(backbone_cfg)
# freeze some layers
if backbone_freeze_keywords is not None:
for name, parameter in backbone.named_parameters():
for keyword in backbone_freeze_keywords:
if keyword in name:
parameter.requires_grad_(False)
break
# load checkpoint
if backbone_ckpt_path is not None:
print("Loading backbone checkpoint from {}".format(backbone_ckpt_path))
checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"]
from collections import OrderedDict
def key_select_function(keyname):
if "head" in keyname:
return False
if dilation and "layers.3" in keyname:
return False
return True
_tmp_st = OrderedDict(
{
k: v
for k, v in clean_state_dict(checkpoint).items()
if key_select_function(k)
}
)
_tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
print(str(_tmp_st_output))
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
assert len(bb_num_channels) == len(
return_interm_indices
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
model = Joiner(backbone, pos_embedding)
model.num_channels = bb_num_channels
self.num_channels = bb_num_channels
self.model = model
def forward(
self, tensor_list: NestedTensor
) -> Union[List[NestedTensor], List[torch.Tensor]]:
"""Forward function.
Args:
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
[List[NestedTensor]: A list of feature map in NestedTensor format.
List[torch.Tensor]: A list of position encoding.
"""
return self.model(tensor_list)
def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor:
"""Forward function to get position embedding only.
Args:
tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
Returns:
torch.Tensor: Position embedding.
"""
return self.model.forward_pos_embed_only(tensor_list)
def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw):
"""Get swin config dict.
Args:
modelname (str): Name of the model.
pretrain_img_size (Tuple[int, int]): Image size of the pretrain model.
kw (Dict): Other key word arguments.
Returns:
Dict: Config dict.
str: Path to the pretrained checkpoint.
"""
assert modelname in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]
model_para_dict = {
"swin_T_224_1k": dict(
type="SwinTransformer",
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
),
"swin_B_224_22k": dict(
type="SwinTransformer",
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
),
"swin_B_384_22k": dict(
type="SwinTransformer",
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=12,
),
"swin_L_224_22k": dict(
type="SwinTransformer",
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
),
"swin_L_384_22k": dict(
type="SwinTransformer",
embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
),
}
kw_cgf = model_para_dict[modelname]
kw_cgf.update(kw)
kw_cgf.update(dict(pretrain_img_size=pretrain_img_size))
return kw_cgf