Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, normal_init | |
| from mmseg.ops import resize | |
| class BaseDecodeHead(nn.Module): | |
| """Base class for BaseDecodeHead. | |
| Args: | |
| in_channels (int|Sequence[int]): Input channels. | |
| channels (int): Channels after modules, before conv_seg. | |
| num_classes (int): Number of classes. | |
| dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |
| conv_cfg (dict|None): Config of conv layers. Default: None. | |
| norm_cfg (dict|None): Config of norm layers. Default: None. | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU') | |
| in_index (int|Sequence[int]): Input feature index. Default: -1 | |
| input_transform (str|None): Transformation type of input features. | |
| Options: 'resize_concat', 'multiple_select', None. | |
| 'resize_concat': Multiple feature maps will be resize to the | |
| same size as first one and than concat together. | |
| Usually used in FCN head of HRNet. | |
| 'multiple_select': Multiple feature maps will be bundle into | |
| a list and passed into decode head. | |
| None: Only one select feature map is allowed. | |
| Default: None. | |
| loss_decode (dict): Config of decode loss. | |
| Default: dict(type='CrossEntropyLoss'). | |
| ignore_index (int | None): The label index to be ignored. When using | |
| masked BCE loss, ignore_index should be set to None. Default: 255 | |
| sampler (dict|None): The config of segmentation map sampler. | |
| Default: None. | |
| align_corners (bool): align_corners argument of F.interpolate. | |
| Default: False. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| channels, | |
| *, | |
| num_classes, | |
| dropout_ratio=0.1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| in_index=-1, | |
| input_transform=None, | |
| ignore_index=255, | |
| align_corners=False): | |
| super(BaseDecodeHead, self).__init__() | |
| self._init_inputs(in_channels, in_index, input_transform) | |
| self.channels = channels | |
| self.num_classes = num_classes | |
| self.dropout_ratio = dropout_ratio | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.in_index = in_index | |
| self.ignore_index = ignore_index | |
| self.align_corners = align_corners | |
| self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |
| if dropout_ratio > 0: | |
| self.dropout = nn.Dropout2d(dropout_ratio) | |
| else: | |
| self.dropout = None | |
| def extra_repr(self): | |
| """Extra repr.""" | |
| s = f'input_transform={self.input_transform}, ' \ | |
| f'ignore_index={self.ignore_index}, ' \ | |
| f'align_corners={self.align_corners}' | |
| return s | |
| def _init_inputs(self, in_channels, in_index, input_transform): | |
| """Check and initialize input transforms. | |
| The in_channels, in_index and input_transform must match. | |
| Specifically, when input_transform is None, only single feature map | |
| will be selected. So in_channels and in_index must be of type int. | |
| When input_transform | |
| Args: | |
| in_channels (int|Sequence[int]): Input channels. | |
| in_index (int|Sequence[int]): Input feature index. | |
| input_transform (str|None): Transformation type of input features. | |
| Options: 'resize_concat', 'multiple_select', None. | |
| 'resize_concat': Multiple feature maps will be resize to the | |
| same size as first one and than concat together. | |
| Usually used in FCN head of HRNet. | |
| 'multiple_select': Multiple feature maps will be bundle into | |
| a list and passed into decode head. | |
| None: Only one select feature map is allowed. | |
| """ | |
| if input_transform is not None: | |
| assert input_transform in ['resize_concat', 'multiple_select'] | |
| self.input_transform = input_transform | |
| self.in_index = in_index | |
| if input_transform is not None: | |
| assert isinstance(in_channels, (list, tuple)) | |
| assert isinstance(in_index, (list, tuple)) | |
| assert len(in_channels) == len(in_index) | |
| if input_transform == 'resize_concat': | |
| self.in_channels = sum(in_channels) | |
| else: | |
| self.in_channels = in_channels | |
| else: | |
| assert isinstance(in_channels, int) | |
| assert isinstance(in_index, int) | |
| self.in_channels = in_channels | |
| def init_weights(self): | |
| """Initialize weights of classification layer.""" | |
| normal_init(self.conv_seg, mean=0, std=0.01) | |
| def _transform_inputs(self, inputs): | |
| """Transform inputs for decoder. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| Tensor: The transformed inputs | |
| """ | |
| if self.input_transform == 'resize_concat': | |
| inputs = [inputs[i] for i in self.in_index] | |
| upsampled_inputs = [ | |
| resize( | |
| input=x, | |
| size=inputs[0].shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) for x in inputs | |
| ] | |
| inputs = torch.cat(upsampled_inputs, dim=1) | |
| elif self.input_transform == 'multiple_select': | |
| inputs = [inputs[i] for i in self.in_index] | |
| else: | |
| inputs = inputs[self.in_index] | |
| return inputs | |
| def forward(self, inputs): | |
| """Placeholder of forward function.""" | |
| pass | |
| def cls_seg(self, feat): | |
| """Classify each pixel.""" | |
| if self.dropout is not None: | |
| feat = self.dropout(feat) | |
| output = self.conv_seg(feat) | |
| return output | |
| class FCNHead(BaseDecodeHead): | |
| """Fully Convolution Networks for Semantic Segmentation. | |
| This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. | |
| Args: | |
| num_convs (int): Number of convs in the head. Default: 2. | |
| kernel_size (int): The kernel size for convs in the head. Default: 3. | |
| concat_input (bool): Whether concat the input and output of convs | |
| before classification layer. | |
| """ | |
| def __init__(self, | |
| num_convs=2, | |
| kernel_size=3, | |
| concat_input=True, | |
| **kwargs): | |
| assert num_convs >= 0 | |
| self.num_convs = num_convs | |
| self.concat_input = concat_input | |
| self.kernel_size = kernel_size | |
| super(FCNHead, self).__init__(**kwargs) | |
| if num_convs == 0: | |
| assert self.in_channels == self.channels | |
| convs = [] | |
| convs.append( | |
| ConvModule( | |
| self.in_channels, | |
| self.channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| for i in range(num_convs - 1): | |
| convs.append( | |
| ConvModule( | |
| self.channels, | |
| self.channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| if num_convs == 0: | |
| self.convs = nn.Identity() | |
| else: | |
| self.convs = nn.Sequential(*convs) | |
| if self.concat_input: | |
| self.conv_cat = ConvModule( | |
| self.in_channels + self.channels, | |
| self.channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| x = self._transform_inputs(inputs) | |
| output = self.convs(x) | |
| if self.concat_input: | |
| output = self.conv_cat(torch.cat([x, output], dim=1)) | |
| output = self.cls_seg(output) | |
| return output | |
| class MultiHeadFCNHead(nn.Module): | |
| """Fully Convolution Networks for Semantic Segmentation. | |
| This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. | |
| Args: | |
| num_convs (int): Number of convs in the head. Default: 2. | |
| kernel_size (int): The kernel size for convs in the head. Default: 3. | |
| concat_input (bool): Whether concat the input and output of convs | |
| before classification layer. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| channels, | |
| *, | |
| num_classes, | |
| dropout_ratio=0.1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| in_index=-1, | |
| input_transform=None, | |
| ignore_index=255, | |
| align_corners=False, | |
| num_convs=2, | |
| kernel_size=3, | |
| concat_input=True, | |
| num_head=18, | |
| **kwargs): | |
| super(MultiHeadFCNHead, self).__init__() | |
| assert num_convs >= 0 | |
| self.num_convs = num_convs | |
| self.concat_input = concat_input | |
| self.kernel_size = kernel_size | |
| self._init_inputs(in_channels, in_index, input_transform) | |
| self.channels = channels | |
| self.num_classes = num_classes | |
| self.dropout_ratio = dropout_ratio | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.in_index = in_index | |
| self.num_head = num_head | |
| self.ignore_index = ignore_index | |
| self.align_corners = align_corners | |
| if dropout_ratio > 0: | |
| self.dropout = nn.Dropout2d(dropout_ratio) | |
| conv_seg_head_list = [] | |
| for _ in range(self.num_head): | |
| conv_seg_head_list.append( | |
| nn.Conv2d(channels, num_classes, kernel_size=1)) | |
| self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list) | |
| self.init_weights() | |
| if num_convs == 0: | |
| assert self.in_channels == self.channels | |
| convs_list = [] | |
| conv_cat_list = [] | |
| for _ in range(self.num_head): | |
| convs = [] | |
| convs.append( | |
| ConvModule( | |
| self.in_channels, | |
| self.channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| for _ in range(num_convs - 1): | |
| convs.append( | |
| ConvModule( | |
| self.channels, | |
| self.channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| if num_convs == 0: | |
| convs_list.append(nn.Identity()) | |
| else: | |
| convs_list.append(nn.Sequential(*convs)) | |
| if self.concat_input: | |
| conv_cat_list.append( | |
| ConvModule( | |
| self.in_channels + self.channels, | |
| self.channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| self.convs_list = nn.ModuleList(convs_list) | |
| self.conv_cat_list = nn.ModuleList(conv_cat_list) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| x = self._transform_inputs(inputs) | |
| output_list = [] | |
| for head_idx in range(self.num_head): | |
| output = self.convs_list[head_idx](x) | |
| if self.concat_input: | |
| output = self.conv_cat_list[head_idx]( | |
| torch.cat([x, output], dim=1)) | |
| if self.dropout is not None: | |
| output = self.dropout(output) | |
| output = self.conv_seg_head_list[head_idx](output) | |
| output_list.append(output) | |
| return output_list | |
| def _init_inputs(self, in_channels, in_index, input_transform): | |
| """Check and initialize input transforms. | |
| The in_channels, in_index and input_transform must match. | |
| Specifically, when input_transform is None, only single feature map | |
| will be selected. So in_channels and in_index must be of type int. | |
| When input_transform | |
| Args: | |
| in_channels (int|Sequence[int]): Input channels. | |
| in_index (int|Sequence[int]): Input feature index. | |
| input_transform (str|None): Transformation type of input features. | |
| Options: 'resize_concat', 'multiple_select', None. | |
| 'resize_concat': Multiple feature maps will be resize to the | |
| same size as first one and than concat together. | |
| Usually used in FCN head of HRNet. | |
| 'multiple_select': Multiple feature maps will be bundle into | |
| a list and passed into decode head. | |
| None: Only one select feature map is allowed. | |
| """ | |
| if input_transform is not None: | |
| assert input_transform in ['resize_concat', 'multiple_select'] | |
| self.input_transform = input_transform | |
| self.in_index = in_index | |
| if input_transform is not None: | |
| assert isinstance(in_channels, (list, tuple)) | |
| assert isinstance(in_index, (list, tuple)) | |
| assert len(in_channels) == len(in_index) | |
| if input_transform == 'resize_concat': | |
| self.in_channels = sum(in_channels) | |
| else: | |
| self.in_channels = in_channels | |
| else: | |
| assert isinstance(in_channels, int) | |
| assert isinstance(in_index, int) | |
| self.in_channels = in_channels | |
| def init_weights(self): | |
| """Initialize weights of classification layer.""" | |
| for conv_seg_head in self.conv_seg_head_list: | |
| normal_init(conv_seg_head, mean=0, std=0.01) | |
| def _transform_inputs(self, inputs): | |
| """Transform inputs for decoder. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| Tensor: The transformed inputs | |
| """ | |
| if self.input_transform == 'resize_concat': | |
| inputs = [inputs[i] for i in self.in_index] | |
| upsampled_inputs = [ | |
| resize( | |
| input=x, | |
| size=inputs[0].shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) for x in inputs | |
| ] | |
| inputs = torch.cat(upsampled_inputs, dim=1) | |
| elif self.input_transform == 'multiple_select': | |
| inputs = [inputs[i] for i in self.in_index] | |
| else: | |
| inputs = inputs[self.in_index] | |
| return inputs | |