# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import inspect import glob import json import math import os import types import warnings from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.attention import Attention, FeedForward from diffusers.models.attention_processor import ( Attention, AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0) from diffusers.models.embeddings import (CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormContinuous, CogVideoXLayerNormZero, RMSNorm) from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers) from diffusers.utils.torch_utils import maybe_allow_in_graph from torch import nn from ..dist import (QwenImageMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from .attention_utils import attention from .cache_utils import TeaCache from ..utils import cfg_skip logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ) -> torch.Tensor: """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent).to(timesteps.dtype) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def apply_rotary_emb_qwen( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Used for Stable Audio, OmniGen, CogView4 and Cosmos x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) def forward(self, timestep, hidden_states): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) conditioning = timesteps_emb return conditioning class QwenEmbedRope(nn.Module): def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim pos_index = torch.arange(4096) neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), self.rope_params(pos_index, self.axes_dim[2], self.theta), ], dim=1, ) self.neg_freqs = torch.cat( [ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), self.rope_params(neg_index, self.axes_dim[2], self.theta), ], dim=1, ) self.rope_cache = {} # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token """ assert dim % 2 == 0 freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def forward(self, video_fhw, txt_seq_lens, device): """ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: txt_length: [bs] a list of 1 integers representing the length of the text """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] if not isinstance(video_fhw, list): video_fhw = [video_fhw] vid_freqs = [] max_vid_index = 0 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw rope_key = f"{idx}_{frame}_{height}_{width}" if not torch.compiler.is_compiling(): if rope_key not in self.rope_cache: self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) video_freq = self.rope_cache[rope_key] else: video_freq = self._compute_video_freqs(frame, height, width, idx) video_freq = video_freq.to(device) vid_freqs.append(video_freq) if self.scale_rope: max_vid_index = max(height // 2, width // 2, max_vid_index) else: max_vid_index = max(height, width, max_vid_index) max_len = max(txt_seq_lens) txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) else: freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() class QwenDoubleStreamAttnProcessor2_0: """ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor implements joint attention computation where text and image streams are processed together. """ _attention_backend = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, # Image stream encoder_hidden_states: torch.FloatTensor = None, # Text stream encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") seq_txt = encoder_hidden_states.shape[1] # Compute QKV for image stream (sample projections) img_query = attn.to_q(hidden_states) img_key = attn.to_k(hidden_states) img_value = attn.to_v(hidden_states) # Compute QKV for text stream (context projections) txt_query = attn.add_q_proj(encoder_hidden_states) txt_key = attn.add_k_proj(encoder_hidden_states) txt_value = attn.add_v_proj(encoder_hidden_states) # Reshape for multi-head attention img_query = img_query.unflatten(-1, (attn.heads, -1)) img_key = img_key.unflatten(-1, (attn.heads, -1)) img_value = img_value.unflatten(-1, (attn.heads, -1)) txt_query = txt_query.unflatten(-1, (attn.heads, -1)) txt_key = txt_key.unflatten(-1, (attn.heads, -1)) txt_value = txt_value.unflatten(-1, (attn.heads, -1)) # Apply QK normalization if attn.norm_q is not None: img_query = attn.norm_q(img_query) if attn.norm_k is not None: img_key = attn.norm_k(img_key) if attn.norm_added_q is not None: txt_query = attn.norm_added_q(txt_query) if attn.norm_added_k is not None: txt_key = attn.norm_added_k(txt_key) # Apply RoPE if image_rotary_emb is not None: img_freqs, txt_freqs = image_rotary_emb img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) # Concatenate for joint attention # Order: [text, image] joint_query = torch.cat([txt_query, img_query], dim=1) joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) joint_hidden_states = attention( joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False ) # Reshape back joint_hidden_states = joint_hidden_states.flatten(2, 3) joint_hidden_states = joint_hidden_states.to(joint_query.dtype) # Split attention outputs back txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part # Apply output projections img_attn_output = attn.to_out[0](img_attn_output) if len(attn.to_out) > 1: img_attn_output = attn.to_out[1](img_attn_output) # dropout txt_attn_output = attn.to_add_out(txt_attn_output) return img_attn_output, txt_attn_output @maybe_allow_in_graph class QwenImageTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 ): super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim # Image processing modules self.img_mod = nn.Sequential( nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 ) self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Attention( query_dim=dim, cross_attention_dim=None, # Enable cross attention for joint computation added_kv_proj_dim=dim, # Enable added KV projections for text stream dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=QwenDoubleStreamAttnProcessor2_0(), qk_norm=qk_norm, eps=eps, ) self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # Text processing modules self.txt_mod = nn.Sequential( nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2 ) self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def _modulate(self, x, mod_params): """Apply modulation to input tensor""" shift, scale, gate = mod_params.chunk(3, dim=-1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Get modulation parameters for both streams img_mod_params = self.img_mod(temb) # [B, 6*dim] txt_mod_params = self.txt_mod(temb) # [B, 6*dim] # Split modulation parameters for norm1 and norm2 img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] # Process image stream - norm1 + modulation img_normed = self.img_norm1(hidden_states) img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) # Process text stream - norm1 + modulation txt_normed = self.txt_norm1(encoder_hidden_states) txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) # Use QwenAttnProcessor2_0 for joint attention computation # This directly implements the DoubleStreamLayerMegatron logic: # 1. Computes QKV for both streams # 2. Applies QK normalization and RoPE # 3. Concatenates and runs joint attention # 4. Splits results back to separate streams joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=img_modulated, # Image stream (will be processed as "sample") encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided img_attn_output, txt_attn_output = attn_output # Apply attention gates and add residual (like in Megatron) hidden_states = hidden_states + img_gate1 * img_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output # Process image stream - norm2 + MLP img_normed2 = self.img_norm2(hidden_states) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) img_mlp_output = self.img_mlp(img_modulated2) hidden_states = hidden_states + img_gate2 * img_mlp_output # Process text stream - norm2 + MLP txt_normed2 = self.txt_norm2(encoder_hidden_states) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) txt_mlp_output = self.txt_mlp(txt_modulated2) encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output # Clip to prevent overflow for fp16 if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Qwen. Args: patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches. in_channels (`int`, defaults to `64`): The number of channels in the input. out_channels (`int`, *optional*, defaults to `None`): The number of channels in the output. If not specified, it defaults to `in_channels`. num_layers (`int`, defaults to `60`): The number of layers of dual stream DiT blocks to use. attention_head_dim (`int`, defaults to `128`): The number of dimensions to use for each attention head. num_attention_heads (`int`, defaults to `24`): The number of attention heads to use. joint_attention_dim (`int`, defaults to `3584`): The number of dimensions to use for the joint attention (embedding/channel dimension of `encoder_hidden_states`). guidance_embeds (`bool`, defaults to `False`): Whether to use guidance embeddings for guidance-distilled variant of the model. axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): The dimensions to use for the rotary positional embeddings. """ # _supports_gradient_checkpointing = True # _no_split_modules = ["QwenImageTransformerBlock"] # _skip_layerwise_casting_patterns = ["pos_embed", "norm"] # _repeated_blocks = ["QwenImageTransformerBlock"] _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size: int = 2, in_channels: int = 64, out_channels: Optional[int] = 16, num_layers: int = 60, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 3584, guidance_embeds: bool = False, # TODO: this should probably be removed axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), ): super().__init__() self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) self.img_in = nn.Linear(in_channels, self.inner_dim) self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ QwenImageTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, ) for _ in range(num_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.teacache = None self.cfg_skip_ratio = None self.current_steps = 0 self.num_inference_steps = None self.gradient_checkpointing = False self.sp_world_size = 1 self.sp_world_rank = 0 def _set_gradient_checkpointing(self, *args, **kwargs): if "value" in kwargs: self.gradient_checkpointing = kwargs["value"] elif "enable" in kwargs: self.gradient_checkpointing = kwargs["enable"] else: raise ValueError("Invalid set gradient checkpointing") def enable_multi_gpus_inference(self,): self.sp_world_size = get_sequence_parallel_world_size() self.sp_world_rank = get_sequence_parallel_rank() self.all_gather = get_sp_group().all_gather self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0()) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def enable_cfg_skip(self, cfg_skip_ratio, num_steps): if cfg_skip_ratio != 0: self.cfg_skip_ratio = cfg_skip_ratio self.current_steps = 0 self.num_inference_steps = num_steps else: self.cfg_skip_ratio = None self.current_steps = 0 self.num_inference_steps = None def share_cfg_skip( self, transformer = None, ): self.cfg_skip_ratio = transformer.cfg_skip_ratio self.current_steps = transformer.current_steps self.num_inference_steps = transformer.num_inference_steps def disable_cfg_skip(self): self.cfg_skip_ratio = None self.current_steps = 0 self.num_inference_steps = None def enable_teacache( self, coefficients, num_steps: int, rel_l1_thresh: float, num_skip_start_steps: int = 0, offload: bool = True, ): self.teacache = TeaCache( coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload ) def share_teacache( self, transformer = None, ): self.teacache = transformer.teacache def disable_teacache(self): self.teacache = None @cfg_skip() def forward_bs(self, x, *args, **kwargs): func = self.forward sig = inspect.signature(func) bs = len(x) bs_half = int(bs // 2) if bs >= 2: # cond x_i = x[bs_half:] args_i = [ arg[bs_half:] if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)) and len(arg) == bs else arg for arg in args ] kwargs_i = { k: (v[bs_half:] if isinstance(v, (torch.Tensor, list, tuple, np.ndarray)) and len(v) == bs else v ) for k, v in kwargs.items() } if 'cond_flag' in sig.parameters: kwargs_i["cond_flag"] = True cond_out = func(x_i, *args_i, **kwargs_i) # uncond uncond_x_i = x[:bs_half] uncond_args_i = [ arg[:bs_half] if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)) and len(arg) == bs else arg for arg in args ] uncond_kwargs_i = { k: (v[:bs_half] if isinstance(v, (torch.Tensor, list, tuple, np.ndarray)) and len(v) == bs else v ) for k, v in kwargs.items() } if 'cond_flag' in sig.parameters: uncond_kwargs_i["cond_flag"] = False uncond_out = func(uncond_x_i, *uncond_args_i, **uncond_kwargs_i) x = torch.cat([uncond_out, cond_out], dim=0) else: x = func(x, *args, **kwargs) return x def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None, txt_seq_lens: Optional[List[int]] = None, guidance: torch.Tensor = None, # TODO: this should probably be removed attention_kwargs: Optional[Dict[str, Any]] = None, cond_flag: bool = True, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`QwenTransformer2DModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): Mask of the input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) if isinstance(encoder_hidden_states, list): encoder_hidden_states = torch.stack(encoder_hidden_states) encoder_hidden_states_mask = torch.stack(encoder_hidden_states_mask) hidden_states = self.img_in(hidden_states) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 temb = ( self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states) ) image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) # Context Parallel if self.sp_world_size > 1: hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] if image_rotary_emb is not None: image_rotary_emb = ( torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank], image_rotary_emb[1] ) # TeaCache if self.teacache is not None: if cond_flag: inp = hidden_states.clone() temb_ = temb.clone() encoder_hidden_states_ = encoder_hidden_states.clone() img_mod_params_ = self.transformer_blocks[0].img_mod(temb_) img_mod1_, img_mod2_ = img_mod_params_.chunk(2, dim=-1) img_normed_ = self.transformer_blocks[0].img_norm1(inp) modulated_inp, img_gate1_ = self.transformer_blocks[0]._modulate(img_normed_, img_mod1_) skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps if skip_flag: self.should_calc = True self.teacache.accumulated_rel_l1_distance = 0 else: if cond_flag: rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: self.should_calc = False else: self.should_calc = True self.teacache.accumulated_rel_l1_distance = 0 self.teacache.previous_modulated_input = modulated_inp self.teacache.should_calc = self.should_calc else: self.should_calc = self.teacache.should_calc # TeaCache if self.teacache is not None: if not self.should_calc: previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond hidden_states = hidden_states + previous_residual.to(hidden_states.device)[-hidden_states.size()[0]:,] else: ori_hidden_states = hidden_states.clone().cpu() if self.teacache.offload else hidden_states.clone() # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, ) if cond_flag: self.teacache.previous_residual_cond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states else: self.teacache.previous_residual_uncond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states del ori_hidden_states else: for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, ) # Use only the image part (hidden_states) from the dual-stream blocks hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) if self.sp_world_size > 1: output = self.all_gather(output, dim=1) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if self.teacache is not None and cond_flag: self.teacache.cnt += 1 if self.teacache.cnt == self.teacache.num_steps: self.teacache.reset() return output @classmethod def from_pretrained( cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 ): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) from diffusers.utils import WEIGHTS_NAME model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) model_file_safetensors = model_file.replace(".bin", ".safetensors") if "dict_mapping" in transformer_additional_kwargs.keys(): for key in transformer_additional_kwargs["dict_mapping"]: transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] if low_cpu_mem_usage: try: import re from diffusers import __version__ as diffusers_version if diffusers_version >= "0.33.0": from diffusers.models.model_loading_utils import \ load_model_dict_into_meta else: from diffusers.models.modeling_utils import \ load_model_dict_into_meta from diffusers.utils import is_accelerate_available if is_accelerate_available(): import accelerate # Instantiate model with empty weights with accelerate.init_empty_weights(): model = cls.from_config(config, **transformer_additional_kwargs) param_device = "cpu" if os.path.exists(model_file): state_dict = torch.load(model_file, map_location="cpu") elif os.path.exists(model_file_safetensors): from safetensors.torch import load_file, safe_open state_dict = load_file(model_file_safetensors) else: from safetensors.torch import load_file, safe_open model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) state_dict = {} print(model_files_safetensors) for _model_file_safetensors in model_files_safetensors: _state_dict = load_file(_model_file_safetensors) for key in _state_dict: state_dict[key] = _state_dict[key] filtered_state_dict = {} for key in state_dict: if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): filtered_state_dict[key] = state_dict[key] else: print(f"Skipping key '{key}' due to size mismatch or absence in model.") model_keys = set(model.state_dict().keys()) loaded_keys = set(filtered_state_dict.keys()) missing_keys = model_keys - loaded_keys def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): initialized_dict = {} with torch.no_grad(): for key in missing_keys: param_shape = model_state_dict[key].shape param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype if 'weight' in key: if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) elif 'embedding' in key or 'embed' in key: initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 elif 'head' in key or 'output' in key or 'proj_out' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) elif len(param_shape) >= 2: initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) nn.init.xavier_uniform_(initialized_dict[key]) else: initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 elif 'bias' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) elif 'running_mean' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) elif 'running_var' in key: initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) elif 'num_batches_tracked' in key: initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) else: initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) return initialized_dict if missing_keys: print(f"Missing keys will be initialized: {sorted(missing_keys)}") initialized_params = initialize_missing_parameters( missing_keys, model.state_dict(), torch_dtype ) filtered_state_dict.update(initialized_params) if diffusers_version >= "0.33.0": # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. load_model_dict_into_meta( model, filtered_state_dict, dtype=torch_dtype, model_name_or_path=pretrained_model_path, ) else: model._convert_deprecated_attention_blocks(filtered_state_dict) unexpected_keys = load_model_dict_into_meta( model, filtered_state_dict, device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_path, ) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: print( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) return model except Exception as e: print( f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." ) model = cls.from_config(config, **transformer_additional_kwargs) if os.path.exists(model_file): state_dict = torch.load(model_file, map_location="cpu") elif os.path.exists(model_file_safetensors): from safetensors.torch import load_file, safe_open state_dict = load_file(model_file_safetensors) else: from safetensors.torch import load_file, safe_open model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) state_dict = {} for _model_file_safetensors in model_files_safetensors: _state_dict = load_file(_model_file_safetensors) for key in _state_dict: state_dict[key] = _state_dict[key] tmp_state_dict = {} for key in state_dict: if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): tmp_state_dict[key] = state_dict[key] else: print(key, "Size don't match, skip") state_dict = tmp_state_dict m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") print(m) params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] print(f"### All Parameters: {sum(params) / 1e6} M") params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] print(f"### attn1 Parameters: {sum(params) / 1e6} M") model = model.to(torch_dtype) return model