herrscher0's picture
Initial commit: FloodDiffusion text-to-motion generation model
ebc7f2e verified
# This module uses modified code from Alibaba Wan Team
# Original source: https://github.com/Wan-Video/Wan2.2
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# Modified to support 1d features with (B, C, T)
import torch
import torch.nn as nn
import torch.nn.functional as F
CACHE_T = 2
class CausalConv1d(nn.Conv1d):
"""
Causal 1d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
2 * self.padding[0],
0,
)
self.padding = (0,)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[0] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[0] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, bias=False):
super().__init__()
broadcastable_dims = (1,)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"upsample1d",
"downsample1d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample1d":
self.time_conv = CausalConv1d(dim, dim * 2, (3,), padding=(1,))
elif mode == "downsample1d":
self.time_conv = CausalConv1d(dim, dim, (3,), stride=(2,), padding=(0,))
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t = x.size()
if self.mode == "upsample1d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:].clone()
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t)
x = torch.stack((x[:, 0, :, :], x[:, 1, :, :]), 3)
x = x.reshape(b, c, t * 2)
if self.mode == "downsample1d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim),
nn.SiLU(),
CausalConv1d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv1d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv1d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AvgDown1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor = self.factor_t
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (pad_t, 0)
x = F.pad(x, pad)
B, C, T = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
)
x = x.permute(0, 1, 3, 2).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
)
x = x.mean(dim=2)
return x
class DupUp1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor = self.factor_t
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
x.size(2),
)
x = x.permute(0, 1, 3, 2).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
)
if first_chunk:
x = x[
:,
:,
self.factor_t - 1 :,
]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False):
super().__init__()
# Shortcut path with downsample
if temperal_downsample:
self.avg_shortcut = AvgDown1D(
in_dim,
out_dim,
factor_t=2,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if temperal_downsample:
downsamples.append(Resample(out_dim, mode="downsample1d"))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
if self.avg_shortcut is None:
return x
else:
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False):
super().__init__()
# Shortcut path with upsample
if temperal_upsample:
self.avg_shortcut = DupUp1D(
in_dim,
out_dim,
factor_t=2,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if temperal_upsample:
upsamples.append(Resample(out_dim, mode="upsample1d"))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder1d(nn.Module):
def __init__(
self,
input_dim,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv1d(input_dim, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i] if i < len(temperal_downsample) else False
)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
)
)
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
RMS_norm(out_dim),
CausalConv1d(out_dim, out_dim, 1),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim),
nn.SiLU(),
CausalConv1d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv1d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder1d(nn.Module):
def __init__(
self,
output_dim,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv1 = CausalConv1d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
RMS_norm(dims[0]),
CausalConv1d(dims[0], dims[0], 1),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
)
)
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim),
nn.SiLU(),
CausalConv1d(out_dim, output_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv1d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv1d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv1d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(
self,
input_dim,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=1,
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder1d(
input_dim,
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv1d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv1d(z_dim, z_dim, 1)
self.decoder = Decoder1d(
input_dim,
dec_dim,
z_dim,
dim_mult,
num_res_blocks,
self.temperal_upsample,
dropout,
)
def forward(self, x, scale=[0, 1]):
mu = self.encode(x, scale)
x_recon = self.decode(mu, scale)
return x_recon, mu
def encode(self, x, scale, return_dist=False):
self.clear_cache()
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
1, self.z_dim, 1
)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
if return_dist:
return mu, log_var
return mu
def decode(self, z, scale):
self.clear_cache()
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i : i + 1],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
@torch.no_grad()
def stream_encode(self, x, first_chunk, scale, return_dist=False):
t = x.shape[2]
if first_chunk:
iter_ = 1 + (t - 1) // 4
else:
iter_ = t // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
if first_chunk:
out = self.encoder(
x[:, :, :1],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out = self.encoder(
x[:, :, :4],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
if first_chunk:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 4 * i : 4 * (i + 1)],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
1, self.z_dim, 1
)
else:
mu = (mu - scale[0]) * scale[1]
if return_dist:
return mu, log_var
else:
return mu
@torch.no_grad()
def stream_decode(self, z, first_chunk, scale):
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i : i + 1],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=first_chunk, # Use the external first_chunk parameter
)
else:
out_ = self.decoder(
x[:, :, i : i + 1],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=False, # Explicitly set to False for subsequent time steps within the same chunk
)
out = torch.cat([out, out_], 2)
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv1d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv1d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num