XYScanNet_Demo / models /XYScanNetP.py
HanzhouLiu
Add application file
b56342d
import numbers
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias, stride = stride)
"""
Borrow from "https://github.com/state-spaces/mamba.git"
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
"""
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# We do matmul and transpose BLH -> HBL at the same time
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None:
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
##########################################################################
## Feed-forward Network
class FFN(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FFN, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
self.win_size = 8
self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
h1, w1 = h//self.win_size, w//self.win_size
x = self.project_in(x)
x = self.dwconv(x)
x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
x_win = x_win * self.modulator
x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
x1, x2 = x.chunk(2, dim=1)
x = x1 * x2
x = self.project_out(x)
return x
##########################################################################
## Gated Depth-wise Feed-forward Network (GDFN)
class GDFN(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(GDFN, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x = self.dwconv(x)
x1, x2 = x.chunk(2, dim=1)
x = F.silu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
##########################################################################
## Resizing modules
class Downsample(nn.Module):
def __init__(self, n_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
def forward(self, x):
return self.body(x)
class Upsample(nn.Module):
def __init__(self, n_feat):
super(Upsample, self).__init__()
self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
def forward(self, x):
return self.body(x)
"""
Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
@inproceedings{Tsai2022Stripformer,
author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
title = {Stripformer: Strip Transformer for Fast Image Deblurring},
booktitle = {ECCV},
year = {2022}
}
"""
class Intra_VSSM(nn.Module):
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
super(Intra_VSSM, self).__init__()
hidden = int(dim*vssm_expansion_factor)
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.mamba = Mamba(d_model=hidden // 2)
def forward_core(self, x):
B, C, H, W = x.size()
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
feature_h = feature_h.view(B * H, W, C//2)
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
feature_v = feature_v.view(B * W, H, C//2)
if H == W:
feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
scan_output = self.mamba(feature)
scan_output = torch.chunk(scan_output, 2, dim=0)
scan_output_h = scan_output[0]
scan_output_v = scan_output[1]
else:
scan_output_h = self.mamba(feature_h)
scan_output_v = self.mamba(feature_v)
scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
return scan_output
def forward(self, x):
x = self.proj_in(x)
x, x_ = self.dwconv(x).chunk(2, dim=1)
x = self.forward_core(x)
x = F.silu(x_) * x
x = self.proj_out(x)
return x
class Inter_VSSM(nn.Module):
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
super(Inter_VSSM, self).__init__()
hidden = int(dim*vssm_expansion_factor)
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.mamba = Mamba(d_model=hidden // 2)
self.sigmoid = nn.Sigmoid()
def forward_core(self, x):
B, C, H, W = x.size()
x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
feature_h_score = feature_h_score.view(B, H, -1)
feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
feature_v_score = feature_v_score.view(B, W, -1)
if H == W:
feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
scan_score = self.mamba(feature_score)
scan_score = torch.chunk(scan_score, 2, dim=0)
scan_score_h = scan_score[0]
scan_score_v = scan_score[1]
else:
scan_score_h = self.mamba(feature_h_score)
scan_score_v = self.mamba(feature_v_score)
scan_score_h = self.sigmoid(scan_score_h)
scan_score_v = self.sigmoid(scan_score_v)
feature_h = feature_h*scan_score_h[:,:,:,None]
feature_v = feature_v*scan_score_v[:,:,:,None]
feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
return output
def forward(self, x):
x = self.proj_in(x)
x, x_ = self.dwconv(x).chunk(2, dim=1)
x = self.forward_core(x)
x = F.silu(x_) * x
x = self.proj_out(x)
return x
##########################################################################
class Strip_VSSB(nn.Module):
def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
super(Strip_VSSB, self).__init__()
self.ssm = ssm
if self.ssm == True:
self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
def forward(self, x):
if self.ssm == True:
x = x + self.intra(self.norm1_ssm(x))
x = x + self.ffn1(self.norm1_ffn(x))
if self.ssm == True:
x = x + self.inter(self.norm2_ssm(x))
x = x + self.ffn2(self.norm2_ffn(x))
return x
##########################################################################
##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
class CLFF(nn.Module):
def __init__(self, dim, dim_n1, dim_n2, bias=False):
super(CLFF, self).__init__()
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
self.log_sigmoid = nn.LogSigmoid()
self.sigmoid = nn.Sigmoid()
def forward(self, x, n1, n2):
x_ = self.conv(x)
n1_ = self.conv_n1(n1)
n2_ = self.conv_n2(n2)
kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
#g = self.sigmoid(x_)
g1 = self.sigmoid(kl_n1)
g2 = self.sigmoid(kl_n2)
#x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
return x
##########################################################################
##---------- StripScanNet -----------------------
class XYScanNetP(nn.Module):
def __init__(self,
inp_channels=3,
out_channels=3,
dim = 144, # 48, 72, 96, 120, 144
num_blocks = [3,3,6],
vssm_expansion_factor = 1, # 1 or 2
ffn_expansion_factor = 1, # 1 or 3
bias = False,
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
):
super(XYScanNetP, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
# self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, inp_img):
# Encoder
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
inp_enc_level2 = self.down1_2(out_enc_level1)
out_enc_level2 = self.encoder_level2(inp_enc_level2)
out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
inp_enc_level3 = self.down2_3(out_enc_level2)
out_enc_level3 = self.encoder_level3(inp_enc_level3)
out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
# Decoder
out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
out_dec_level1 = out_dec_level1_decomp1 + inp_img
return out_dec_level1, out_dec_level1_decomp1, None
def count_parameters(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XYScanNetP().to(device)
print("Model architecture:\n")
print(model)
count_parameters(model)
# Optionally test with a dummy input
dummy_input = torch.randn(1, 3, 256, 256).to(device)
output, _, _ = model(dummy_input)
print(f"Output shape: {output.shape}")
if __name__ == "__main__":
main()