|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .network_blocks import BaseConv, DWConv |
|
|
|
|
|
|
|
|
_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]] |
|
|
|
|
|
|
|
|
def meshgrid(*tensors): |
|
|
""" |
|
|
Copied from YOLOX/yolox/utils/compat.py |
|
|
""" |
|
|
if _TORCH_VER >= [1, 10]: |
|
|
return torch.meshgrid(*tensors, indexing="ij") |
|
|
else: |
|
|
return torch.meshgrid(*tensors) |
|
|
|
|
|
|
|
|
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): |
|
|
""" |
|
|
Copied from YOLOX/yolox/utils/boxes.py |
|
|
""" |
|
|
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: |
|
|
raise IndexError |
|
|
|
|
|
if xyxy: |
|
|
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) |
|
|
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) |
|
|
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) |
|
|
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) |
|
|
else: |
|
|
tl = torch.max( |
|
|
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), |
|
|
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), |
|
|
) |
|
|
br = torch.min( |
|
|
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), |
|
|
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), |
|
|
) |
|
|
|
|
|
area_a = torch.prod(bboxes_a[:, 2:], 1) |
|
|
area_b = torch.prod(bboxes_b[:, 2:], 1) |
|
|
en = (tl < br).type(tl.type()).prod(dim=2) |
|
|
area_i = torch.prod(br - tl, 2) * en |
|
|
return area_i / (area_a[:, None] + area_b - area_i) |
|
|
|
|
|
|
|
|
class YOLOXHead(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_classes, |
|
|
width=1.0, |
|
|
strides=[8, 16, 32], |
|
|
in_channels=[256, 512, 1024], |
|
|
act="silu", |
|
|
depthwise=False, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
act (str): activation type of conv. Defalut value: "silu". |
|
|
depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.decode_in_inference = True |
|
|
|
|
|
self.cls_convs = nn.ModuleList() |
|
|
self.reg_convs = nn.ModuleList() |
|
|
self.cls_preds = nn.ModuleList() |
|
|
self.reg_preds = nn.ModuleList() |
|
|
self.obj_preds = nn.ModuleList() |
|
|
self.stems = nn.ModuleList() |
|
|
Conv = DWConv if depthwise else BaseConv |
|
|
|
|
|
for i in range(len(in_channels)): |
|
|
self.stems.append( |
|
|
BaseConv( |
|
|
in_channels=int(in_channels[i] * width), |
|
|
out_channels=int(256 * width), |
|
|
ksize=1, |
|
|
stride=1, |
|
|
act=act, |
|
|
) |
|
|
) |
|
|
self.cls_convs.append( |
|
|
nn.Sequential( |
|
|
*[ |
|
|
Conv( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=int(256 * width), |
|
|
ksize=3, |
|
|
stride=1, |
|
|
act=act, |
|
|
), |
|
|
Conv( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=int(256 * width), |
|
|
ksize=3, |
|
|
stride=1, |
|
|
act=act, |
|
|
), |
|
|
] |
|
|
) |
|
|
) |
|
|
self.reg_convs.append( |
|
|
nn.Sequential( |
|
|
*[ |
|
|
Conv( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=int(256 * width), |
|
|
ksize=3, |
|
|
stride=1, |
|
|
act=act, |
|
|
), |
|
|
Conv( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=int(256 * width), |
|
|
ksize=3, |
|
|
stride=1, |
|
|
act=act, |
|
|
), |
|
|
] |
|
|
) |
|
|
) |
|
|
self.cls_preds.append( |
|
|
nn.Conv2d( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=self.num_classes, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
) |
|
|
) |
|
|
self.reg_preds.append( |
|
|
nn.Conv2d( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=4, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
) |
|
|
) |
|
|
self.obj_preds.append( |
|
|
nn.Conv2d( |
|
|
in_channels=int(256 * width), |
|
|
out_channels=1, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
) |
|
|
) |
|
|
|
|
|
self.use_l1 = False |
|
|
self.l1_loss = nn.L1Loss(reduction="none") |
|
|
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") |
|
|
self.iou_loss = None |
|
|
self.strides = strides |
|
|
self.grids = [torch.zeros(1)] * len(in_channels) |
|
|
|
|
|
def forward(self, xin, labels=None, imgs=None): |
|
|
outputs = [] |
|
|
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( |
|
|
zip(self.cls_convs, self.reg_convs, self.strides, xin) |
|
|
): |
|
|
x = self.stems[k](x) |
|
|
cls_x = x |
|
|
reg_x = x |
|
|
|
|
|
cls_feat = cls_conv(cls_x) |
|
|
cls_output = self.cls_preds[k](cls_feat) |
|
|
|
|
|
reg_feat = reg_conv(reg_x) |
|
|
reg_output = self.reg_preds[k](reg_feat) |
|
|
obj_output = self.obj_preds[k](reg_feat) |
|
|
|
|
|
output = torch.cat( |
|
|
[reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1 |
|
|
) |
|
|
|
|
|
outputs.append(output) |
|
|
|
|
|
self.hw = [x.shape[-2:] for x in outputs] |
|
|
|
|
|
outputs = torch.cat( |
|
|
[x.flatten(start_dim=2) for x in outputs], dim=2 |
|
|
).permute(0, 2, 1) |
|
|
if self.decode_in_inference: |
|
|
return self.decode_outputs(outputs, dtype=xin[0].type()) |
|
|
else: |
|
|
return outputs |
|
|
|
|
|
def get_output_and_grid(self, output, k, stride, dtype): |
|
|
grid = self.grids[k] |
|
|
|
|
|
batch_size = output.shape[0] |
|
|
n_ch = 5 + self.num_classes |
|
|
hsize, wsize = output.shape[-2:] |
|
|
if grid.shape[2:4] != output.shape[2:4]: |
|
|
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) |
|
|
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) |
|
|
self.grids[k] = grid |
|
|
|
|
|
output = output.view(batch_size, 1, n_ch, hsize, wsize) |
|
|
output = output.permute(0, 1, 3, 4, 2).reshape( |
|
|
batch_size, hsize * wsize, -1 |
|
|
) |
|
|
grid = grid.view(1, -1, 2) |
|
|
output[..., :2] = (output[..., :2] + grid) * stride |
|
|
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride |
|
|
return output, grid |
|
|
|
|
|
def decode_outputs(self, outputs, dtype): |
|
|
grids = [] |
|
|
strides = [] |
|
|
for (hsize, wsize), stride in zip(self.hw, self.strides): |
|
|
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) |
|
|
grid = torch.stack((xv, yv), 2).view(1, -1, 2) |
|
|
grids.append(grid) |
|
|
shape = grid.shape[:2] |
|
|
strides.append(torch.full((*shape, 1), stride)) |
|
|
|
|
|
grids = torch.cat(grids, dim=1).type(dtype) |
|
|
strides = torch.cat(strides, dim=1).type(dtype) |
|
|
|
|
|
outputs = torch.cat([ |
|
|
(outputs[..., 0:2] + grids) * strides, |
|
|
torch.exp(outputs[..., 2:4]) * strides, |
|
|
outputs[..., 4:] |
|
|
], dim=-1) |
|
|
return outputs |
|
|
|