""" UNet model implementation. Matches the architecture from deep-starry/starry/unet/ for loading .chkpt checkpoints. """ import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, n_channels, n_classes, classify_out=True, bilinear=True, depth=4, init_width=64): super().__init__() self.n_channels = n_channels self.n_classes = n_classes self.classify_out = classify_out self.depth = depth factor = 2 if bilinear else 1 self.inc = DoubleConv(n_channels, init_width) self.outc = OutConv(init_width, n_classes) downs = [] ups = [] for d in range(depth): ic = init_width * (2 ** d) oc = ic * 2 if d == depth - 1: oc //= factor downs.append(Down(ic, oc)) for d in range(depth): ic = init_width * (2 ** (depth - d)) oc = ic // 2 if d < depth - 1: oc //= factor ups.append(Up(ic, oc, bilinear)) self.downs = nn.ModuleList(modules=downs) self.ups = nn.ModuleList(modules=ups) def forward(self, input): xs = [] x = self.inc(input) for down in self.downs: xs.append(x) x = down(x) xs.reverse() for i, up in enumerate(self.ups): xi = xs[i] x = up(x, xi) if not self.classify_out: return x logits = self.outc(x) return logits