Spaces:
Runtime error
Runtime error
| '''Convnet encoder module. | |
| ''' | |
| import torch | |
| import torch.nn as nn | |
| #from cortex.built_ins.networks.utils import get_nonlinearity | |
| from cortex_DIM.nn_modules.misc import Fold, Unfold, View | |
| def infer_conv_size(w, k, s, p): | |
| '''Infers the next size after convolution. | |
| Args: | |
| w: Input size. | |
| k: Kernel size. | |
| s: Stride. | |
| p: Padding. | |
| Returns: | |
| int: Output size. | |
| ''' | |
| x = (w - k + 2 * p) // s + 1 | |
| return x | |
| class Convnet(nn.Module): | |
| '''Basic convnet convenience class. | |
| Attributes: | |
| conv_layers: nn.Sequential of nn.Conv2d layers with batch norm, | |
| dropout, nonlinearity. | |
| fc_layers: nn.Sequential of nn.Linear layers with batch norm, | |
| dropout, nonlinearity. | |
| reshape: Simple reshape layer. | |
| conv_shape: Shape of the convolutional output. | |
| ''' | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.create_layers(*args, **kwargs) | |
| def create_layers(self, shape, conv_args=None, fc_args=None): | |
| '''Creates layers | |
| conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) | |
| fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) | |
| Args: | |
| shape: Shape of input. | |
| conv_args: List of tuple of convolutional arguments. | |
| fc_args: List of tuple of fully-connected arguments. | |
| ''' | |
| self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) | |
| dim_x, dim_y, dim_out = self.conv_shape | |
| dim_r = dim_x * dim_y * dim_out | |
| self.reshape = View(-1, dim_r) | |
| self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) | |
| def create_conv_layers(self, shape, conv_args): | |
| '''Creates a set of convolutional layers. | |
| Args: | |
| shape: Input shape. | |
| conv_args: List of tuple of convolutional arguments. | |
| Returns: | |
| nn.Sequential: a sequence of convolutional layers. | |
| ''' | |
| conv_layers = nn.Sequential() | |
| conv_args = conv_args or [] | |
| dim_x, dim_y, dim_in = shape | |
| for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): | |
| name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) | |
| conv_block = nn.Sequential() | |
| if dim_out is not None: | |
| conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) | |
| conv_block.add_module(name + 'conv', conv) | |
| dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) | |
| else: | |
| dim_out = dim_in | |
| if dropout: | |
| conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) | |
| if batch_norm: | |
| bn = nn.BatchNorm2d(dim_out) | |
| conv_block.add_module(name + 'bn', bn) | |
| if nonlinearity: | |
| nonlinearity = get_nonlinearity(nonlinearity) | |
| conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) | |
| if pool: | |
| (pool_type, kernel, stride) = pool | |
| Pool = getattr(nn, pool_type) | |
| conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride)) | |
| dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) | |
| conv_layers.add_module(name, conv_block) | |
| dim_in = dim_out | |
| dim_out = dim_in | |
| return conv_layers, (dim_x, dim_y, dim_out) | |
| def create_linear_layers(self, dim_in, fc_args): | |
| ''' | |
| Args: | |
| dim_in: Number of input units. | |
| fc_args: List of tuple of fully-connected arguments. | |
| Returns: | |
| nn.Sequential. | |
| ''' | |
| fc_layers = nn.Sequential() | |
| fc_args = fc_args or [] | |
| for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args): | |
| name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) | |
| fc_block = nn.Sequential() | |
| if dim_out is not None: | |
| fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out)) | |
| else: | |
| dim_out = dim_in | |
| if dropout: | |
| fc_block.add_module(name + 'do', nn.Dropout(p=dropout)) | |
| if batch_norm: | |
| bn = nn.BatchNorm1d(dim_out) | |
| fc_block.add_module(name + 'bn', bn) | |
| if nonlinearity: | |
| nonlinearity = get_nonlinearity(nonlinearity) | |
| fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity) | |
| fc_layers.add_module(name, fc_block) | |
| dim_in = dim_out | |
| return fc_layers, dim_in | |
| def next_size(self, dim_x, dim_y, k, s, p): | |
| '''Infers the next size of a convolutional layer. | |
| Args: | |
| dim_x: First dimension. | |
| dim_y: Second dimension. | |
| k: Kernel size. | |
| s: Stride. | |
| p: Padding. | |
| Returns: | |
| (int, int): (First output dimension, Second output dimension) | |
| ''' | |
| if isinstance(k, int): | |
| kx, ky = (k, k) | |
| else: | |
| kx, ky = k | |
| if isinstance(s, int): | |
| sx, sy = (s, s) | |
| else: | |
| sx, sy = s | |
| if isinstance(p, int): | |
| px, py = (p, p) | |
| else: | |
| px, py = p | |
| return (infer_conv_size(dim_x, kx, sx, px), | |
| infer_conv_size(dim_y, ky, sy, py)) | |
| def forward(self, x: torch.Tensor, return_full_list=False): | |
| '''Forward pass | |
| Args: | |
| x: Input. | |
| return_full_list: Optional, returns all layer outputs. | |
| Returns: | |
| torch.Tensor or list of torch.Tensor. | |
| ''' | |
| if return_full_list: | |
| conv_out = [] | |
| for conv_layer in self.conv_layers: | |
| x = conv_layer(x) | |
| conv_out.append(x) | |
| else: | |
| conv_out = self.conv_layers(x) | |
| x = conv_out | |
| x = self.reshape(x) | |
| if return_full_list: | |
| fc_out = [] | |
| for fc_layer in self.fc_layers: | |
| x = fc_layer(x) | |
| fc_out.append(x) | |
| else: | |
| fc_out = self.fc_layers(x) | |
| return conv_out, fc_out | |
| class FoldedConvnet(Convnet): | |
| '''Convnet with strided crop input. | |
| ''' | |
| def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None): | |
| '''Creates layers | |
| conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) | |
| fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) | |
| Args: | |
| shape: Shape of input. | |
| crop_size: Size of crops | |
| conv_args: List of tuple of convolutional arguments. | |
| fc_args: List of tuple of fully-connected arguments. | |
| ''' | |
| self.crop_size = crop_size | |
| dim_x, dim_y, dim_in = shape | |
| if dim_x != dim_y: | |
| raise ValueError('x and y dimensions must be the same to use Folded encoders.') | |
| self.final_size = 2 * (dim_x // self.crop_size) - 1 | |
| self.unfold = Unfold(dim_x, self.crop_size) | |
| self.refold = Fold(dim_x, self.crop_size) | |
| shape = (self.crop_size, self.crop_size, dim_in) | |
| self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) | |
| dim_x, dim_y, dim_out = self.conv_shape | |
| dim_r = dim_x * dim_y * dim_out | |
| self.reshape = View(-1, dim_r) | |
| self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) | |
| def create_conv_layers(self, shape, conv_args): | |
| '''Creates a set of convolutional layers. | |
| Args: | |
| shape: Input shape. | |
| conv_args: List of tuple of convolutional arguments. | |
| Returns: | |
| nn.Sequential: A sequence of convolutional layers. | |
| ''' | |
| conv_layers = nn.Sequential() | |
| conv_args = conv_args or [] | |
| dim_x, dim_y, dim_in = shape | |
| for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): | |
| name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) | |
| conv_block = nn.Sequential() | |
| if dim_out is not None: | |
| conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) | |
| conv_block.add_module(name + 'conv', conv) | |
| dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) | |
| else: | |
| dim_out = dim_in | |
| if dropout: | |
| conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) | |
| if batch_norm: | |
| bn = nn.BatchNorm2d(dim_out) | |
| conv_block.add_module(name + 'bn', bn) | |
| if nonlinearity: | |
| nonlinearity = get_nonlinearity(nonlinearity) | |
| conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) | |
| if pool: | |
| (pool_type, kernel, stride) = pool | |
| Pool = getattr(nn, pool_type) | |
| conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride)) | |
| dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) | |
| conv_layers.add_module(name, conv_block) | |
| dim_in = dim_out | |
| if dim_x != dim_y: | |
| raise ValueError('dim_x and dim_y do not match.') | |
| if dim_x == 1: | |
| dim_x = self.final_size | |
| dim_y = self.final_size | |
| dim_out = dim_in | |
| return conv_layers, (dim_x, dim_y, dim_out) | |
| def forward(self, x: torch.Tensor, return_full_list=False): | |
| '''Forward pass | |
| Args: | |
| x: Input. | |
| return_full_list: Optional, returns all layer outputs. | |
| Returns: | |
| torch.Tensor or list of torch.Tensor. | |
| ''' | |
| x = self.unfold(x) | |
| conv_out = [] | |
| for conv_layer in self.conv_layers: | |
| x = conv_layer(x) | |
| if x.size(2) == 1: | |
| x = self.refold(x) | |
| conv_out.append(x) | |
| x = self.reshape(x) | |
| if return_full_list: | |
| fc_out = [] | |
| for fc_layer in self.fc_layers: | |
| x = fc_layer(x) | |
| fc_out.append(x) | |
| else: | |
| fc_out = self.fc_layers(x) | |
| if not return_full_list: | |
| conv_out = conv_out[-1] | |
| return conv_out, fc_out | |