import torch import torch.nn as nn class SimpleAdapter(nn.Module): def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1): super(SimpleAdapter, self).__init__() # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor) # Convolution: reduce spatial dimensions by a factor # of 2 (without overlap) self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0) # Residual blocks for feature extraction self.residual_blocks = nn.Sequential( *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)] ) def forward(self, x): # Reshape to merge the frame dimension into batch bs, c, f, h, w = x.size() x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) # Pixel Unshuffle operation x_unshuffled = self.pixel_unshuffle(x) # Convolution operation x_conv = self.conv(x_unshuffled) # Feature extraction with residual blocks out = self.residual_blocks(x_conv) # Reshape to restore original bf dimension out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames out = out.permute(0, 2, 1, 3, 4) return out class ResidualBlock(nn.Module): def __init__(self, dim): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) def forward(self, x): residual = x out = self.relu(self.conv1(x)) out = self.conv2(out) out += residual return out # Example usage # in_dim = 3 # out_dim = 64 # adapter = SimpleAdapterWithReshape(in_dim, out_dim) # x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4 # output = adapter(x) # print(output.shape) # Should reflect transformed dimensions