# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn from diffusers.configuration_utils import register_to_config from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers) import glob import inspect import json import os import math from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention_processor import Attention from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import RMSNorm from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.attention_processor import Attention, AttentionProcessor from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers) from .z_image_transformer2d import (ZImageTransformer2DModel, FinalLayer, ZImageTransformerBlock) ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 class ZImageControlTransformerBlock(ZImageTransformerBlock): def __init__( self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0 ): super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id if block_id == 0: self.before_proj = nn.Linear(self.dim, self.dim) nn.init.zeros_(self.before_proj.weight) nn.init.zeros_(self.before_proj.bias) self.after_proj = nn.Linear(self.dim, self.dim) nn.init.zeros_(self.after_proj.weight) nn.init.zeros_(self.after_proj.bias) def forward(self, c, x, **kwargs): if self.block_id == 0: c = self.before_proj(c) + x all_c = [] else: all_c = list(torch.unbind(c)) c = all_c.pop(-1) c = super().forward(c, **kwargs) c_skip = self.after_proj(c) all_c += [c_skip, c] c = torch.stack(all_c) return c class BaseZImageTransformerBlock(ZImageTransformerBlock): def __init__( self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0 ): super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs): hidden_states = super().forward(hidden_states, **kwargs) if self.block_id is not None: hidden_states = hidden_states + hints[self.block_id] * context_scale return hidden_states class ZImageControlTransformer2DModel(ZImageTransformer2DModel): @register_to_config def __init__( self, control_layers_places=None, control_in_dim=None, all_patch_size=(2,), all_f_patch_size=(1,), in_channels=16, dim=3840, n_layers=30, n_refiner_layers=2, n_heads=30, n_kv_heads=30, norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, rope_theta=256.0, t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], ): super().__init__( all_patch_size=all_patch_size, all_f_patch_size=all_f_patch_size, in_channels=in_channels, dim=dim, n_layers=n_layers, n_refiner_layers=n_refiner_layers, n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=norm_eps, qk_norm=qk_norm, cap_feat_dim=cap_feat_dim, rope_theta=rope_theta, t_scale=t_scale, axes_dims=axes_dims, axes_lens=axes_lens, ) self.control_layers_places = [i for i in range(0, self.num_layers, 2)] if control_layers_places is None else control_layers_places self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim assert 0 in self.control_layers_places self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)} # blocks del self.layers self.layers = nn.ModuleList( [ BaseZImageTransformerBlock( i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping[i] if i in self.control_layers_places else None ) for i in range(n_layers) ] ) # control blocks self.control_layers = nn.ModuleList( [ ZImageControlTransformerBlock( i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i ) for i in self.control_layers_places ] ) # control patch embeddings all_x_embedder = {} for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) print(f_patch_size * patch_size * patch_size * self.control_in_dim, dim) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) self.control_noise_refiner = nn.ModuleList( [ ZImageTransformerBlock( 1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, ) for layer_id in range(n_refiner_layers) ] ) def forward_control( self, x, cap_feats, control_context, kwargs, t=None, patch_size=2, f_patch_size=1, ): # embeddings bsz = len(control_context) device = control_context[0].device ( control_context, x_size, x_pos_ids, x_inner_pad_mask, ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) # control_context embed & refine x_item_seqlens = [len(_) for _ in control_context] assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) x_max_item_seqlen = max(x_item_seqlens) control_context = torch.cat(control_context, dim=0) control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) # Match t_embedder output dtype to control_context for layerwise casting compatibility adaln_input = t.type_as(control_context) control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 # Context Parallel if self.sp_world_size > 1: control_context = torch.chunk(control_context, self.sp_world_size, dim=1)[self.sp_world_rank] if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.control_noise_refiner: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} control_context = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), control_context, x_attn_mask, x_freqs_cis, adaln_input, **ckpt_kwargs, ) else: for layer in self.control_noise_refiner: control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) # unified cap_item_seqlens = [len(_) for _ in cap_feats] control_context_unified = [] for i in range(bsz): x_len = x_item_seqlens[i] cap_len = cap_item_seqlens[i] control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified # Context Parallel if self.sp_world_size > 1: c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank] # arguments new_kwargs = dict(x=x) new_kwargs.update(kwargs) for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, **static_kwargs): def custom_forward(*inputs): return module(*inputs, **static_kwargs) return custom_forward ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} c = torch.utils.checkpoint.checkpoint( create_custom_forward(layer, **new_kwargs), c, **ckpt_kwargs, ) else: c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] return hints def forward( self, x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, control_context=None, control_context_scale=1.0, ): assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size bsz = len(x) device = x[0].device t = t * self.t_scale t = self.t_embedder(t) ( x, cap_feats, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask, ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) # x embed & refine x_item_seqlens = [len(_) for _ in x] assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) x_max_item_seqlen = max(x_item_seqlens) x = torch.cat(x, dim=0) x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) # Match t_embedder output dtype to x for layerwise casting compatibility adaln_input = t.type_as(x) x[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 # Context Parallel if self.sp_world_size > 1: x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.noise_refiner: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), x, x_attn_mask, x_freqs_cis, adaln_input, **ckpt_kwargs, ) else: for layer in self.noise_refiner: x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) # cap embed & refine cap_item_seqlens = [len(_) for _ in cap_feats] assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) cap_max_item_seqlen = max(cap_item_seqlens) cap_feats = torch.cat(cap_feats, dim=0) cap_feats = self.cap_embedder(cap_feats) cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(cap_item_seqlens): cap_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.context_refiner: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} cap_feats = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), cap_feats, cap_attn_mask, cap_freqs_cis, **ckpt_kwargs, ) else: for layer in self.context_refiner: cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) # unified unified = [] unified_freqs_cis = [] for i in range(bsz): x_len = x_item_seqlens[i] cap_len = cap_item_seqlens[i] unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] assert unified_item_seqlens == [len(_) for _ in unified] unified_max_item_seqlen = max(unified_item_seqlens) unified = pad_sequence(unified, batch_first=True, padding_value=0.0) unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(unified_item_seqlens): unified_attn_mask[i, :seq_len] = 1 # Arguments kwargs = dict( attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input, ) hints = self.forward_control( unified, cap_feats, control_context, kwargs, t=t, patch_size=patch_size, f_patch_size=f_patch_size, ) for layer in self.layers: # Arguments kwargs = dict( attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input, hints=hints, context_scale=control_context_scale ) if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, **static_kwargs): def custom_forward(*inputs): return module(*inputs, **static_kwargs) return custom_forward ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} unified = torch.utils.checkpoint.checkpoint( create_custom_forward(layer, **kwargs), unified, **ckpt_kwargs, ) else: unified = layer(unified, **kwargs) unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) x = self.unpatchify(unified, x_size, patch_size, f_patch_size) if self.sp_world_size > 1: x = self.all_gather(x, dim=1) x = torch.stack(x) return x, {}