Spaces:
Running
on
Zero
Running
on
Zero
| import numbers | |
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from einops import rearrange, repeat | |
| from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn | |
| try: | |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update | |
| except ImportError: | |
| causal_conv1d_fn, causal_conv1d_update = None, None | |
| try: | |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update | |
| except ImportError: | |
| selective_state_update = None | |
| try: | |
| from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn | |
| except ImportError: | |
| RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None | |
| def to_3d(x): | |
| return rearrange(x, 'b c h w -> b (h w) c') | |
| def to_4d(x, h, w): | |
| return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) | |
| class BiasFree_LayerNorm(nn.Module): | |
| def __init__(self, normalized_shape): | |
| super(BiasFree_LayerNorm, self).__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| normalized_shape = (normalized_shape,) | |
| normalized_shape = torch.Size(normalized_shape) | |
| assert len(normalized_shape) == 1 | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.normalized_shape = normalized_shape | |
| def forward(self, x): | |
| sigma = x.var(-1, keepdim=True, unbiased=False) | |
| return x / torch.sqrt(sigma + 1e-5) * self.weight | |
| class WithBias_LayerNorm(nn.Module): | |
| def __init__(self, normalized_shape): | |
| super(WithBias_LayerNorm, self).__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| normalized_shape = (normalized_shape,) | |
| normalized_shape = torch.Size(normalized_shape) | |
| assert len(normalized_shape) == 1 | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.normalized_shape = normalized_shape | |
| def forward(self, x): | |
| mu = x.mean(-1, keepdim=True) | |
| sigma = x.var(-1, keepdim=True, unbiased=False) | |
| return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias | |
| class LayerNorm(nn.Module): | |
| def __init__(self, dim, LayerNorm_type): | |
| super(LayerNorm, self).__init__() | |
| if LayerNorm_type == 'BiasFree': | |
| self.body = BiasFree_LayerNorm(dim) | |
| else: | |
| self.body = WithBias_LayerNorm(dim) | |
| def forward(self, x): | |
| h, w = x.shape[-2:] | |
| return to_4d(self.body(to_3d(x)), h, w) | |
| ########################################################################## | |
| def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): | |
| return nn.Conv2d( | |
| in_channels, out_channels, kernel_size, | |
| padding=(kernel_size//2), bias=bias, stride = stride) | |
| """ | |
| Borrow from "https://github.com/state-spaces/mamba.git" | |
| @article{mamba, | |
| title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, | |
| author={Gu, Albert and Dao, Tri}, | |
| journal={arXiv preprint arXiv:2312.00752}, | |
| year={2023} | |
| } | |
| """ | |
| class Mamba(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| d_state=16, | |
| d_conv=4, | |
| expand=2, | |
| dt_rank="auto", | |
| dt_min=0.001, | |
| dt_max=0.1, | |
| dt_init="random", | |
| dt_scale=1.0, | |
| dt_init_floor=1e-4, | |
| conv_bias=True, | |
| bias=False, | |
| use_fast_path=True, # Fused kernel options | |
| layer_idx=None, | |
| device=None, | |
| dtype=None, | |
| ): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.d_model = d_model | |
| self.d_state = d_state | |
| self.d_conv = d_conv | |
| self.expand = expand | |
| self.d_inner = int(self.expand * self.d_model) | |
| self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank | |
| self.use_fast_path = use_fast_path | |
| self.layer_idx = layer_idx | |
| self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) | |
| self.conv1d = nn.Conv1d( | |
| in_channels=self.d_inner, | |
| out_channels=self.d_inner, | |
| bias=conv_bias, | |
| kernel_size=d_conv, | |
| groups=self.d_inner, | |
| padding=d_conv - 1, | |
| **factory_kwargs, | |
| ) | |
| self.activation = "silu" | |
| self.act = nn.SiLU() | |
| self.x_proj = nn.Linear( | |
| self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs | |
| ) | |
| self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) | |
| # Initialize special dt projection to preserve variance at initialization | |
| dt_init_std = self.dt_rank**-0.5 * dt_scale | |
| if dt_init == "constant": | |
| nn.init.constant_(self.dt_proj.weight, dt_init_std) | |
| elif dt_init == "random": | |
| nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) | |
| else: | |
| raise NotImplementedError | |
| # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max | |
| dt = torch.exp( | |
| torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) | |
| + math.log(dt_min) | |
| ).clamp(min=dt_init_floor) | |
| # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 | |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) | |
| with torch.no_grad(): | |
| self.dt_proj.bias.copy_(inv_dt) | |
| # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit | |
| self.dt_proj.bias._no_reinit = True | |
| # S4D real initialization | |
| A = repeat( | |
| torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), | |
| "n -> d n", | |
| d=self.d_inner, | |
| ).contiguous() | |
| A_log = torch.log(A) # Keep A_log in fp32 | |
| self.A_log = nn.Parameter(A_log) | |
| self.A_log._no_weight_decay = True | |
| # D "skip" parameter | |
| self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 | |
| self.D._no_weight_decay = True | |
| self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) | |
| def forward(self, hidden_states, inference_params=None): | |
| """ | |
| hidden_states: (B, L, D) | |
| Returns: same shape as hidden_states | |
| """ | |
| batch, seqlen, dim = hidden_states.shape | |
| conv_state, ssm_state = None, None | |
| if inference_params is not None: | |
| conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) | |
| if inference_params.seqlen_offset > 0: | |
| # The states are updated inplace | |
| out, _, _ = self.step(hidden_states, conv_state, ssm_state) | |
| return out | |
| # We do matmul and transpose BLH -> HBL at the same time | |
| xz = rearrange( | |
| self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), | |
| "d (b l) -> b d l", | |
| l=seqlen, | |
| ) | |
| if self.in_proj.bias is not None: | |
| xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") | |
| A = -torch.exp(self.A_log.float()) # (d_inner, d_state) | |
| # In the backward pass we write dx and dz next to each other to avoid torch.cat | |
| if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states | |
| out = mamba_inner_fn( | |
| xz, | |
| self.conv1d.weight, | |
| self.conv1d.bias, | |
| self.x_proj.weight, | |
| self.dt_proj.weight, | |
| self.out_proj.weight, | |
| self.out_proj.bias, | |
| A, | |
| None, # input-dependent B | |
| None, # input-dependent C | |
| self.D.float(), | |
| delta_bias=self.dt_proj.bias.float(), | |
| delta_softplus=True, | |
| ) | |
| else: | |
| x, z = xz.chunk(2, dim=1) | |
| # Compute short convolution | |
| if conv_state is not None: | |
| # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv | |
| # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. | |
| conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) | |
| if causal_conv1d_fn is None: | |
| x = self.act(self.conv1d(x)[..., :seqlen]) | |
| else: | |
| assert self.activation in ["silu", "swish"] | |
| x = causal_conv1d_fn( | |
| x=x, | |
| weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), | |
| bias=self.conv1d.bias, | |
| activation=self.activation, | |
| ) | |
| # We're careful here about the layout, to avoid extra transposes. | |
| # We want dt to have d as the slowest moving dimension | |
| # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. | |
| x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) | |
| dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) | |
| dt = self.dt_proj.weight @ dt.t() | |
| dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) | |
| B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() | |
| C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() | |
| assert self.activation in ["silu", "swish"] | |
| y = selective_scan_fn( | |
| x, | |
| dt, | |
| A, | |
| B, | |
| C, | |
| self.D.float(), | |
| z=z, | |
| delta_bias=self.dt_proj.bias.float(), | |
| delta_softplus=True, | |
| return_last_state=ssm_state is not None, | |
| ) | |
| if ssm_state is not None: | |
| y, last_state = y | |
| ssm_state.copy_(last_state) | |
| y = rearrange(y, "b d l -> b l d") | |
| out = self.out_proj(y) | |
| return out | |
| def step(self, hidden_states, conv_state, ssm_state): | |
| dtype = hidden_states.dtype | |
| assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" | |
| xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) | |
| x, z = xz.chunk(2, dim=-1) # (B D) | |
| # Conv step | |
| if causal_conv1d_update is None: | |
| conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) | |
| conv_state[:, :, -1] = x | |
| x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) | |
| if self.conv1d.bias is not None: | |
| x = x + self.conv1d.bias | |
| x = self.act(x).to(dtype=dtype) | |
| else: | |
| x = causal_conv1d_update( | |
| x, | |
| conv_state, | |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), | |
| self.conv1d.bias, | |
| self.activation, | |
| ) | |
| x_db = self.x_proj(x) # (B dt_rank+2*d_state) | |
| dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) | |
| # Don't add dt_bias here | |
| dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) | |
| A = -torch.exp(self.A_log.float()) # (d_inner, d_state) | |
| # SSM step | |
| if selective_state_update is None: | |
| # Discretize A and B | |
| dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) | |
| dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) | |
| dB = torch.einsum("bd,bn->bdn", dt, B) | |
| ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) | |
| y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) | |
| y = y + self.D.to(dtype) * x | |
| y = y * self.act(z) # (B D) | |
| else: | |
| y = selective_state_update( | |
| ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True | |
| ) | |
| out = self.out_proj(y) | |
| return out.unsqueeze(1), conv_state, ssm_state | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
| device = self.out_proj.weight.device | |
| conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype | |
| conv_state = torch.zeros( | |
| batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype | |
| ) | |
| ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype | |
| # ssm_dtype = torch.float32 | |
| ssm_state = torch.zeros( | |
| batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype | |
| ) | |
| return conv_state, ssm_state | |
| def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): | |
| assert self.layer_idx is not None | |
| if self.layer_idx not in inference_params.key_value_memory_dict: | |
| batch_shape = (batch_size,) | |
| conv_state = torch.zeros( | |
| batch_size, | |
| self.d_model * self.expand, | |
| self.d_conv, | |
| device=self.conv1d.weight.device, | |
| dtype=self.conv1d.weight.dtype, | |
| ) | |
| ssm_state = torch.zeros( | |
| batch_size, | |
| self.d_model * self.expand, | |
| self.d_state, | |
| device=self.dt_proj.weight.device, | |
| dtype=self.dt_proj.weight.dtype, | |
| # dtype=torch.float32, | |
| ) | |
| inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) | |
| else: | |
| conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] | |
| # TODO: What if batch size changes between generation, and we reuse the same states? | |
| if initialize_states: | |
| conv_state.zero_() | |
| ssm_state.zero_() | |
| return conv_state, ssm_state | |
| ########################################################################## | |
| ## Feed-forward Network | |
| class FFN(nn.Module): | |
| def __init__(self, dim, ffn_expansion_factor, bias): | |
| super(FFN, self).__init__() | |
| hidden_features = int(dim*ffn_expansion_factor) | |
| self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) | |
| self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1) | |
| self.win_size = 8 | |
| self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator | |
| self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| h1, w1 = h//self.win_size, w//self.win_size | |
| x = self.project_in(x) | |
| x = self.dwconv(x) | |
| x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size) | |
| x_win = x_win * self.modulator | |
| x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1) | |
| x1, x2 = x.chunk(2, dim=1) | |
| x = x1 * x2 | |
| x = self.project_out(x) | |
| return x | |
| ########################################################################## | |
| ## Gated Depth-wise Feed-forward Network (GDFN) | |
| class GDFN(nn.Module): | |
| def __init__(self, dim, ffn_expansion_factor, bias): | |
| super(GDFN, self).__init__() | |
| hidden_features = int(dim*ffn_expansion_factor) | |
| self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) | |
| self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1) | |
| self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) | |
| def forward(self, x): | |
| x = self.project_in(x) | |
| x = self.dwconv(x) | |
| x1, x2 = x.chunk(2, dim=1) | |
| x = F.silu(x1) * x2 | |
| x = self.project_out(x) | |
| return x | |
| ########################################################################## | |
| ## Overlapped image patch embedding with 3x3 Conv | |
| class OverlapPatchEmbed(nn.Module): | |
| def __init__(self, in_c=3, embed_dim=48, bias=False): | |
| super(OverlapPatchEmbed, self).__init__() | |
| self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) | |
| def forward(self, x): | |
| x = self.proj(x) | |
| return x | |
| ########################################################################## | |
| ## Resizing modules | |
| class Downsample(nn.Module): | |
| def __init__(self, n_feat): | |
| super(Downsample, self).__init__() | |
| self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), | |
| nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False)) | |
| def forward(self, x): | |
| return self.body(x) | |
| class Upsample(nn.Module): | |
| def __init__(self, n_feat): | |
| super(Upsample, self).__init__() | |
| self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
| nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False)) | |
| def forward(self, x): | |
| return self.body(x) | |
| """ | |
| Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git" | |
| @inproceedings{Tsai2022Stripformer, | |
| author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin}, | |
| title = {Stripformer: Strip Transformer for Fast Image Deblurring}, | |
| booktitle = {ECCV}, | |
| year = {2022} | |
| } | |
| """ | |
| class Intra_VSSM(nn.Module): | |
| def __init__(self, dim, vssm_expansion_factor, bias): # gated = True | |
| super(Intra_VSSM, self).__init__() | |
| hidden = int(dim*vssm_expansion_factor) | |
| self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias) | |
| self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias) | |
| self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias) | |
| self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias) | |
| self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias) | |
| self.mamba = Mamba(d_model=hidden // 2) | |
| def forward_core(self, x): | |
| B, C, H, W = x.size() | |
| x_input = torch.chunk(self.conv_input(x), 2, dim=1) | |
| feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous() | |
| feature_h = feature_h.view(B * H, W, C//2) | |
| feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous() | |
| feature_v = feature_v.view(B * W, H, C//2) | |
| if H == W: | |
| feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2 | |
| scan_output = self.mamba(feature) | |
| scan_output = torch.chunk(scan_output, 2, dim=0) | |
| scan_output_h = scan_output[0] | |
| scan_output_v = scan_output[1] | |
| else: | |
| scan_output_h = self.mamba(feature_h) | |
| scan_output_v = self.mamba(feature_v) | |
| scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous() | |
| scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous() | |
| scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1)) | |
| return scan_output | |
| def forward(self, x): | |
| x = self.proj_in(x) | |
| x, x_ = self.dwconv(x).chunk(2, dim=1) | |
| x = self.forward_core(x) | |
| x = F.silu(x_) * x | |
| x = self.proj_out(x) | |
| return x | |
| class Inter_VSSM(nn.Module): | |
| def __init__(self, dim, vssm_expansion_factor, bias): # gated = True | |
| super(Inter_VSSM, self).__init__() | |
| hidden = int(dim*vssm_expansion_factor) | |
| self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias) | |
| self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias) | |
| self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias) | |
| self.avg_pool = nn.AdaptiveAvgPool2d((None,1)) | |
| self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias) | |
| self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias) | |
| self.mamba = Mamba(d_model=hidden // 2) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward_core(self, x): | |
| B, C, H, W = x.size() | |
| x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W | |
| feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W | |
| feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1 | |
| feature_h_score = feature_h_score.view(B, H, -1) | |
| feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H | |
| feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1 | |
| feature_v_score = feature_v_score.view(B, W, -1) | |
| if H == W: | |
| feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2 | |
| scan_score = self.mamba(feature_score) | |
| scan_score = torch.chunk(scan_score, 2, dim=0) | |
| scan_score_h = scan_score[0] | |
| scan_score_v = scan_score[1] | |
| else: | |
| scan_score_h = self.mamba(feature_h_score) | |
| scan_score_v = self.mamba(feature_v_score) | |
| scan_score_h = self.sigmoid(scan_score_h) | |
| scan_score_v = self.sigmoid(scan_score_v) | |
| feature_h = feature_h*scan_score_h[:,:,:,None] | |
| feature_v = feature_v*scan_score_v[:,:,:,None] | |
| feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous() | |
| feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous() | |
| output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1)) | |
| return output | |
| def forward(self, x): | |
| x = self.proj_in(x) | |
| x, x_ = self.dwconv(x).chunk(2, dim=1) | |
| x = self.forward_core(x) | |
| x = F.silu(x_) * x | |
| x = self.proj_out(x) | |
| return x | |
| ########################################################################## | |
| class Strip_VSSB(nn.Module): | |
| def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'): | |
| super(Strip_VSSB, self).__init__() | |
| self.ssm = ssm | |
| if self.ssm == True: | |
| self.norm1_ssm = LayerNorm(dim, LayerNorm_type) | |
| self.norm2_ssm = LayerNorm(dim, LayerNorm_type) | |
| self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias) | |
| self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias) | |
| self.norm1_ffn = LayerNorm(dim, LayerNorm_type) | |
| self.norm2_ffn = LayerNorm(dim, LayerNorm_type) | |
| self.ffn1 = GDFN(dim, ffn_expansion_factor, bias) | |
| self.ffn2 = GDFN(dim, ffn_expansion_factor, bias) | |
| def forward(self, x): | |
| if self.ssm == True: | |
| x = x + self.intra(self.norm1_ssm(x)) | |
| x = x + self.ffn1(self.norm1_ffn(x)) | |
| if self.ssm == True: | |
| x = x + self.inter(self.norm2_ssm(x)) | |
| x = x + self.ffn2(self.norm2_ffn(x)) | |
| return x | |
| ########################################################################## | |
| ##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat ----------------------- | |
| class CLFF(nn.Module): | |
| def __init__(self, dim, dim_n1, dim_n2, bias=False): | |
| super(CLFF, self).__init__() | |
| self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
| self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias) | |
| self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias) | |
| self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) | |
| self.log_sigmoid = nn.LogSigmoid() | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x, n1, n2): | |
| x_ = self.conv(x) | |
| n1_ = self.conv_n1(n1) | |
| n2_ = self.conv_n2(n2) | |
| kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True) | |
| kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True) | |
| #g = self.sigmoid(x_) | |
| g1 = self.sigmoid(kl_n1) | |
| g2 = self.sigmoid(kl_n2) | |
| #x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_) | |
| x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1)) | |
| return x | |
| ########################################################################## | |
| ##---------- StripScanNet ----------------------- | |
| class XYScanNetP(nn.Module): | |
| def __init__(self, | |
| inp_channels=3, | |
| out_channels=3, | |
| dim = 144, # 48, 72, 96, 120, 144 | |
| num_blocks = [3,3,6], | |
| vssm_expansion_factor = 1, # 1 or 2 | |
| ffn_expansion_factor = 1, # 1 or 3 | |
| bias = False, | |
| LayerNorm_type = 'WithBias', ## Other option 'BiasFree' | |
| ): | |
| super(XYScanNetP, self).__init__() | |
| self.patch_embed = OverlapPatchEmbed(inp_channels, dim) | |
| self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor, | |
| bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) | |
| self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 | |
| self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor, | |
| bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) | |
| self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 | |
| self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor, | |
| bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) | |
| self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor, | |
| bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) | |
| self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 | |
| self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias) | |
| self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) | |
| self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor, | |
| bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) | |
| self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 | |
| self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias) | |
| self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias) | |
| self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor, | |
| bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) | |
| # self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) | |
| self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) | |
| def forward(self, inp_img): | |
| # Encoder | |
| inp_enc_level1 = self.patch_embed(inp_img) | |
| out_enc_level1 = self.encoder_level1(inp_enc_level1) | |
| out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2 | |
| inp_enc_level2 = self.down1_2(out_enc_level1) | |
| out_enc_level2 = self.encoder_level2(inp_enc_level2) | |
| out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1 | |
| inp_enc_level3 = self.down2_3(out_enc_level2) | |
| out_enc_level3 = self.encoder_level3(inp_enc_level3) | |
| out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2) | |
| out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1) | |
| out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1) | |
| out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2) | |
| # Decoder | |
| out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3) | |
| inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1) | |
| inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1)) | |
| out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1) | |
| inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1) | |
| inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1)) | |
| out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1) | |
| out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1) | |
| out_dec_level1 = out_dec_level1_decomp1 + inp_img | |
| return out_dec_level1, out_dec_level1_decomp1, None | |
| def count_parameters(model): | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Total parameters: {total:,}") | |
| print(f"Trainable parameters: {trainable:,}") | |
| def main(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = XYScanNetP().to(device) | |
| print("Model architecture:\n") | |
| print(model) | |
| count_parameters(model) | |
| # Optionally test with a dummy input | |
| dummy_input = torch.randn(1, 3, 256, 256).to(device) | |
| output, _, _ = model(dummy_input) | |
| print(f"Output shape: {output.shape}") | |
| if __name__ == "__main__": | |
| main() |