ZIT-Controlnet / videox_fun /models /wan_camera_adapter.py
Alexander Bagus
22
be751d2
raw
history blame
2.3 kB
import torch
import torch.nn as nn
class SimpleAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
super(SimpleAdapter, self).__init__()
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
)
def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)
# Convolution operation
x_conv = self.conv(x_unshuffled)
# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)
# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)
return out
class ResidualBlock(nn.Module):
def __init__(self, dim):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
# Example usage
# in_dim = 3
# out_dim = 64
# adapter = SimpleAdapterWithReshape(in_dim, out_dim)
# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
# output = adapter(x)
# print(output.shape) # Should reflect transformed dimensions