Spaces:
Running
on
Zero
Running
on
Zero
File size: 30,648 Bytes
b56342d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 |
import numbers
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias, stride = stride)
"""
Borrow from "https://github.com/state-spaces/mamba.git"
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
"""
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# We do matmul and transpose BLH -> HBL at the same time
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None:
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
##########################################################################
## Feed-forward Network
class FFN(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FFN, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
self.win_size = 8
self.modulator = nn.Parameter(torch.ones(self.win_size, self.win_size, dim*2)) # modulator
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
h1, w1 = h//self.win_size, w//self.win_size
x = self.project_in(x)
x = self.dwconv(x)
x_win = rearrange(x, 'b c (wsh h1) (wsw w1) -> b h1 w1 wsh wsw c', wsh=self.win_size, wsw=self.win_size)
x_win = x_win * self.modulator
x = rearrange(x_win, 'b h1 w1 wsh wsw c -> b c (wsh h1) (wsw w1)', wsh=self.win_size, wsw=self.win_size, h1=h1, w1=w1)
x1, x2 = x.chunk(2, dim=1)
x = x1 * x2
x = self.project_out(x)
return x
##########################################################################
## Gated Depth-wise Feed-forward Network (GDFN)
class GDFN(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(GDFN, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias, dilation=1)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x = self.dwconv(x)
x1, x2 = x.chunk(2, dim=1)
x = F.silu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
##########################################################################
## Resizing modules
class Downsample(nn.Module):
def __init__(self, n_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False))
def forward(self, x):
return self.body(x)
class Upsample(nn.Module):
def __init__(self, n_feat):
super(Upsample, self).__init__()
self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False))
def forward(self, x):
return self.body(x)
"""
Borrow from "https://github.com/pp00704831/Stripformer-ECCV-2022-.git"
@inproceedings{Tsai2022Stripformer,
author = {Fu-Jen Tsai and Yan-Tsung Peng and Yen-Yu Lin and Chung-Chi Tsai and Chia-Wen Lin},
title = {Stripformer: Strip Transformer for Fast Image Deblurring},
booktitle = {ECCV},
year = {2022}
}
"""
class Intra_VSSM(nn.Module):
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
super(Intra_VSSM, self).__init__()
hidden = int(dim*vssm_expansion_factor)
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.mamba = Mamba(d_model=hidden // 2)
def forward_core(self, x):
B, C, H, W = x.size()
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
feature_h = feature_h.view(B * H, W, C//2)
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
feature_v = feature_v.view(B * W, H, C//2)
if H == W:
feature = torch.cat((feature_h, feature_v), dim=0) # B * H * 2, W, C//2
scan_output = self.mamba(feature)
scan_output = torch.chunk(scan_output, 2, dim=0)
scan_output_h = scan_output[0]
scan_output_v = scan_output[1]
else:
scan_output_h = self.mamba(feature_h)
scan_output_v = self.mamba(feature_v)
scan_output_h = scan_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
scan_output_v = scan_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
scan_output = self.fuse_out(torch.cat((scan_output_h, scan_output_v), dim=1))
return scan_output
def forward(self, x):
x = self.proj_in(x)
x, x_ = self.dwconv(x).chunk(2, dim=1)
x = self.forward_core(x)
x = F.silu(x_) * x
x = self.proj_out(x)
return x
class Inter_VSSM(nn.Module):
def __init__(self, dim, vssm_expansion_factor, bias): # gated = True
super(Inter_VSSM, self).__init__()
hidden = int(dim*vssm_expansion_factor)
self.proj_in = nn.Conv2d(dim, hidden*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden*2, hidden*2, kernel_size=3, stride=1, padding=1, groups=hidden*2, bias=bias)
self.proj_out = nn.Conv2d(hidden, dim, kernel_size=1, bias=bias)
self.avg_pool = nn.AdaptiveAvgPool2d((None,1))
self.conv_input = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.fuse_out = nn.Conv2d(hidden, hidden, kernel_size=1, padding=0, bias=bias)
self.mamba = Mamba(d_model=hidden // 2)
self.sigmoid = nn.Sigmoid()
def forward_core(self, x):
B, C, H, W = x.size()
x_input = torch.chunk(self.conv_input(x), 2, dim=1) # B, C, H, W
feature_h = x_input[0].permute(0, 2, 1, 3).contiguous() # B, H, C//2, W
feature_h_score = self.avg_pool(feature_h) # B, H, C//2, 1
feature_h_score = feature_h_score.view(B, H, -1)
feature_v = x_input[1].permute(0, 3, 1, 2).contiguous() # B, W, C//2, H
feature_v_score = self.avg_pool(feature_v) # B, W, C//2, 1
feature_v_score = feature_v_score.view(B, W, -1)
if H == W:
feature_score = torch.cat((feature_h_score, feature_v_score), dim=0) # B * 2, W or H, C//2
scan_score = self.mamba(feature_score)
scan_score = torch.chunk(scan_score, 2, dim=0)
scan_score_h = scan_score[0]
scan_score_v = scan_score[1]
else:
scan_score_h = self.mamba(feature_h_score)
scan_score_v = self.mamba(feature_v_score)
scan_score_h = self.sigmoid(scan_score_h)
scan_score_v = self.sigmoid(scan_score_v)
feature_h = feature_h*scan_score_h[:,:,:,None]
feature_v = feature_v*scan_score_v[:,:,:,None]
feature_h = feature_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
feature_v = feature_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
output = self.fuse_out(torch.cat((feature_h, feature_v), dim=1))
return output
def forward(self, x):
x = self.proj_in(x)
x, x_ = self.dwconv(x).chunk(2, dim=1)
x = self.forward_core(x)
x = F.silu(x_) * x
x = self.proj_out(x)
return x
##########################################################################
class Strip_VSSB(nn.Module):
def __init__(self, dim, vssm_expansion_factor, ffn_expansion_factor, bias=False, ssm=False, LayerNorm_type='WithBias'):
super(Strip_VSSB, self).__init__()
self.ssm = ssm
if self.ssm == True:
self.norm1_ssm = LayerNorm(dim, LayerNorm_type)
self.norm2_ssm = LayerNorm(dim, LayerNorm_type)
self.intra = Intra_VSSM(dim, vssm_expansion_factor, bias)
self.inter = Inter_VSSM(dim, vssm_expansion_factor, bias)
self.norm1_ffn = LayerNorm(dim, LayerNorm_type)
self.norm2_ffn = LayerNorm(dim, LayerNorm_type)
self.ffn1 = GDFN(dim, ffn_expansion_factor, bias)
self.ffn2 = GDFN(dim, ffn_expansion_factor, bias)
def forward(self, x):
if self.ssm == True:
x = x + self.intra(self.norm1_ssm(x))
x = x + self.ffn1(self.norm1_ffn(x))
if self.ssm == True:
x = x + self.inter(self.norm2_ssm(x))
x = x + self.ffn2(self.norm2_ffn(x))
return x
##########################################################################
##---------- Cross-level Feature Fusion by Adding Sigmoid(KL-Div) * Multi-Scale Feat -----------------------
class CLFF(nn.Module):
def __init__(self, dim, dim_n1, dim_n2, bias=False):
super(CLFF, self).__init__()
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.conv_n1 = nn.Conv2d(dim_n1, dim, kernel_size=1, bias=bias)
self.conv_n2 = nn.Conv2d(dim_n2, dim, kernel_size=1, bias=bias)
self.fuse_out1 = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
self.log_sigmoid = nn.LogSigmoid()
self.sigmoid = nn.Sigmoid()
def forward(self, x, n1, n2):
x_ = self.conv(x)
n1_ = self.conv_n1(n1)
n2_ = self.conv_n2(n2)
kl_n1 = F.kl_div(input=self.log_sigmoid(n1_), target=self.log_sigmoid(x_), log_target=True)
kl_n2 = F.kl_div(input=self.log_sigmoid(n2_), target=self.log_sigmoid(x_), log_target=True)
#g = self.sigmoid(x_)
g1 = self.sigmoid(kl_n1)
g2 = self.sigmoid(kl_n2)
#x = (1 + g) * x_ + (1 - g) * (g1 * n1_ + g2 * n2_)
x = self.fuse_out1(torch.cat((x_, g1 * n1_ + g2 * n2_), dim=1))
return x
##########################################################################
##---------- StripScanNet -----------------------
class XYScanNetP(nn.Module):
def __init__(self,
inp_channels=3,
out_channels=3,
dim = 144, # 48, 72, 96, 120, 144
num_blocks = [3,3,6],
vssm_expansion_factor = 1, # 1 or 2
ffn_expansion_factor = 1, # 1 or 3
bias = False,
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
):
super(XYScanNetP, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(*[Strip_VSSB(dim=dim, vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
self.encoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
self.encoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=False, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
self.decoder_level3 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**2), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
self.clff_level2 = CLFF(int(dim*2**1), dim_n1=int(dim*2**0), dim_n2=(dim*2**2), bias=bias)
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
self.decoder_level2 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**1), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1
self.clff_level1 = CLFF(int(dim*2**0), dim_n1=int(dim*2**1), dim_n2=(dim*2**2), bias=bias)
self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), int(dim*2**0), kernel_size=1, bias=bias)
self.decoder_level1 = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), vssm_expansion_factor=vssm_expansion_factor, ffn_expansion_factor = ffn_expansion_factor,
bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
# self.refinement = nn.Sequential(*[Strip_VSSB(dim=int(dim*2**0), expansion_factor=expansion_factor, bias=bias, ssm=True, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
self.output = nn.Conv2d(int(dim*2**0), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, inp_img):
# Encoder
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
out_enc_level1_2 = F.interpolate(out_enc_level1, scale_factor=0.5) # dim*2, lvl1 down-scaled to lvl2
inp_enc_level2 = self.down1_2(out_enc_level1)
out_enc_level2 = self.encoder_level2(inp_enc_level2)
out_enc_level2_1 = F.interpolate(out_enc_level2, scale_factor=2) # dim*2, lvl2 up-scaled to lvl1
inp_enc_level3 = self.down2_3(out_enc_level2)
out_enc_level3 = self.encoder_level3(inp_enc_level3)
out_enc_level3_2 = F.interpolate(out_enc_level3, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl2 (lvl3->lvl2)
out_enc_level3_1 = F.interpolate(out_enc_level3_2, scale_factor=2) # dim*2**2, lvl3 up-scaled to lvl1 (lvl3->lvl2->lvl1)
out_enc_level1 = self.clff_level1(out_enc_level1, out_enc_level2_1, out_enc_level3_1)
out_enc_level2 = self.clff_level2(out_enc_level2, out_enc_level1_2, out_enc_level3_2)
# Decoder
out_dec_level3_decomp1 = self.decoder_level3(out_enc_level3)
inp_dec_level2_decomp1 = self.up3_2(out_dec_level3_decomp1)
inp_dec_level2_decomp1 = self.reduce_chan_level2(torch.cat((inp_dec_level2_decomp1, out_enc_level2), dim=1))
out_dec_level2_decomp1 = self.decoder_level2(inp_dec_level2_decomp1)
inp_dec_level1_decomp1 = self.up2_1(out_dec_level2_decomp1)
inp_dec_level1_decomp1 = self.reduce_chan_level1(torch.cat((inp_dec_level1_decomp1, out_enc_level1), dim=1))
out_dec_level1_decomp1 = self.decoder_level1(inp_dec_level1_decomp1)
out_dec_level1_decomp1 = self.output(out_dec_level1_decomp1)
out_dec_level1 = out_dec_level1_decomp1 + inp_img
return out_dec_level1, out_dec_level1_decomp1, None
def count_parameters(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XYScanNetP().to(device)
print("Model architecture:\n")
print(model)
count_parameters(model)
# Optionally test with a dummy input
dummy_input = torch.randn(1, 3, 256, 256).to(device)
output, _, _ = model(dummy_input)
print(f"Output shape: {output.shape}")
if __name__ == "__main__":
main() |