# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import glob import json import math import os import types import warnings from typing import Any, Dict, Optional, Union import numpy as np import torch import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import is_torch_version, logging from torch import nn from ..dist import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group, usp_attn_forward, xFuserLongContextAttention) from ..utils import cfg_skip from .attention_utils import attention from .cache_utils import TeaCache from .wan_camera_adapter import SimpleAdapter def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 position = position.type(torch.float64) # calculation sinusoid = torch.outer( position, torch.pow(10000, -torch.arange(half).to(position).div(half))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x @amp.autocast(enabled=False) def rope_params(max_seq_len, dim, theta=10000): assert dim % 2 == 0 freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs # modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py @amp.autocast(enabled=False) def get_1d_rotary_pos_embed_riflex( pos: Union[np.ndarray, int], dim: int, theta: float = 10000.0, use_real=False, k: Optional[int] = None, L_test: Optional[int] = None, L_test_scale: Optional[int] = None, ): """ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE L_test (`int`, *optional*, defaults to None): the number of frames for inference Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = torch.arange(pos) if isinstance(pos, np.ndarray): pos = torch.from_numpy(pos) # type: ignore # [S] freqs = 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) # === Riflex modification start === # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)). # Empirical observations show that a few videos may exhibit repetition in the tail frames. # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period. if k is not None: freqs[k-1] = 0.9 * 2 * torch.pi / L_test # === Riflex modification end === if L_test_scale is not None: freqs[k-1] = freqs[k-1] / L_test_scale freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin else: # lumina freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): tw = tgt_width th = tgt_height h, w = src r = h / w if r > (th / tw): resize_height = th resize_width = int(round(th / h * w)) else: resize_width = tw resize_height = int(round(tw / w * h)) crop_top = int(round((th - resize_height) / 2.0)) crop_left = int(round((tw - resize_width) / 2.0)) return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) @amp.autocast(enabled=False) @torch.compiler.disable() def rope_apply(x, grid_sizes, freqs): n, c = x.size(2), x.size(3) // 2 # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape( seq_len, n, -1, 2)) freqs_i = torch.cat([ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(seq_len, 1, -1) # apply rotary embedding x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.cat([x_i, x[i, seq_len:]]) # append to collection output.append(x_i) return torch.stack(output).to(x.dtype) def rope_apply_qk(q, k, grid_sizes, freqs): q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) return q, k class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ return self._norm(x) * self.weight def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype) class WanLayerNorm(nn.LayerNorm): def __init__(self, dim, eps=1e-6, elementwise_affine=False): super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ return super().forward(x) class WanSelfAttention(nn.Module): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d) k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d) v = self.v(x.to(dtype)).view(b, s, n, d) return q, k, v q, k, v = qkv_fn(x) q, k = rope_apply_qk(q, k, grid_sizes, freqs) x = attention( q.to(dtype), k.to(dtype), v=v.to(dtype), k_lens=seq_lens, window_size=self.window_size) x = x.to(dtype) # output x = x.flatten(2) x = self.o(x) return x class WanT2VCrossAttention(WanSelfAttention): def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) v = self.v(context.to(dtype)).view(b, -1, n, d) # compute attention x = attention( q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens ) x = x.to(dtype) # output x = x.flatten(2) x = self.o(x) return x class WanI2VCrossAttention(WanSelfAttention): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): super().__init__(dim, num_heads, window_size, qk_norm, eps) self.k_img = nn.Linear(dim, dim) self.v_img = nn.Linear(dim, dim) # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ context_img = context[:, :257] context = context[:, 257:] b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) v = self.v(context.to(dtype)).view(b, -1, n, d) k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d) v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d) img_x = attention( q.to(dtype), k_img.to(dtype), v_img.to(dtype), k_lens=None ) img_x = img_x.to(dtype) # compute attention x = attention( q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens ) x = x.to(dtype) # output x = x.flatten(2) img_x = img_x.flatten(2) x = x + img_x x = self.o(x) return x class WanCrossAttention(WanSelfAttention): def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) v = self.v(context.to(dtype)).view(b, -1, n, d) # compute attention x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens) # output x = x.flatten(2) x = self.o(x.to(dtype)) return x WAN_CROSSATTENTION_CLASSES = { 't2v_cross_attn': WanT2VCrossAttention, 'i2v_cross_attn': WanI2VCrossAttention, 'cross_attn': WanCrossAttention, } class WanAttentionBlock(nn.Module): def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps) self.norm2 = WanLayerNorm(dim, eps) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) # modulation self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def forward( self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0, ): r""" Args: x(Tensor): Shape [B, L, C] e(Tensor): Shape [B, 6, C] seq_lens(Tensor): Shape [B], length of each sequence in batch grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ if e.dim() > 3: e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) e = [e.squeeze(2) for e in e] else: e = (self.modulation + e).chunk(6, dim=1) # self-attention temp_x = self.norm1(x) * (1 + e[1]) + e[0] temp_x = temp_x.to(dtype) y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t) x = x + y * e[2] # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): # cross-attention x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t) # ffn function temp_x = self.norm2(x) * (1 + e[4]) + e[3] temp_x = temp_x.to(dtype) y = self.ffn(temp_x) x = x + y * e[5] return x x = cross_attn_ffn(x, context, context_lens, e) return x class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6): super().__init__() self.dim = dim self.out_dim = out_dim self.patch_size = patch_size self.eps = eps # layers out_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, out_dim) # modulation self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, e): r""" Args: x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ if e.dim() > 2: e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) e = [e.squeeze(2) for e in e] else: e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x class MLPProj(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.proj = torch.nn.Sequential( torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ # ignore_for_config = [ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' # ] # _no_split_modules = ['WanAttentionBlock'] _supports_gradient_checkpointing = True @register_to_config def __init__( self, model_type='t2v', patch_size=(1, 2, 2), text_len=512, in_dim=16, dim=2048, ffn_dim=8192, freq_dim=256, text_dim=4096, out_dim=16, num_heads=16, num_layers=32, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, in_channels=16, hidden_size=2048, add_control_adapter=False, in_dim_control_adapter=24, downscale_factor_control_adapter=8, add_ref_conv=False, in_dim_ref_conv=16, cross_attn_type=None, ): r""" Initialize the diffusion model backbone. Args: model_type (`str`, *optional*, defaults to 't2v'): Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) text_len (`int`, *optional*, defaults to 512): Fixed length for text embeddings in_dim (`int`, *optional*, defaults to 16): Input video channels (C_in) dim (`int`, *optional*, defaults to 2048): Hidden dimension of the transformer ffn_dim (`int`, *optional*, defaults to 8192): Intermediate dimension in feed-forward network freq_dim (`int`, *optional*, defaults to 256): Dimension for sinusoidal time embeddings text_dim (`int`, *optional*, defaults to 4096): Input dimension for text embeddings out_dim (`int`, *optional*, defaults to 16): Output video channels (C_out) num_heads (`int`, *optional*, defaults to 16): Number of attention heads num_layers (`int`, *optional*, defaults to 32): Number of transformer blocks window_size (`tuple`, *optional*, defaults to (-1, -1)): Window size for local attention (-1 indicates global attention) qk_norm (`bool`, *optional*, defaults to True): Enable query/key normalization cross_attn_norm (`bool`, *optional*, defaults to False): Enable cross-attention normalization eps (`float`, *optional*, defaults to 1e-6): Epsilon value for normalization layers """ super().__init__() # assert model_type in ['t2v', 'i2v', 'ti2v'] self.model_type = model_type self.patch_size = patch_size self.text_len = text_len self.in_dim = in_dim self.dim = dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.text_dim = text_dim self.out_dim = out_dim self.num_heads = num_heads self.num_layers = num_layers self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks if cross_attn_type is None: cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) for _ in range(num_layers) ]) for layer_idx, block in enumerate(self.blocks): block.self_attn.layer_idx = layer_idx block.self_attn.num_layers = self.num_layers # head self.head = Head(dim, out_dim, patch_size, eps) # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads self.d = d self.dim = dim self.freqs = torch.cat( [ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1 ) if model_type == 'i2v': self.img_emb = MLPProj(1280, dim) if add_control_adapter: self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter) else: self.control_adapter = None if add_ref_conv: self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) else: self.ref_conv = None self.teacache = None self.cfg_skip_ratio = None self.current_steps = 0 self.num_inference_steps = None self.gradient_checkpointing = False self.all_gather = None self.sp_world_size = 1 self.sp_world_rank = 0 self.init_weights() def _set_gradient_checkpointing(self, *args, **kwargs): if "value" in kwargs: self.gradient_checkpointing = kwargs["value"] if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"): self.motioner.gradient_checkpointing = kwargs["value"] elif "enable" in kwargs: self.gradient_checkpointing = kwargs["enable"] if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"): self.motioner.gradient_checkpointing = kwargs["enable"] else: raise ValueError("Invalid set gradient checkpointing") def enable_teacache( self, coefficients, num_steps: int, rel_l1_thresh: float, num_skip_start_steps: int = 0, offload: bool = True, ): self.teacache = TeaCache( coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload ) def share_teacache( self, transformer = None, ): self.teacache = transformer.teacache def disable_teacache(self): self.teacache = None def enable_cfg_skip(self, cfg_skip_ratio, num_steps): if cfg_skip_ratio != 0: self.cfg_skip_ratio = cfg_skip_ratio self.current_steps = 0 self.num_inference_steps = num_steps else: self.cfg_skip_ratio = None self.current_steps = 0 self.num_inference_steps = None def share_cfg_skip( self, transformer = None, ): self.cfg_skip_ratio = transformer.cfg_skip_ratio self.current_steps = transformer.current_steps self.num_inference_steps = transformer.num_inference_steps def disable_cfg_skip(self): self.cfg_skip_ratio = None self.current_steps = 0 self.num_inference_steps = None def enable_riflex( self, k = 6, L_test = 66, L_test_scale = 4.886, ): device = self.freqs.device self.freqs = torch.cat( [ get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale), rope_params(1024, 2 * (self.d // 6)), rope_params(1024, 2 * (self.d // 6)) ], dim=1 ).to(device) def disable_riflex(self): device = self.freqs.device self.freqs = torch.cat( [ rope_params(1024, self.d - 4 * (self.d // 6)), rope_params(1024, 2 * (self.d // 6)), rope_params(1024, 2 * (self.d // 6)) ], dim=1 ).to(device) def enable_multi_gpus_inference(self,): self.sp_world_size = get_sequence_parallel_world_size() self.sp_world_rank = get_sequence_parallel_rank() self.all_gather = get_sp_group().all_gather # For normal model. for block in self.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) # For vace model. if hasattr(self, 'vace_blocks'): for block in self.vace_blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) @cfg_skip() def forward( self, x, t, context, seq_len, clip_fea=None, y=None, y_camera=None, full_ref=None, subject_ref=None, cond_flag=True, ): r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]): List of text embeddings each with shape [L, C] seq_len (`int`): Maximum sequence length for positional encoding clip_fea (Tensor, *optional*): CLIP image features for image-to-video mode y (List[Tensor], *optional*): Conditional video inputs for image-to-video mode, same shape as x cond_flag (`bool`, *optional*, defaults to True): Flag to indicate whether to forward the condition input Returns: List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ # Wan2.2 don't need a clip. # if self.model_type == 'i2v': # assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device dtype = x.dtype if self.freqs.device != device and torch.device(type="meta") != device: self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] # add control adapter if self.control_adapter is not None and y_camera is not None: y_camera = self.control_adapter(y_camera) x = [u + v for u, v in zip(x, y_camera)] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] if self.ref_conv is not None and full_ref is not None: full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) seq_len += full_ref.size(1) x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)] if t.dim() != 1 and t.size(1) < seq_len: pad_size = seq_len - t.size(1) last_elements = t[:, -1].unsqueeze(1) padding = last_elements.repeat(1, pad_size) t = torch.cat([padding, t], dim=1) if subject_ref is not None: subject_ref_frames = subject_ref.size(2) subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2) grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) seq_len += subject_ref.size(1) x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)] if t.dim() != 1 and t.size(1) < seq_len: pad_size = seq_len - t.size(1) last_elements = t[:, -1].unsqueeze(1) padding = last_elements.repeat(1, pad_size) t = torch.cat([t, padding], dim=1) seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) if self.sp_world_size > 1: seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings with amp.autocast(dtype=torch.float32): if t.dim() != 1: if t.size(1) < seq_len: pad_size = seq_len - t.size(1) last_elements = t[:, -1].unsqueeze(1) padding = last_elements.repeat(1, pad_size) t = torch.cat([t, padding], dim=1) bt = t.size(0) ft = t.flatten() e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, ft).unflatten(0, (bt, seq_len)).float()) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) else: e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float()) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # assert e.dtype == torch.float32 and e0.dtype == torch.float32 # e0 = e0.to(dtype) # e = e.to(dtype) # context context_lens = None context = self.text_embedding( torch.stack([ torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) # Context Parallel if self.sp_world_size > 1: x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] if t.dim() != 1: e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank] # TeaCache if self.teacache is not None: if cond_flag: if t.dim() != 1: modulated_inp = e0[:, -1, :] else: modulated_inp = e0 skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps if skip_flag: self.should_calc = True self.teacache.accumulated_rel_l1_distance = 0 else: if cond_flag: rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: self.should_calc = False else: self.should_calc = True self.teacache.accumulated_rel_l1_distance = 0 self.teacache.previous_modulated_input = modulated_inp self.teacache.should_calc = self.should_calc else: self.should_calc = self.teacache.should_calc # TeaCache if self.teacache is not None: if not self.should_calc: previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond x = x + previous_residual.to(x.device)[-x.size()[0]:,] else: ori_x = x.clone().cpu() if self.teacache.offload else x.clone() for block in self.blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, e0, seq_lens, grid_sizes, self.freqs, context, context_lens, dtype, t, **ckpt_kwargs, ) else: # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens, dtype=dtype, t=t ) x = block(x, **kwargs) if cond_flag: self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x else: self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x else: for block in self.blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, e0, seq_lens, grid_sizes, self.freqs, context, context_lens, dtype, t, **ckpt_kwargs, ) else: # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens, dtype=dtype, t=t ) x = block(x, **kwargs) # head if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs) else: x = self.head(x, e) if self.sp_world_size > 1: x = self.all_gather(x, dim=1) if self.ref_conv is not None and full_ref is not None: full_ref_length = full_ref.size(1) x = x[:, full_ref_length:] grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) if subject_ref is not None: subject_ref_length = subject_ref.size(1) x = x[:, :-subject_ref_length] grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device) # unpatchify x = self.unpatchify(x, grid_sizes) x = torch.stack(x) if self.teacache is not None and cond_flag: self.teacache.cnt += 1 if self.teacache.cnt == self.teacache.num_steps: self.teacache.reset() return x def unpatchify(self, x, grid_sizes): r""" Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) Returns: List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] """ c = self.out_dim out = [] for u, v in zip(x, grid_sizes.tolist()): u = u[:math.prod(v)].view(*v, *self.patch_size, c) u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) return out def init_weights(self): r""" Initialize model parameters using Xavier initialization. """ # basic init for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # init embeddings nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) for m in self.text_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) for m in self.time_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) # init output layer nn.init.zeros_(self.head.head.weight) @classmethod def from_pretrained( cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 ): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) from diffusers.utils import WEIGHTS_NAME model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) model_file_safetensors = model_file.replace(".bin", ".safetensors") if "dict_mapping" in transformer_additional_kwargs.keys(): for key in transformer_additional_kwargs["dict_mapping"]: transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] if low_cpu_mem_usage: try: import re from diffusers import __version__ as diffusers_version if diffusers_version >= "0.33.0": from diffusers.models.model_loading_utils import \ load_model_dict_into_meta else: from diffusers.models.modeling_utils import \ load_model_dict_into_meta from diffusers.utils import is_accelerate_available if is_accelerate_available(): import accelerate # Instantiate model with empty weights with accelerate.init_empty_weights(): model = cls.from_config(config, **transformer_additional_kwargs) param_device = "cpu" if os.path.exists(model_file): state_dict = torch.load(model_file, map_location="cpu") elif os.path.exists(model_file_safetensors): from safetensors.torch import load_file, safe_open state_dict = load_file(model_file_safetensors) else: from safetensors.torch import load_file, safe_open model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) state_dict = {} print(model_files_safetensors) for _model_file_safetensors in model_files_safetensors: _state_dict = load_file(_model_file_safetensors) for key in _state_dict: state_dict[key] = _state_dict[key] if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size(): model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :] model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0 state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight'] filtered_state_dict = {} for key in state_dict: if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): filtered_state_dict[key] = state_dict[key] else: print(f"Skipping key '{key}' due to size mismatch or absence in model.") model_keys = set(model.state_dict().keys()) loaded_keys = set(filtered_state_dict.keys()) missing_keys = model_keys - loaded_keys def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): initialized_dict = {} with torch.no_grad(): for key in missing_keys: param_shape = model_state_dict[key].shape param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype if 'weight' in key: if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) elif 'embedding' in key or 'embed' in key: initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 elif 'head' in key or 'output' in key or 'proj_out' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) elif len(param_shape) >= 2: initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) nn.init.xavier_uniform_(initialized_dict[key]) else: initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 elif 'bias' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) elif 'running_mean' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) elif 'running_var' in key: initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) elif 'num_batches_tracked' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) else: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) return initialized_dict if missing_keys: print(f"Missing keys will be initialized: {sorted(missing_keys)}") initialized_params = initialize_missing_parameters( missing_keys, model.state_dict(), torch_dtype ) filtered_state_dict.update(initialized_params) if diffusers_version >= "0.33.0": # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. load_model_dict_into_meta( model, filtered_state_dict, dtype=torch_dtype, model_name_or_path=pretrained_model_path, ) else: model._convert_deprecated_attention_blocks(filtered_state_dict) unexpected_keys = load_model_dict_into_meta( model, filtered_state_dict, device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_path, ) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: print( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) return model except Exception as e: print( f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." ) model = cls.from_config(config, **transformer_additional_kwargs) if os.path.exists(model_file): state_dict = torch.load(model_file, map_location="cpu") elif os.path.exists(model_file_safetensors): from safetensors.torch import load_file, safe_open state_dict = load_file(model_file_safetensors) else: from safetensors.torch import load_file, safe_open model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) state_dict = {} for _model_file_safetensors in model_files_safetensors: _state_dict = load_file(_model_file_safetensors) for key in _state_dict: state_dict[key] = _state_dict[key] if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size(): model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :] model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0 state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight'] tmp_state_dict = {} for key in state_dict: if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): tmp_state_dict[key] = state_dict[key] else: print(key, "Size don't match, skip") state_dict = tmp_state_dict m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") print(m) params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] print(f"### All Parameters: {sum(params) / 1e6} M") params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] print(f"### attn1 Parameters: {sum(params) / 1e6} M") model = model.to(torch_dtype) return model class Wan2_2Transformer3DModel(WanTransformer3DModel): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ # ignore_for_config = [ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' # ] # _no_split_modules = ['WanAttentionBlock'] _supports_gradient_checkpointing = True def __init__( self, model_type='t2v', patch_size=(1, 2, 2), text_len=512, in_dim=16, dim=2048, ffn_dim=8192, freq_dim=256, text_dim=4096, out_dim=16, num_heads=16, num_layers=32, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6, in_channels=16, hidden_size=2048, add_control_adapter=False, in_dim_control_adapter=24, downscale_factor_control_adapter=8, add_ref_conv=False, in_dim_ref_conv=16, ): r""" Initialize the diffusion model backbone. Args: model_type (`str`, *optional*, defaults to 't2v'): Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) text_len (`int`, *optional*, defaults to 512): Fixed length for text embeddings in_dim (`int`, *optional*, defaults to 16): Input video channels (C_in) dim (`int`, *optional*, defaults to 2048): Hidden dimension of the transformer ffn_dim (`int`, *optional*, defaults to 8192): Intermediate dimension in feed-forward network freq_dim (`int`, *optional*, defaults to 256): Dimension for sinusoidal time embeddings text_dim (`int`, *optional*, defaults to 4096): Input dimension for text embeddings out_dim (`int`, *optional*, defaults to 16): Output video channels (C_out) num_heads (`int`, *optional*, defaults to 16): Number of attention heads num_layers (`int`, *optional*, defaults to 32): Number of transformer blocks window_size (`tuple`, *optional*, defaults to (-1, -1)): Window size for local attention (-1 indicates global attention) qk_norm (`bool`, *optional*, defaults to True): Enable query/key normalization cross_attn_norm (`bool`, *optional*, defaults to False): Enable cross-attention normalization eps (`float`, *optional*, defaults to 1e-6): Epsilon value for normalization layers """ super().__init__( model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, in_channels=in_channels, hidden_size=hidden_size, add_control_adapter=add_control_adapter, in_dim_control_adapter=in_dim_control_adapter, downscale_factor_control_adapter=downscale_factor_control_adapter, add_ref_conv=add_ref_conv, in_dim_ref_conv=in_dim_ref_conv, cross_attn_type="cross_attn" ) if hasattr(self, "img_emb"): del self.img_emb