k-l-lambda's picture
update: export from starry-refactor 2026-02-20 15:25
1958836
"""
UNet model implementation.
Matches the architecture from deep-starry/starry/unet/ for loading .chkpt checkpoints.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, classify_out=True, bilinear=True, depth=4, init_width=64):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.classify_out = classify_out
self.depth = depth
factor = 2 if bilinear else 1
self.inc = DoubleConv(n_channels, init_width)
self.outc = OutConv(init_width, n_classes)
downs = []
ups = []
for d in range(depth):
ic = init_width * (2 ** d)
oc = ic * 2
if d == depth - 1:
oc //= factor
downs.append(Down(ic, oc))
for d in range(depth):
ic = init_width * (2 ** (depth - d))
oc = ic // 2
if d < depth - 1:
oc //= factor
ups.append(Up(ic, oc, bilinear))
self.downs = nn.ModuleList(modules=downs)
self.ups = nn.ModuleList(modules=ups)
def forward(self, input):
xs = []
x = self.inc(input)
for down in self.downs:
xs.append(x)
x = down(x)
xs.reverse()
for i, up in enumerate(self.ups):
xi = xs[i]
x = up(x, xi)
if not self.classify_out:
return x
logits = self.outc(x)
return logits