Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import math | |
| import torch.nn.functional as F | |
| from config import (ds1_in_channels, ds1_out_channels, ds2_in_channels, ds2_out_channels, | |
| ds3_in_channels, ds3_out_channels, ds4_in_channels, ds4_out_channels, | |
| es1_in_channels, es1_out_channels, es2_in_channels, | |
| es2_out_channels, es3_in_channels, es3_out_channels, | |
| es4_in_channels, es4_out_channels, n_groupnorm_groups, shift_size, | |
| timestep_embed_dim, initial_conv_out_channels, num_heads, window_size, | |
| in_channels, dropout, mlp_ratio, swin_embed_dim, use_scale_shift_norm, | |
| attention_resolutions, image_size) | |
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
| """Truncated normal initialization.""" | |
| def norm_cdf(x): | |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
| with torch.no_grad(): | |
| l = norm_cdf((a - mean) / std) | |
| u = norm_cdf((b - mean) / std) | |
| tensor.uniform_(2 * l - 1, 2 * u - 1) | |
| tensor.erfinv_() | |
| tensor.mul_(std * math.sqrt(2.)) | |
| tensor.add_(mean) | |
| tensor.clamp_(min=a, max=b) | |
| return tensor | |
| def sinusoidal_embedding(timesteps, dim=timestep_embed_dim): | |
| """ | |
| timesteps: (B,) int64 tensor | |
| dim: embedding dimension | |
| returns: (B, dim) tensor | |
| Just like how positional encodings are there in Transformers | |
| """ | |
| device = timesteps.device | |
| half = dim // 2 | |
| freq = torch.exp(-math.log(10000) * torch.arange(half, device=device) / (half - 1)) | |
| args = timesteps[:, None] * freq[None, :] | |
| return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # (B, dim) | |
| class TimeEmbeddingMLP(nn.Module): | |
| def __init__(self, emb_dim, out_channels): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(emb_dim, out_channels), | |
| nn.SiLU(), | |
| nn.Linear(out_channels, out_channels) | |
| ) | |
| def forward(self, t_emb): | |
| return self.mlp(t_emb) # (B, out_channels) | |
| class InitialConv(nn.Module): | |
| ''' | |
| Input : We get input image concatenated with LR image (6 channels total) | |
| Output: We send it to Encoder stage 1 | |
| ''' | |
| def __init__(self, input_channels=None): | |
| ''' | |
| Input Shape --> [256 x 256 x input_channels] | |
| Output Shape --> [256 x 256 x initial_conv_out_channels] | |
| ''' | |
| super().__init__() | |
| if input_channels is None: | |
| input_channels = in_channels | |
| self.net = nn.Conv2d(in_channels=input_channels, out_channels=initial_conv_out_channels, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ResidualBlock(nn.Module): | |
| ''' | |
| Inside the Residual block, channels remain same | |
| Input : From previous Encoder stage / Initial Conv | |
| Output : Downsampling block and we save skip connection for correspoding decoder stage | |
| ''' | |
| def __init__(self, in_channels, out_channels, sin_embed_dim = timestep_embed_dim, dropout_rate=dropout, use_scale_shift=use_scale_shift_norm): | |
| ''' | |
| This ResBlock will be used by following inchannels [64, 128, 256, 512] | |
| This ResBlock will be used by following outchannels [64, 128, 256, 512] | |
| ''' | |
| super().__init__() | |
| self.use_scale_shift = use_scale_shift | |
| ## 1st res block (in_layers) | |
| self.norm1 = nn.GroupNorm(num_groups = n_groupnorm_groups, num_channels = in_channels) ## num_groups 8 are standard it seems | |
| self.act1 = nn.SiLU() | |
| self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size=3, stride=1, padding=1) | |
| ## timestamp embedding MLP | |
| # If use_scale_shift_norm, output 2*out_channels (for scale and shift) | |
| # Otherwise, output out_channels (for additive) | |
| embed_out_dim = 2 * out_channels if use_scale_shift else out_channels | |
| self.MLP_embed = TimeEmbeddingMLP(sin_embed_dim, out_channels=embed_out_dim) | |
| ## 2nd res block (out_layers) | |
| self.norm2 = nn.GroupNorm(num_groups = n_groupnorm_groups, num_channels = out_channels) ## num_groups 8 are standard it seems | |
| self.act2 = nn.SiLU() | |
| self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate > 0 else nn.Identity() | |
| self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size=3, stride=1, padding=1) | |
| ## skip connection | |
| self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() | |
| def forward(self, x, t_emb): ## t_emb is pre-computed time embedding (B, timestep_embed_dim) | |
| # in_layers: norm -> SiLU -> conv | |
| h = self.conv1(self.act1(self.norm1(x))) | |
| # Time embedding conditioning | |
| emb_out = self.MLP_embed(t_emb) # (B, embed_out_dim) | |
| while len(emb_out.shape) < len(h.shape): | |
| emb_out = emb_out[..., None, None] # (B, embed_out_dim, 1, 1) | |
| if self.use_scale_shift: | |
| # FiLM conditioning: h = norm(h) * (1 + scale) + shift | |
| scale, shift = torch.chunk(emb_out, 2, dim=1) # Each (B, out_channels, 1, 1) | |
| h = self.norm2(h) * (1 + scale) + shift | |
| h = self.act2(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| else: | |
| # Additive conditioning: h = h + emb_out | |
| h = h + emb_out | |
| h = self.conv2(self.dropout(self.act2(self.norm2(h)))) | |
| return h + self.skip(x) | |
| class Downsample(nn.Module): | |
| ''' | |
| A downsampling layer using strided convolution. | |
| Reduces spatial resolution by half (stride=2) while keeping channels the same. | |
| Note: Channel changes happen in ResBlocks, not in this downsample layer. | |
| This matches the original ResShift implementation when conv_resample=True. | |
| Input: From each encoder stage | |
| Output: To next encoder stage (same channels, half resolution) | |
| ''' | |
| def __init__(self, in_channels, out_channels): | |
| ''' | |
| Args: | |
| in_channels: Input channel count | |
| out_channels: Output channel count (should equal in_channels in our usage) | |
| ''' | |
| super().__init__() | |
| # Strided convolution: 3x3 conv with stride=2, padding=1 | |
| # This halves the spatial resolution (e.g., 64x64 -> 32x32) | |
| self.net = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 2, padding = 1) | |
| def forward(self, x): | |
| return self.net(x) | |
| class EncoderStage(nn.Module): | |
| ''' | |
| Combine ResBlock and downsample here | |
| x --> resolution | |
| y --> channels | |
| Input: [y, x, x] | |
| Output: [2y, x/2, x/2] | |
| ''' | |
| def __init__(self, in_channels, out_channels, downsample = True, resolution=None, use_attention=False): | |
| super().__init__() | |
| self.res1 = ResidualBlock(in_channels = in_channels, out_channels = out_channels) | |
| # Add attention after first res block if resolution matches and use_attention is True | |
| self.attention = None | |
| if use_attention and resolution in attention_resolutions: | |
| # Create BasicLayer equivalent: 2 SwinTransformerBlocks (one with shift=0, one with shift=window_size//2) | |
| self.attention = nn.Sequential( | |
| SwinTransformerBlock(in_channels=out_channels, num_heads=num_heads, shift_size=0, | |
| embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio), | |
| SwinTransformerBlock(in_channels=out_channels, num_heads=num_heads, shift_size=window_size // 2, | |
| embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio) | |
| ) | |
| self.res2 = ResidualBlock(in_channels = out_channels, out_channels = out_channels) | |
| # handling this for the last part of the encoder stage 4 | |
| # Downsample only reduces spatial resolution, keeps channels the same | |
| # Channel changes happen in ResBlocks, not in downsample | |
| self.do_downsample = Downsample(out_channels, out_channels) if downsample else nn.Identity() | |
| self.downsample = self.do_downsample | |
| def forward(self, x, t_emb): | |
| out = self.res1(x, t_emb) ## here out is h + skip(x) | |
| # Apply attention if present (attention doesn't use t_emb) | |
| if self.attention is not None: | |
| out = self.attention(out) | |
| out_skipconnection = self.res2(out, t_emb) | |
| # print(f'The shape after Encoder Stage before downsampling is {out.squeeze(dim = 0).shape}') | |
| out_downsampled = self.downsample(out_skipconnection) | |
| # print(f'The shape after Encoder Stage after downsampling is {out.squeeze(dim = 0).shape}') | |
| return out_downsampled, out_skipconnection | |
| class FullEncoderModule(nn.Module): | |
| ''' | |
| connect all 4 encoder stages(for now) | |
| ''' | |
| def __init__(self, input_channels=None): | |
| ''' | |
| Passing through Encoder stages 1 by 1 | |
| Args: | |
| input_channels: Number of input channels (default: in_channels from config) | |
| ''' | |
| super().__init__() | |
| if input_channels is None: | |
| input_channels = in_channels | |
| self.initial_conv = InitialConv(input_channels=input_channels) | |
| # Add attention after initial conv if 64x64 is in attention_resolutions | |
| self.attention_initial = None | |
| if image_size in attention_resolutions: | |
| self.attention_initial = nn.Sequential( | |
| SwinTransformerBlock(in_channels=initial_conv_out_channels, num_heads=num_heads, shift_size=0, | |
| embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio), | |
| SwinTransformerBlock(in_channels=initial_conv_out_channels, num_heads=num_heads, shift_size=window_size // 2, | |
| embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio) | |
| ) | |
| # Track resolutions: after initial_conv=64, after stage1=32, after stage2=16, after stage3=8, after stage4=8 | |
| self.encoderstage_1 = EncoderStage(es1_in_channels, es1_out_channels, downsample=True, resolution=image_size, use_attention=True) | |
| self.encoderstage_2 = EncoderStage(es2_in_channels, es2_out_channels, downsample=True, resolution=image_size // 2, use_attention=True) | |
| self.encoderstage_3 = EncoderStage(es3_in_channels, es3_out_channels, downsample=True, resolution=image_size // 4, use_attention=True) | |
| self.encoderstage_4 = EncoderStage(es4_in_channels, es4_out_channels, downsample=False, resolution=image_size // 8, use_attention=True) | |
| def forward(self, x, t_emb): | |
| out = self.initial_conv(x) | |
| # Apply attention after initial conv if present | |
| if self.attention_initial is not None: | |
| out = self.attention_initial(out) | |
| out_1, skip_1 = self.encoderstage_1(out, t_emb) | |
| #print(f'The shape after Encoder Stage 1 after downsampling is {out_1.shape}') | |
| out_2, skip_2 = self.encoderstage_2(out_1, t_emb) | |
| #print(f'The shape after Encoder Stage 2 after downsampling is {out_2.shape}') | |
| out_3, skip_3 = self.encoderstage_3(out_2, t_emb) | |
| #print(f'The shape after Encoder Stage 3 after downsampling is {out_3.shape}') | |
| out_4, skip_4 = self.encoderstage_4(out_3, t_emb) | |
| #print(f'The shape after Encoder Stage 4 is {out_4.shape}') | |
| # i think we should return these for correspoding decoder stages | |
| return (out_1, skip_1), (out_2, skip_2), (out_3, skip_3), (out_4, skip_4) | |
| class WindowAttention(nn.Module): | |
| """ | |
| Window based multi-head self attention (W-MSA) module with relative position bias. | |
| Supports both shifted and non-shifted windows. | |
| """ | |
| def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| self.dim = dim | |
| self.window_size = window_size if isinstance(window_size, (tuple, list)) else (window_size, window_size) | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| # Relative position bias table | |
| self.relative_position_bias_table = nn.Parameter( | |
| torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) | |
| ) | |
| # Get pair-wise relative position index for each token inside the window | |
| coords_h = torch.arange(self.window_size[0]) | |
| coords_w = torch.arange(self.window_size[1]) | |
| coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww | |
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |
| relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 | |
| relative_coords[:, :, 1] += self.window_size[1] - 1 | |
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |
| relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |
| self.register_buffer("relative_position_index", relative_position_index) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| # Initialize relative position bias | |
| trunc_normal_(self.relative_position_bias_table, std=.02) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x, mask=None): | |
| """ | |
| Args: | |
| x: input features with shape of (num_windows*B, N, C) | |
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |
| """ | |
| B_, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() | |
| q, k, v = qkv[0], qkv[1], qkv[2] # B_ x H x N x C | |
| q = q * self.scale | |
| attn = (q @ k.transpose(-2, -1).contiguous()) | |
| # Add relative position bias | |
| relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |
| self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 | |
| ) # Wh*Ww, Wh*Ww, nH | |
| relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |
| attn = attn + relative_position_bias.unsqueeze(0).to(attn.dtype) | |
| # Apply mask if provided | |
| if mask is not None: | |
| nW = mask.shape[0] | |
| attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) | |
| attn = attn.view(-1, self.num_heads, N, N) | |
| attn = self.softmax(attn) | |
| else: | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).contiguous().reshape(B_, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class SwinTransformerBlock(nn.Module): | |
| def __init__(self, in_channels, num_heads = num_heads, shift_size=0, embed_dim=None, mlp_ratio_val=None): | |
| ''' | |
| As soon as the input image comes (512 x 32 x 32), we divide this into | |
| 16 patches of 512 x 7 x 7 | |
| Each patch is then flattented and it becomes (49 x 512) | |
| Now think of this as 49 tokens having 512 embedding dim vector. Usually a feature map is representation of pixel in embedding. | |
| If we say 3 x 4 x 4, that means each pixel is represented in 3 dim vector. Here, 49 pixels/tokens are represented in 512 dim. | |
| we will have an embedding layer for this. | |
| ''' | |
| super().__init__() | |
| self.window_size = window_size | |
| self.shift_size = shift_size | |
| self.num_heads = num_heads # Store num_heads for mask generation | |
| # Use embed_dim from config if provided, otherwise use in_channels | |
| self.embed_dim = embed_dim if embed_dim is not None else swin_embed_dim | |
| self.mlp_ratio = mlp_ratio_val if mlp_ratio_val is not None else mlp_ratio | |
| # Projection layers if embed_dim differs from in_channels | |
| if self.embed_dim != in_channels: | |
| self.proj_in = nn.Conv2d(in_channels, self.embed_dim, kernel_size=1) | |
| self.proj_out = nn.Conv2d(self.embed_dim, in_channels, kernel_size=1) | |
| else: | |
| self.proj_in = nn.Identity() | |
| self.proj_out = nn.Identity() | |
| # Use custom WindowAttention with relative position bias | |
| self.attn = WindowAttention( | |
| dim=self.embed_dim, | |
| window_size=self.window_size, | |
| num_heads=num_heads, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0., | |
| proj_drop=0. | |
| ) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(self.embed_dim, int(self.embed_dim * self.mlp_ratio)), | |
| nn.GELU(), | |
| nn.Linear(int(self.embed_dim * self.mlp_ratio), self.embed_dim) | |
| ) | |
| self.norm1 = nn.LayerNorm(self.embed_dim) | |
| self.norm2 = nn.LayerNorm(self.embed_dim) | |
| # Attention mask for shifted windows | |
| if self.shift_size > 0: | |
| # Will be computed in forward based on input size | |
| self.register_buffer("attn_mask", None, persistent=False) | |
| else: | |
| self.attn_mask = None | |
| def get_windowed_tokens(self, x): | |
| ''' | |
| In a window, how many pixels/tokens are there and what is its representation in terms of vec | |
| ''' | |
| B, C, H, W = x.size() | |
| ws = self.window_size | |
| # move channel to last dim to make reshaping intuitive | |
| x = x.permute(0, 2, 3, 1).contiguous() # (B, H, W, C) | |
| # reshape into blocks: (B, H//ws, ws, W//ws, ws, C) | |
| x = x.view(B, H // ws, ws, W // ws, ws, C) | |
| # reorder to (B, num_h, num_w, ws, ws, C) | |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous() # (B, Nh, Nw, ws, ws, C) | |
| # merge windows: (B * Nh * Nw, ws * ws, C) | |
| windows_tokens = x.view(-1, ws * ws, C) | |
| return windows_tokens | |
| def window_reverse(self, windows, H, W, B): | |
| """Merge windows back to feature map.""" | |
| ws = self.window_size | |
| num_windows_h = H // ws | |
| num_windows_w = W // ws | |
| x = windows.view(B, num_windows_h, num_windows_w, ws, ws, -1) | |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |
| return x.permute(0, 3, 1, 2).contiguous() # (B, C, H, W) | |
| def calculate_mask(self, H, W, device): | |
| """Calculate attention mask for SW-MSA.""" | |
| if self.shift_size == 0: | |
| return None | |
| img_mask = torch.zeros((1, H, W, 1), device=device) | |
| h_slices = (slice(0, -self.window_size), | |
| slice(-self.window_size, -self.shift_size), | |
| slice(-self.shift_size, None)) | |
| w_slices = (slice(0, -self.window_size), | |
| slice(-self.window_size, -self.shift_size), | |
| slice(-self.shift_size, None)) | |
| cnt = 0 | |
| for h in h_slices: | |
| for w in w_slices: | |
| img_mask[:, h, w, :] = cnt | |
| cnt += 1 | |
| # Convert to (B, C, H, W) format for window_partition | |
| img_mask = img_mask.permute(0, 3, 1, 2).contiguous() # (1, 1, H, W) | |
| mask_windows = self.get_windowed_tokens(img_mask) # (num_windows, ws*ws, 1) | |
| mask_windows = mask_windows.squeeze(-1) # (num_windows, ws*ws) | |
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |
| attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) | |
| # shape: (num_windows, ws*ws, ws*ws) | |
| return attn_mask | |
| def forward(self, x): | |
| # pad the input first(since we are using 7x7 window, we gotta make our image from 32x32 to 35x35) | |
| ''' | |
| Here there are two types of swin blocks. | |
| 1. Windowed swin block | |
| 2. shifted windowed swin block | |
| In our code we use both these blocks one after the other. The difference is the first computes local attention, without shifting. | |
| The second, shifts first, them computes local attention, then shifts it back. | |
| ''' | |
| B, C, H, W = x.size() | |
| # Project to embed_dim if needed | |
| x = self.proj_in(x) # (B, embed_dim, H, W) | |
| C_emb = x.shape[1] | |
| # Save shortcut AFTER projection (in embed_dim space for residual) | |
| shortcut = x | |
| # Pad if needed | |
| pad_r = (self.window_size - W % self.window_size) % self.window_size | |
| pad_b = (self.window_size - H % self.window_size) % self.window_size | |
| if pad_r > 0 or pad_b > 0: | |
| x = F.pad(x, (0, pad_r, 0, pad_b)) | |
| shortcut = F.pad(shortcut, (0, pad_r, 0, pad_b)) | |
| H_pad, W_pad = x.shape[2], x.shape[3] | |
| # Normalize BEFORE windowing (original behavior) | |
| # Convert to (B, H, W, C) for LayerNorm | |
| x_norm = x.permute(0, 2, 3, 1).contiguous() # (B, H, W, C) | |
| x_norm = self.norm1(x_norm) # Normalize spatial features | |
| x_norm = x_norm.permute(0, 3, 1, 2).contiguous() # (B, C, H, W) | |
| # Cyclic shift | |
| if self.shift_size > 0: | |
| shifted_x = torch.roll(x_norm, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3)) | |
| else: | |
| shifted_x = x_norm | |
| # Partition windows | |
| x_windows = self.get_windowed_tokens(shifted_x) # (num_windows*B, ws*ws, C) | |
| # Calculate mask for shifted windows | |
| if self.shift_size > 0: | |
| mask = self.calculate_mask(H_pad, W_pad, x.device) | |
| else: | |
| mask = None | |
| # W-MSA/SW-MSA | |
| attn_windows = self.attn(x_windows, mask=mask) # (num_windows*B, ws*ws, C) | |
| # Merge windows | |
| attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C_emb) | |
| shifted_x = self.window_reverse(attn_windows, H_pad, W_pad, B) # (B, C, H, W) | |
| # Reverse cyclic shift | |
| if self.shift_size > 0: | |
| x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(2, 3)) | |
| else: | |
| x = shifted_x | |
| # Crop padding | |
| if pad_r > 0 or pad_b > 0: | |
| x = x[:, :, :H_pad, :W_pad] | |
| shortcut = shortcut[:, :, :H_pad, :W_pad] | |
| # Residual connection around attention (original: shortcut + drop_path(x)) | |
| x = shortcut + x # Add in embed_dim space | |
| # FFN | |
| # Convert to (B, H, W, C) for LayerNorm | |
| x_norm2 = x.permute(0, 2, 3, 1).contiguous() # (B, H, W, C) | |
| x_norm2 = self.norm2(x_norm2) # (B, H, W, C) | |
| x_mlp = self.mlp(x_norm2) # (B, H, W, C) | |
| x_mlp = x_mlp.permute(0, 3, 1, 2).contiguous() # (B, C, H, W) | |
| # Residual connection around MLP | |
| x = x + x_mlp | |
| # Project back to in_channels if needed | |
| if self.embed_dim != C: | |
| x = self.proj_out(x) # (B, in_channels, H, W) | |
| # Crop to original size | |
| x = x[:, :, :H, :W] | |
| return x | |
| class Bottleneck(nn.Module): | |
| def __init__(self, in_channels = es4_out_channels, out_channels = ds1_in_channels): | |
| super().__init__() | |
| self.res1 = ResidualBlock(in_channels = in_channels, out_channels = out_channels) | |
| # Use swin_embed_dim from config for projection | |
| self.swintransformer1 = SwinTransformerBlock(in_channels = out_channels, num_heads = num_heads, shift_size=0, embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio) | |
| self.swintransformer2 = SwinTransformerBlock(in_channels = out_channels, num_heads = num_heads, shift_size=window_size // 2, embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio) | |
| self.res2 = ResidualBlock(in_channels = out_channels, out_channels = out_channels) | |
| def forward(self, x, t_emb): | |
| res_out = self.res1(x, t_emb) | |
| swin_out_1 = self.swintransformer1(res_out) | |
| # print(f'swin_out_1 shape is {swin_out_1.shape}') | |
| swin_out_2 = self.swintransformer2(swin_out_1) | |
| # print(f'swin_out_2 shape is {swin_out_2.shape}') | |
| res_out_2 = self.res2(swin_out_2, t_emb) | |
| return res_out_2 | |
| class Upsample(nn.Module): | |
| ''' | |
| Just increases resolution | |
| Input: From each decoder stage | |
| Output: To next decoder stage | |
| ''' | |
| def __init__(self, in_channels, out_channels): | |
| ''' | |
| Our target is to half the resolution and double the channels | |
| ''' | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class DecoderStage(nn.Module): | |
| """ | |
| Decoder block: | |
| - Optional upsample | |
| - Concatenate skip (channel dimension doubles) | |
| - Two residual blocks | |
| """ | |
| def __init__(self, in_channels, skip_channels, out_channels, upsample=True, resolution=None, use_attention=False): | |
| super().__init__() | |
| # Upsample first, but keep same number of channels | |
| self.upsample = Upsample(in_channels, in_channels) if upsample else nn.Identity() | |
| # | |
| merged_channels = in_channels + skip_channels | |
| # First ResBlock processes merged tensor | |
| self.res1 = ResidualBlock(in_channels = merged_channels, out_channels=out_channels) | |
| # Add attention after first res block if resolution matches and use_attention is True | |
| self.attention = None | |
| if use_attention and resolution in attention_resolutions: | |
| # Create BasicLayer equivalent: 2 SwinTransformerBlocks (one with shift=0, one with shift=window_size//2) | |
| self.attention = nn.Sequential( | |
| SwinTransformerBlock(in_channels=out_channels, num_heads=num_heads, shift_size=0, | |
| embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio), | |
| SwinTransformerBlock(in_channels=out_channels, num_heads=num_heads, shift_size=window_size // 2, | |
| embed_dim=swin_embed_dim, mlp_ratio_val=mlp_ratio) | |
| ) | |
| # Second ResBlock keeps output channels the same | |
| self.res2 = ResidualBlock(in_channels=out_channels, out_channels=out_channels) | |
| def forward(self, x, skip, t_emb): | |
| """ | |
| x : (B, C, H, W) decoder input | |
| skip : (B, C_skip, H, W) encoder skip feature | |
| t_emb: (B, timestep_embed_dim) pre-computed time embedding | |
| """ | |
| x = self.upsample(x) # optional upsample | |
| x = torch.cat([x, skip], dim=1) # concat along channels | |
| x = self.res1(x, t_emb) | |
| # Apply attention if present (attention doesn't use t_emb) | |
| if self.attention is not None: | |
| x = self.attention(x) | |
| x = self.res2(x, t_emb) | |
| return x | |
| class FullDecoderModule(nn.Module): | |
| ''' | |
| connect all 4 encoder stages(for now) | |
| ''' | |
| def __init__(self): | |
| ''' | |
| Passing through Encoder stages 1 by 1 | |
| ''' | |
| super().__init__() | |
| # Track resolutions: after bottleneck=8, after stage1=8, after stage2=16, after stage3=32, after stage4=64 | |
| self.decoderstage_1 = DecoderStage(in_channels = ds1_in_channels, skip_channels=es4_out_channels, out_channels= ds1_out_channels, upsample=False, resolution=image_size // 8, use_attention=True) | |
| self.decoderstage_2 = DecoderStage(in_channels = ds2_in_channels, skip_channels=es3_out_channels, out_channels=ds2_out_channels, upsample=True, resolution=image_size // 4, use_attention=True) # Adjusted input channels to include skip connection | |
| self.decoderstage_3 = DecoderStage(in_channels = ds3_in_channels, skip_channels=es2_out_channels, out_channels=ds3_out_channels, upsample=True, resolution=image_size // 2, use_attention=True) # Adjusted input channels | |
| self.decoderstage_4 = DecoderStage(in_channels = ds4_in_channels, skip_channels=es1_out_channels, out_channels=ds4_out_channels, upsample=True, resolution=image_size, use_attention=True) # Adjusted input channels | |
| # Add normalization before final conv to match original | |
| self.final_norm = nn.GroupNorm(num_groups=n_groupnorm_groups, num_channels=ds4_out_channels) | |
| self.final_act = nn.SiLU() | |
| self.finalconv = nn.Conv2d(in_channels = ds4_out_channels, out_channels = 3, kernel_size = 3, stride = 1, padding = 1) | |
| def forward(self, bottleneck_output, encoder_outputs, t_emb):# | |
| # Unpack encoder outputs | |
| (out_1_enc, skip_1), (out_2_enc, skip_2), (out_3_enc, skip_3), (out_4_enc, skip_4) = encoder_outputs | |
| # Decoder stages, passing skip connections | |
| out_1_dec = self.decoderstage_1(bottleneck_output, skip_4, t_emb) # First decoder stage uses the bottleneck output last encoder output | |
| #print(f'The shape after Decoder Stage 1 is {out_1_dec.shape}') | |
| out_2_dec = self.decoderstage_2(out_1_dec, skip_3, t_emb) # Subsequent stages use previous decoder output and corresponding encoder skip | |
| #print(f'The shape after Decoder Stage 2 after upsampling is {out_2_dec.shape}') | |
| out_3_dec = self.decoderstage_3(out_2_dec, skip_2, t_emb) | |
| #print(f'The shape after Decoder Stage 3 after upsampling is {out_3_dec.shape}') | |
| out_4_dec = self.decoderstage_4(out_3_dec, skip_1, t_emb) | |
| #print(f'The shape after Encoder Stage 4 after upsampling is {out_4_dec.shape}') | |
| # Apply normalization and activation before final conv (matching original) | |
| final_out = self.finalconv(self.final_act(self.final_norm(out_4_dec))) | |
| #print(f'The shape after final conv is {final_out.shape}') | |
| return final_out | |
| class FullUNET(nn.Module): | |
| def __init__(self): | |
| """ | |
| Full U-Net model with required LR conditioning. | |
| Concatenates LR image directly with input (assumes same resolution). | |
| """ | |
| super().__init__() | |
| # Input channels = original input (3) + LR image channels (3) = 6 | |
| input_channels = in_channels + in_channels # 3 + 3 = 6 | |
| self.enc = FullEncoderModule(input_channels=input_channels) | |
| self.bottleneck = Bottleneck() | |
| self.dec = FullDecoderModule() | |
| def forward(self, x, t, lq): | |
| """ | |
| Forward pass with required LR conditioning. | |
| Args: | |
| x: (B, C, H, W) Input tensor | |
| t: (B,) Timestep tensor | |
| lq: (B, C_lq, H_lq, W_lq) LR image for conditioning (required, same resolution as x) | |
| Returns: | |
| out: (B, out_channels, H, W) Output tensor | |
| """ | |
| # Compute time embedding once for efficiency | |
| t_emb = sinusoidal_embedding(t) # (B, timestep_embed_dim) | |
| # Concatenate LR image directly with input along channel dimension | |
| # Assumes lq has same spatial dimensions as x | |
| x = torch.cat([x, lq], dim=1) | |
| encoder_outputs = self.enc(x, t_emb) # with pre-computed time embedding | |
| (out_1_enc, skip_1), (out_2_enc, skip_2), (out_3_enc, skip_3), (out_4_enc, skip_4) = encoder_outputs | |
| bottle_neck_output = self.bottleneck(out_4_enc, t_emb) | |
| out = self.dec(bottle_neck_output, encoder_outputs, t_emb) | |
| return out | |