# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math import types from copy import deepcopy from typing import List import numpy as np import torch import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import is_torch_version, logging from einops import rearrange from .attention_utils import attention from .wan_animate_adapter import FaceAdapter, FaceEncoder from .wan_animate_motion_encoder import Generator from .wan_transformer3d import (Head, MLPProj, WanAttentionBlock, WanLayerNorm, WanRMSNorm, WanSelfAttention, WanTransformer3DModel, rope_apply, sinusoidal_embedding_1d) from ..utils import cfg_skip class Wan2_2Transformer3DModel_Animate(WanTransformer3DModel): # _no_split_modules = ['WanAnimateAttentionBlock'] _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size=(1, 2, 2), text_len=512, in_dim=36, dim=5120, ffn_dim=13824, freq_dim=256, text_dim=4096, out_dim=16, num_heads=40, num_layers=40, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, motion_encoder_dim=512, use_img_emb=True ): model_type = "i2v" # TODO: Hard code for both preview and official versions. super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) self.motion_encoder_dim = motion_encoder_dim self.use_img_emb = use_img_emb self.pose_patch_embedding = nn.Conv3d( 16, dim, kernel_size=patch_size, stride=patch_size ) # initialize weights self.init_weights() self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) self.face_adapter = FaceAdapter( heads_num=self.num_heads, hidden_dim=self.dim, num_adapter_layers=self.num_layers // 5, ) self.face_encoder = FaceEncoder( in_dim=motion_encoder_dim, hidden_dim=self.dim, num_heads=4, ) def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents] for x_, pose_latents_ in zip(x, pose_latents): x_[:, :, 1:] += pose_latents_ b,c,T,h,w = face_pixel_values.shape face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") encode_bs = 8 face_pixel_values_tmp = [] for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) motion_vec = torch.cat(face_pixel_values_tmp) motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) motion_vec = self.face_encoder(motion_vec) B, L, H, C = motion_vec.shape pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) motion_vec = torch.cat([pad_face, motion_vec], dim=1) return x, motion_vec def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): if block_idx % 5 == 0: use_context_parallel = self.sp_world_size > 1 adapter_args = [x, motion_vec, motion_masks, use_context_parallel, self.all_gather, self.sp_world_size, self.sp_world_rank] residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) x = residual_out + x return x @cfg_skip() def forward( self, x, t, clip_fea, context, seq_len, y=None, pose_latents=None, face_pixel_values=None, cond_flag=True ): # params device = self.patch_embedding.weight.device dtype = x.dtype if self.freqs.device != device and torch.device(type="meta") != device: self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) if self.sp_world_size > 1: seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings with amp.autocast(dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float() ) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context_lens = None context = self.text_embedding( torch.stack([ torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) if self.use_img_emb: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) # Context Parallel if self.sp_world_size > 1: x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] if t.dim() != 1: e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank] # TeaCache if self.teacache is not None: if cond_flag: if t.dim() != 1: modulated_inp = e0[0][:, -1, :] else: modulated_inp = e0[0] skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps if skip_flag: self.should_calc = True self.teacache.accumulated_rel_l1_distance = 0 else: if cond_flag: rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: self.should_calc = False else: self.should_calc = True self.teacache.accumulated_rel_l1_distance = 0 self.teacache.previous_modulated_input = modulated_inp self.teacache.should_calc = self.should_calc else: self.should_calc = self.teacache.should_calc # TeaCache if self.teacache is not None: if not self.should_calc: previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond x = x + previous_residual.to(x.device)[-x.size()[0]:,] else: ori_x = x.clone().cpu() if self.teacache.offload else x.clone() for idx, block in enumerate(self.blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, e0, seq_lens, grid_sizes, self.freqs, context, context_lens, dtype, t, **ckpt_kwargs, ) x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) else: # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens, dtype=dtype, t=t ) x = block(x, **kwargs) x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) if cond_flag: self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x else: self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x else: for idx, block in enumerate(self.blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, e0, seq_lens, grid_sizes, self.freqs, context, context_lens, dtype, t, **ckpt_kwargs, ) x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) else: # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens, dtype=dtype, t=t ) x = block(x, **kwargs) x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) # head x = self.head(x, e) # Context Parallel if self.sp_world_size > 1: x = self.all_gather(x.contiguous(), dim=1) # unpatchify x = self.unpatchify(x, grid_sizes) x = torch.stack(x) return x