Spaces:
Paused
Paused
| #https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py | |
| # MIT License | |
| # Copyright (c) 2018 Adrian Wolny | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| from functools import partial | |
| import torch | |
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| # from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D | |
| def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, | |
| dropout_prob, is3d): | |
| """ | |
| Create a list of modules with together constitute a single conv layer with non-linearity | |
| and optional batchnorm/groupnorm. | |
| Args: | |
| in_channels (int): number of input channels | |
| out_channels (int): number of output channels | |
| kernel_size(int or tuple): size of the convolving kernel | |
| order (string): order of things, e.g. | |
| 'cr' -> conv + ReLU | |
| 'gcr' -> groupnorm + conv + ReLU | |
| 'cl' -> conv + LeakyReLU | |
| 'ce' -> conv + ELU | |
| 'bcr' -> batchnorm + conv + ReLU | |
| 'cbrd' -> conv + batchnorm + ReLU + dropout | |
| 'cbrD' -> conv + batchnorm + ReLU + dropout2d | |
| num_groups (int): number of groups for the GroupNorm | |
| padding (int or tuple): add zero-padding added to all three sides of the input | |
| dropout_prob (float): dropout probability | |
| is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d | |
| Return: | |
| list of tuple (name, module) | |
| """ | |
| assert 'c' in order, "Conv layer MUST be present" | |
| assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' | |
| modules = [] | |
| for i, char in enumerate(order): | |
| if char == 'r': | |
| modules.append(('ReLU', nn.ReLU(inplace=True))) | |
| elif char == 'l': | |
| modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) | |
| elif char == 'e': | |
| modules.append(('ELU', nn.ELU(inplace=True))) | |
| elif char == 'c': | |
| # add learnable bias only in the absence of batchnorm/groupnorm | |
| bias = not ('g' in order or 'b' in order) | |
| if is3d: | |
| conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) | |
| else: | |
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) | |
| modules.append(('conv', conv)) | |
| elif char == 'g': | |
| is_before_conv = i < order.index('c') | |
| if is_before_conv: | |
| num_channels = in_channels | |
| else: | |
| num_channels = out_channels | |
| # use only one group if the given number of groups is greater than the number of channels | |
| if num_channels < num_groups: | |
| num_groups = 1 | |
| assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' | |
| modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) | |
| elif char == 'b': | |
| is_before_conv = i < order.index('c') | |
| if is3d: | |
| bn = nn.BatchNorm3d | |
| else: | |
| bn = nn.BatchNorm2d | |
| if is_before_conv: | |
| modules.append(('batchnorm', bn(in_channels))) | |
| else: | |
| modules.append(('batchnorm', bn(out_channels))) | |
| elif char == 'd': | |
| modules.append(('dropout', nn.Dropout(p=dropout_prob))) | |
| elif char == 'D': | |
| modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) | |
| else: | |
| raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") | |
| return modules | |
| class SingleConv(nn.Sequential): | |
| """ | |
| Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order | |
| of operations can be specified via the `order` parameter | |
| Args: | |
| in_channels (int): number of input channels | |
| out_channels (int): number of output channels | |
| kernel_size (int or tuple): size of the convolving kernel | |
| order (string): determines the order of layers, e.g. | |
| 'cr' -> conv + ReLU | |
| 'crg' -> conv + ReLU + groupnorm | |
| 'cl' -> conv + LeakyReLU | |
| 'ce' -> conv + ELU | |
| num_groups (int): number of groups for the GroupNorm | |
| padding (int or tuple): add zero-padding | |
| dropout_prob (float): dropout probability, default 0.1 | |
| is3d (bool): if True use Conv3d, otherwise use Conv2d | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, | |
| padding=1, dropout_prob=0.1, is3d=True): | |
| super(SingleConv, self).__init__() | |
| for name, module in create_conv(in_channels, out_channels, kernel_size, order, | |
| num_groups, padding, dropout_prob, is3d): | |
| self.add_module(name, module) | |
| class DoubleConv(nn.Sequential): | |
| """ | |
| A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). | |
| We use (Conv3d+ReLU+GroupNorm3d) by default. | |
| This can be changed however by providing the 'order' argument, e.g. in order | |
| to change to Conv3d+BatchNorm3d+ELU use order='cbe'. | |
| Use padded convolutions to make sure that the output (H_out, W_out) is the same | |
| as (H_in, W_in), so that you don't have to crop in the decoder path. | |
| Args: | |
| in_channels (int): number of input channels | |
| out_channels (int): number of output channels | |
| encoder (bool): if True we're in the encoder path, otherwise we're in the decoder | |
| kernel_size (int or tuple): size of the convolving kernel | |
| order (string): determines the order of layers, e.g. | |
| 'cr' -> conv + ReLU | |
| 'crg' -> conv + ReLU + groupnorm | |
| 'cl' -> conv + LeakyReLU | |
| 'ce' -> conv + ELU | |
| num_groups (int): number of groups for the GroupNorm | |
| padding (int or tuple): add zero-padding added to all three sides of the input | |
| upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 | |
| dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 | |
| is3d (bool): if True use Conv3d instead of Conv2d layers | |
| """ | |
| def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', | |
| num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): | |
| super(DoubleConv, self).__init__() | |
| if encoder: | |
| # we're in the encoder path | |
| conv1_in_channels = in_channels | |
| if upscale == 1: | |
| conv1_out_channels = out_channels | |
| else: | |
| conv1_out_channels = out_channels // 2 | |
| if conv1_out_channels < in_channels: | |
| conv1_out_channels = in_channels | |
| conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels | |
| else: | |
| # we're in the decoder path, decrease the number of channels in the 1st convolution | |
| conv1_in_channels, conv1_out_channels = in_channels, out_channels | |
| conv2_in_channels, conv2_out_channels = out_channels, out_channels | |
| # check if dropout_prob is a tuple and if so | |
| # split it for different dropout probabilities for each convolution. | |
| if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): | |
| dropout_prob1 = dropout_prob[0] | |
| dropout_prob2 = dropout_prob[1] | |
| else: | |
| dropout_prob1 = dropout_prob2 = dropout_prob | |
| # conv1 | |
| self.add_module('SingleConv1', | |
| SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, | |
| padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) | |
| # conv2 | |
| self.add_module('SingleConv2', | |
| SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, | |
| padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) | |
| class ResNetBlock(nn.Module): | |
| """ | |
| Residual block that can be used instead of standard DoubleConv in the Encoder module. | |
| Motivated by: https://arxiv.org/pdf/1706.00120.pdf | |
| Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): | |
| super(ResNetBlock, self).__init__() | |
| if in_channels != out_channels: | |
| # conv1x1 for increasing the number of channels | |
| if is3d: | |
| self.conv1 = nn.Conv3d(in_channels, out_channels, 1) | |
| else: | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 1) | |
| else: | |
| self.conv1 = nn.Identity() | |
| self.conv2 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, | |
| is3d=is3d) | |
| # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual | |
| n_order = order | |
| for c in 'rel': | |
| n_order = n_order.replace(c, '') | |
| self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, | |
| num_groups=num_groups, is3d=is3d) | |
| # create non-linearity separately | |
| if 'l' in order: | |
| self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) | |
| elif 'e' in order: | |
| self.non_linearity = nn.ELU(inplace=True) | |
| else: | |
| self.non_linearity = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| # apply first convolution to bring the number of channels to out_channels | |
| residual = self.conv1(x) | |
| out = self.conv2(x) | |
| out = self.conv3(out) | |
| out += residual | |
| out = self.non_linearity(out) | |
| return out | |
| class Encoder(nn.Module): | |
| """ | |
| A single module from the encoder path consisting of the optional max | |
| pooling layer (one may specify the MaxPool kernel_size to be different | |
| from the standard (2,2,2), e.g. if the volumetric data is anisotropic | |
| (make sure to use complementary scale_factor in the decoder path) followed by | |
| a basic module (DoubleConv or ResNetBlock). | |
| Args: | |
| in_channels (int): number of input channels | |
| out_channels (int): number of output channels | |
| conv_kernel_size (int or tuple): size of the convolving kernel | |
| apply_pooling (bool): if True use MaxPool3d before DoubleConv | |
| pool_kernel_size (int or tuple): the size of the window | |
| pool_type (str): pooling layer: 'max' or 'avg' | |
| basic_module(nn.Module): either ResNetBlock or DoubleConv | |
| conv_layer_order (string): determines the order of layers | |
| in `DoubleConv` module. See `DoubleConv` for more info. | |
| num_groups (int): number of groups for the GroupNorm | |
| padding (int or tuple): add zero-padding added to all three sides of the input | |
| upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 | |
| dropout_prob (float or tuple): dropout probability, default 0.1 | |
| is3d (bool): use 3d or 2d convolutions/pooling operation | |
| """ | |
| def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, | |
| pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', | |
| num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): | |
| super(Encoder, self).__init__() | |
| assert pool_type in ['max', 'avg'] | |
| if apply_pooling: | |
| if pool_type == 'max': | |
| if is3d: | |
| self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) | |
| else: | |
| self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) | |
| else: | |
| if is3d: | |
| self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) | |
| else: | |
| self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) | |
| else: | |
| self.pooling = None | |
| self.basic_module = basic_module(in_channels, out_channels, | |
| encoder=True, | |
| kernel_size=conv_kernel_size, | |
| order=conv_layer_order, | |
| num_groups=num_groups, | |
| padding=padding, | |
| upscale=upscale, | |
| dropout_prob=dropout_prob, | |
| is3d=is3d) | |
| def forward(self, x): | |
| if self.pooling is not None: | |
| x = self.pooling(x) | |
| x = self.basic_module(x) | |
| return x | |
| class Decoder(nn.Module): | |
| """ | |
| A single module for decoder path consisting of the upsampling layer | |
| (either learned ConvTranspose3d or nearest neighbor interpolation) | |
| followed by a basic module (DoubleConv or ResNetBlock). | |
| Args: | |
| in_channels (int): number of input channels | |
| out_channels (int): number of output channels | |
| conv_kernel_size (int or tuple): size of the convolving kernel | |
| scale_factor (int or tuple): used as the multiplier for the image H/W/D in | |
| case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation | |
| from the corresponding encoder | |
| basic_module(nn.Module): either ResNetBlock or DoubleConv | |
| conv_layer_order (string): determines the order of layers | |
| in `DoubleConv` module. See `DoubleConv` for more info. | |
| num_groups (int): number of groups for the GroupNorm | |
| padding (int or tuple): add zero-padding added to all three sides of the input | |
| upsample (str): algorithm used for upsampling: | |
| InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' | |
| TransposeConvUpsampling: 'deconv' | |
| No upsampling: None | |
| Default: 'default' (chooses automatically) | |
| dropout_prob (float or tuple): dropout probability, default 0.1 | |
| """ | |
| def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, | |
| conv_layer_order='gcr', num_groups=8, padding=1, upsample='default', | |
| dropout_prob=0.1, is3d=True): | |
| super(Decoder, self).__init__() | |
| # perform concat joining per default | |
| concat = True | |
| # don't adapt channels after join operation | |
| adapt_channels = False | |
| if upsample is not None and upsample != 'none': | |
| if upsample == 'default': | |
| if basic_module == DoubleConv: | |
| upsample = 'nearest' # use nearest neighbor interpolation for upsampling | |
| concat = True # use concat joining | |
| adapt_channels = False # don't adapt channels | |
| elif basic_module == ResNetBlock: #or basic_module == ResNetBlockSE: | |
| upsample = 'deconv' # use deconvolution upsampling | |
| concat = False # use summation joining | |
| adapt_channels = True # adapt channels after joining | |
| # perform deconvolution upsampling if mode is deconv | |
| if upsample == 'deconv': | |
| self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=conv_kernel_size, scale_factor=scale_factor, | |
| is3d=is3d) | |
| else: | |
| self.upsampling = InterpolateUpsampling(mode=upsample) | |
| else: | |
| # no upsampling | |
| self.upsampling = NoUpsampling() | |
| # concat joining | |
| self.joining = partial(self._joining, concat=True) | |
| # perform joining operation | |
| self.joining = partial(self._joining, concat=concat) | |
| # adapt the number of in_channels for the ResNetBlock | |
| if adapt_channels is True: | |
| in_channels = out_channels | |
| self.basic_module = basic_module(in_channels, out_channels, | |
| encoder=False, | |
| kernel_size=conv_kernel_size, | |
| order=conv_layer_order, | |
| num_groups=num_groups, | |
| padding=padding, | |
| dropout_prob=dropout_prob, | |
| is3d=is3d) | |
| def forward(self, encoder_features, x): | |
| x = self.upsampling(encoder_features=encoder_features, x=x) | |
| x = self.joining(encoder_features, x) | |
| x = self.basic_module(x) | |
| return x | |
| def _joining(encoder_features, x, concat): | |
| if concat: | |
| return torch.cat((encoder_features, x), dim=1) | |
| else: | |
| return encoder_features + x | |
| def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, | |
| conv_upscale, dropout_prob, | |
| layer_order, num_groups, pool_kernel_size, is3d): | |
| # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` | |
| encoders = [] | |
| for i, out_feature_num in enumerate(f_maps): | |
| if i == 0: | |
| # apply conv_coord only in the first encoder if any | |
| encoder = Encoder(in_channels, out_feature_num, | |
| apply_pooling=False, # skip pooling in the firs encoder | |
| basic_module=basic_module, | |
| conv_layer_order=layer_order, | |
| conv_kernel_size=conv_kernel_size, | |
| num_groups=num_groups, | |
| padding=conv_padding, | |
| upscale=conv_upscale, | |
| dropout_prob=dropout_prob, | |
| is3d=is3d) | |
| else: | |
| encoder = Encoder(f_maps[i - 1], out_feature_num, | |
| basic_module=basic_module, | |
| conv_layer_order=layer_order, | |
| conv_kernel_size=conv_kernel_size, | |
| num_groups=num_groups, | |
| pool_kernel_size=pool_kernel_size, | |
| padding=conv_padding, | |
| upscale=conv_upscale, | |
| dropout_prob=dropout_prob, | |
| is3d=is3d) | |
| encoders.append(encoder) | |
| return nn.ModuleList(encoders) | |
| def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, | |
| num_groups, upsample, dropout_prob, is3d): | |
| # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` | |
| decoders = [] | |
| reversed_f_maps = list(reversed(f_maps[1:])) | |
| for i in range(len(reversed_f_maps) - 1): | |
| if basic_module == DoubleConv and upsample != 'deconv': | |
| in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] | |
| else: | |
| in_feature_num = reversed_f_maps[i] | |
| out_feature_num = reversed_f_maps[i + 1] | |
| decoder = Decoder(in_feature_num, out_feature_num, | |
| basic_module=basic_module, | |
| conv_layer_order=layer_order, | |
| conv_kernel_size=conv_kernel_size, | |
| num_groups=num_groups, | |
| padding=conv_padding, | |
| upsample=upsample, | |
| dropout_prob=dropout_prob, | |
| is3d=is3d) | |
| decoders.append(decoder) | |
| return nn.ModuleList(decoders) | |
| class AbstractUpsampling(nn.Module): | |
| """ | |
| Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either | |
| interpolation or learned transposed convolution. | |
| """ | |
| def __init__(self, upsample): | |
| super(AbstractUpsampling, self).__init__() | |
| self.upsample = upsample | |
| def forward(self, encoder_features, x): | |
| # get the spatial dimensions of the output given the encoder_features | |
| output_size = encoder_features.size()[2:] | |
| # upsample the input and return | |
| return self.upsample(x, output_size) | |
| class InterpolateUpsampling(AbstractUpsampling): | |
| """ | |
| Args: | |
| mode (str): algorithm used for upsampling: | |
| 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' | |
| used only if transposed_conv is False | |
| """ | |
| def __init__(self, mode='nearest'): | |
| upsample = partial(self._interpolate, mode=mode) | |
| super().__init__(upsample) | |
| def _interpolate(x, size, mode): | |
| return F.interpolate(x, size=size, mode=mode) | |
| class TransposeConvUpsampling(AbstractUpsampling): | |
| """ | |
| Args: | |
| in_channels (int): number of input channels for transposed conv | |
| used only if transposed_conv is True | |
| out_channels (int): number of output channels for transpose conv | |
| used only if transposed_conv is True | |
| kernel_size (int or tuple): size of the convolving kernel | |
| used only if transposed_conv is True | |
| scale_factor (int or tuple): stride of the convolution | |
| used only if transposed_conv is True | |
| is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d | |
| """ | |
| class Upsample(nn.Module): | |
| """ | |
| Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in | |
| transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary. | |
| """ | |
| def __init__(self, conv_transposed, is3d): | |
| super().__init__() | |
| self.conv_transposed = conv_transposed | |
| self.is3d = is3d | |
| def forward(self, x, size): | |
| x = self.conv_transposed(x) | |
| return F.interpolate(x, size=size) | |
| def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True): | |
| # make sure that the output size reverses the MaxPool3d from the corresponding encoder | |
| if is3d is True: | |
| conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=scale_factor, padding=1, bias=False) | |
| else: | |
| conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=scale_factor, padding=1, bias=False) | |
| upsample = self.Upsample(conv_transposed, is3d) | |
| super().__init__(upsample) | |
| class NoUpsampling(AbstractUpsampling): | |
| def __init__(self): | |
| super().__init__(self._no_upsampling) | |
| def _no_upsampling(x, size): | |
| return x |