ZIT-Controlnet / videox_fun /models /fantasytalking_transformer3d.py
Alexander Bagus
22
be751d2
# Modified from https://github.com/Fantasy-AMAP/fantasy-talking/blob/main/diffsynth/models
# Copyright Alibaba Inc. All Rights Reserved.
import math
import os
from typing import Any, Dict
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import register_to_config
from diffusers.utils import is_torch_version
from ..dist import sequence_parallel_all_gather, sequence_parallel_chunk
from ..utils import cfg_skip
from .attention_utils import attention
from .wan_transformer3d import (WanAttentionBlock, WanLayerNorm, WanRMSNorm,
WanSelfAttention, WanTransformer3DModel,
sinusoidal_embedding_1d)
class AudioProjModel(nn.Module):
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, audio_embeds):
context_tokens = self.proj(audio_embeds)
context_tokens = self.norm(context_tokens)
return context_tokens # [B,L,C]
class AudioCrossAttentionProcessor(nn.Module):
def __init__(self, context_dim, hidden_dim):
super().__init__()
self.context_dim = context_dim
self.hidden_dim = hidden_dim
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
nn.init.zeros_(self.k_proj.weight)
nn.init.zeros_(self.v_proj.weight)
self.sp_world_size = 1
self.sp_world_rank = 0
self.all_gather = None
def __call__(
self,
attn: nn.Module,
x: torch.Tensor,
context: torch.Tensor,
context_lens: torch.Tensor,
audio_proj: torch.Tensor,
audio_context_lens: torch.Tensor,
latents_num_frames: int = 21,
audio_scale: float = 1.0,
) -> torch.Tensor:
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
audio_proj: [B, 21, L3, C]
audio_context_lens: [B*21].
"""
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), attn.num_heads, attn.head_dim
# Compute query, key, value
q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
v = attn.v(context).view(b, -1, n, d)
k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
v_img = attn.v_img(context_img).view(b, -1, n, d)
img_x = attention(q, k_img, v_img, k_lens=None)
# Compute attention
x = attention(q, k, v, k_lens=context_lens)
x = x.flatten(2)
img_x = img_x.flatten(2)
if len(audio_proj.shape) == 4:
if self.sp_world_size > 1:
q = self.all_gather(q, dim=1)
length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames)
origin_length = q.size()[1]
if origin_length > length:
q_pad = q[:, length:]
q = q[:, :length]
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
audio_x = attention(
audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL"
)
audio_x = audio_x.view(b, q.size(1), n, d)
if self.sp_world_size > 1:
if origin_length > length:
audio_x = torch.cat([audio_x, q_pad], dim=1)
audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank]
audio_x = audio_x.flatten(2)
elif len(audio_proj.shape) == 3:
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL")
audio_x = audio_x.flatten(2)
# Output
if isinstance(audio_scale, torch.Tensor):
audio_scale = audio_scale[:, None, None]
x = x + img_x + audio_x * audio_scale
x = attn.o(x)
# print(audio_scale)
return x
class AudioCrossAttention(WanSelfAttention):
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
super().__init__(dim, num_heads, window_size, qk_norm, eps)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.processor = AudioCrossAttentionProcessor(2048, dim)
def forward(
self,
x,
context,
context_lens,
audio_proj,
audio_context_lens,
latents_num_frames,
audio_scale: float = 1.0,
**kwargs,
):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
if audio_proj is None:
return self.processor(self, x, context, context_lens)
else:
return self.processor(
self,
x,
context,
context_lens,
audio_proj,
audio_context_lens,
latents_num_frames,
audio_scale,
)
class AudioAttentionBlock(nn.Module):
def __init__(
self,
cross_attn_type, # Useless
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# Layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else nn.Identity()
)
self.cross_attn = AudioCrossAttention(
dim, num_heads, (-1, -1), qk_norm, eps
)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ffn_dim, dim),
)
# Modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
audio_proj=None,
audio_context_lens=None,
audio_scale=1,
dtype=torch.bfloat16,
t=0,
):
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs, dtype, t=t
)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# Cross-attention & FFN function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(
self.norm3(x), context, context_lens, dtype=dtype, t=t,
audio_proj=audio_proj, audio_context_lens=audio_context_lens, audio_scale=audio_scale,
latents_num_frames=grid_sizes[0][0],
)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class FantasyTalkingTransformer3DModel(WanTransformer3DModel):
@register_to_config
def __init__(self,
model_type='i2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
cross_attn_type=None,
audio_in_dim=768):
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
if cross_attn_type is None:
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
AudioAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
])
for layer_idx, block in enumerate(self.blocks):
block.self_attn.layer_idx = layer_idx
block.self_attn.num_layers = self.num_layers
self.proj_model = AudioProjModel(audio_in_dim, 2048)
def split_audio_sequence(self, audio_proj_length, num_frames=81):
"""
Map the audio feature sequence to corresponding latent frame slices.
Args:
audio_proj_length (int): The total length of the audio feature sequence
(e.g., 173 in audio_proj[1, 173, 768]).
num_frames (int): The number of video frames in the training data (default: 81).
Returns:
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
(within the audio feature sequence) corresponding to a latent frame.
"""
# Average number of tokens per original video frame
tokens_per_frame = audio_proj_length / num_frames
# Each latent frame covers 4 video frames, and we want the center
tokens_per_latent_frame = tokens_per_frame * 4
half_tokens = int(tokens_per_latent_frame / 2)
pos_indices = []
for i in range(int((num_frames - 1) / 4) + 1):
if i == 0:
pos_indices.append(0)
else:
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
end_token = tokens_per_frame * (i * 4 + 1)
center_token = int((start_token + end_token) / 2) - 1
pos_indices.append(center_token)
# Build index ranges centered around each position
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
# Adjust the first range to avoid negative start index
pos_idx_ranges[0] = [
-(half_tokens * 2 - pos_idx_ranges[1][0]),
pos_idx_ranges[1][0],
]
return pos_idx_ranges
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
"""
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
if the range exceeds the input boundaries.
Args:
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
expand_length (int): Number of tokens to expand on both sides of each subsequence.
Returns:
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
Each element is a padded subsequence.
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
Useful for ignoring padding tokens in attention masks.
"""
pos_idx_ranges = [
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
]
sub_sequences = []
seq_len = input_tensor.size(1) # 173
max_valid_idx = seq_len - 1 # 172
k_lens_list = []
for start, end in pos_idx_ranges:
# Calculate the fill amount
pad_front = max(-start, 0)
pad_back = max(end - max_valid_idx, 0)
# Calculate the start and end indices of the valid part
valid_start = max(start, 0)
valid_end = min(end, max_valid_idx)
# Extract the valid part
if valid_start <= valid_end:
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
else:
valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
# In the sequence dimension (the 1st dimension) perform padding
padded_subseq = F.pad(
valid_part,
(0, 0, 0, pad_back + pad_front, 0, 0),
mode="constant",
value=0,
)
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
sub_sequences.append(padded_subseq)
return torch.stack(sub_sequences, dim=1), torch.tensor(
k_lens_list, dtype=torch.long
)
def enable_multi_gpus_inference(self,):
super().enable_multi_gpus_inference()
for name, module in self.named_modules():
if module.__class__.__name__ == 'AudioCrossAttentionProcessor':
module.sp_world_size = self.sp_world_size
module.sp_world_rank = self.sp_world_rank
module.all_gather = self.all_gather
@cfg_skip()
def forward(
self,
x,
t,
context,
seq_len,
audio_wav2vec_fea=None,
clip_fea=None,
y=None,
audio_scale=1,
cond_flag=True
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# Wan2.2 don't need a clip.
# if self.model_type == 'i2v':
# assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
dtype = x.dtype
if self.freqs.device != device and torch.device(type="meta") != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
if self.sp_world_size > 1:
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
if t.dim() != 1:
if t.size(1) < seq_len:
pad_size = seq_len - t.size(1)
last_elements = t[:, -1].unsqueeze(1)
padding = last_elements.repeat(1, pad_size)
t = torch.cat([t, padding], dim=1)
bt = t.size(0)
ft = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
ft).unflatten(0, (bt, seq_len)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
else:
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
# e0 = e0.to(dtype)
# e = e.to(dtype)
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
num_frames = (grid_sizes[0][0] - 1) * 4 + 1
audio_proj_fea = self.proj_model(audio_wav2vec_fea)
pos_idx_ranges = self.split_audio_sequence(audio_proj_fea.size(1), num_frames=num_frames)
audio_proj, audio_context_lens = self.split_tensor_with_padding(
audio_proj_fea, pos_idx_ranges, expand_length=4
)
# Context Parallel
if self.sp_world_size > 1:
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
if t.dim() != 1:
e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
# TeaCache
if self.teacache is not None:
if cond_flag:
if t.dim() != 1:
modulated_inp = e0[:, -1, :]
else:
modulated_inp = e0
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
if skip_flag:
self.should_calc = True
self.teacache.accumulated_rel_l1_distance = 0
else:
if cond_flag:
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
self.should_calc = False
else:
self.should_calc = True
self.teacache.accumulated_rel_l1_distance = 0
self.teacache.previous_modulated_input = modulated_inp
self.teacache.should_calc = self.should_calc
else:
self.should_calc = self.teacache.should_calc
# TeaCache
if self.teacache is not None:
if not self.should_calc:
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
else:
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
for block in self.blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
e0,
seq_lens,
grid_sizes,
self.freqs,
context,
context_lens,
audio_proj,
audio_context_lens,
audio_scale,
dtype,
t,
**ckpt_kwargs,
)
else:
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens,
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
audio_scale=audio_scale,
dtype=dtype,
t=t
)
x = block(x, **kwargs)
if cond_flag:
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
else:
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
else:
for block in self.blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
e0,
seq_lens,
grid_sizes,
self.freqs,
context,
context_lens,
audio_proj,
audio_context_lens,
audio_scale,
dtype,
t,
**ckpt_kwargs,
)
else:
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens,
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
audio_scale=audio_scale,
dtype=dtype,
t=t
)
x = block(x, **kwargs)
# head
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
else:
x = self.head(x, e)
if self.sp_world_size > 1:
x = self.all_gather(x, dim=1)
# Unpatchify
x = self.unpatchify(x, grid_sizes)
x = torch.stack(x)
if self.teacache is not None and cond_flag:
self.teacache.cnt += 1
if self.teacache.cnt == self.teacache.num_steps:
self.teacache.reset()
return x