ZIT-Controlnet / videox_fun /models /qwenimage_transformer2d.py
Alexander Bagus
22
be751d2
raw
history blame
50.1 kB
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import glob
import json
import math
import os
import types
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import (
Attention, AttentionProcessor, CogVideoXAttnProcessor2_0,
FusedCogVideoXAttnProcessor2_0)
from diffusers.models.embeddings import (CogVideoXPatchEmbed,
TimestepEmbedding, Timesteps,
get_3d_sincos_pos_embed)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import (AdaLayerNorm,
AdaLayerNormContinuous,
CogVideoXLayerNormZero, RMSNorm)
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
scale_lora_layers, unscale_lora_layers)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from ..dist import (QwenImageMultiGPUsAttnProcessor2_0,
get_sequence_parallel_rank,
get_sequence_parallel_world_size, get_sp_group)
from .attention_utils import attention
from .cache_utils import TeaCache
from ..utils import cfg_skip
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def apply_rotary_emb_qwen(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
conditioning = timesteps_emb
return conditioning
class QwenEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.rope_cache = {}
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{frame}_{height}_{width}"
if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
else:
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenDoubleStreamAttnProcessor2_0:
"""
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
implements joint attention computation where text and image streams are processed together.
"""
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if encoder_hidden_states is None:
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
seq_txt = encoder_hidden_states.shape[1]
# Compute QKV for image stream (sample projections)
img_query = attn.to_q(hidden_states)
img_key = attn.to_k(hidden_states)
img_value = attn.to_v(hidden_states)
# Compute QKV for text stream (context projections)
txt_query = attn.add_q_proj(encoder_hidden_states)
txt_key = attn.add_k_proj(encoder_hidden_states)
txt_value = attn.add_v_proj(encoder_hidden_states)
# Reshape for multi-head attention
img_query = img_query.unflatten(-1, (attn.heads, -1))
img_key = img_key.unflatten(-1, (attn.heads, -1))
img_value = img_value.unflatten(-1, (attn.heads, -1))
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
# Apply QK normalization
if attn.norm_q is not None:
img_query = attn.norm_q(img_query)
if attn.norm_k is not None:
img_key = attn.norm_k(img_key)
if attn.norm_added_q is not None:
txt_query = attn.norm_added_q(txt_query)
if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)
# Apply RoPE
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
# Concatenate for joint attention
# Order: [text, image]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_hidden_states = attention(
joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False
)
# Reshape back
joint_hidden_states = joint_hidden_states.flatten(2, 3)
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
# Split attention outputs back
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
# Apply output projections
img_attn_output = attn.to_out[0](img_attn_output)
if len(attn.to_out) > 1:
img_attn_output = attn.to_out[1](img_attn_output) # dropout
txt_attn_output = attn.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
@maybe_allow_in_graph
class QwenImageTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
# Image processing modules
self.img_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
)
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None, # Enable cross attention for joint computation
added_kv_proj_dim=dim, # Enable added KV projections for text stream
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=QwenDoubleStreamAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
)
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
# Text processing modules
self.txt_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
)
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
# Text doesn't need separate attention - it's handled by img_attn joint computation
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def _modulate(self, x, mod_params):
"""Apply modulation to input tensor"""
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
# Split modulation parameters for norm1 and norm2
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
# 1. Computes QKV for both streams
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
img_attn_output, txt_attn_output = attn_output
# Apply attention gates and add residual (like in Megatron)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_mlp_output = self.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
# Clip to prevent overflow for fp16
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Qwen.
Args:
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `60`):
The number of layers of dual stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `3584`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
"""
# _supports_gradient_checkpointing = True
# _no_split_modules = ["QwenImageTransformerBlock"]
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
# _repeated_blocks = ["QwenImageTransformerBlock"]
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: Optional[int] = 16,
num_layers: int = 60,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
guidance_embeds: bool = False, # TODO: this should probably be removed
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
self.img_in = nn.Linear(in_channels, self.inner_dim)
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.teacache = None
self.cfg_skip_ratio = None
self.current_steps = 0
self.num_inference_steps = None
self.gradient_checkpointing = False
self.sp_world_size = 1
self.sp_world_rank = 0
def _set_gradient_checkpointing(self, *args, **kwargs):
if "value" in kwargs:
self.gradient_checkpointing = kwargs["value"]
elif "enable" in kwargs:
self.gradient_checkpointing = kwargs["enable"]
else:
raise ValueError("Invalid set gradient checkpointing")
def enable_multi_gpus_inference(self,):
self.sp_world_size = get_sequence_parallel_world_size()
self.sp_world_rank = get_sequence_parallel_rank()
self.all_gather = get_sp_group().all_gather
self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0())
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
if cfg_skip_ratio != 0:
self.cfg_skip_ratio = cfg_skip_ratio
self.current_steps = 0
self.num_inference_steps = num_steps
else:
self.cfg_skip_ratio = None
self.current_steps = 0
self.num_inference_steps = None
def share_cfg_skip(
self,
transformer = None,
):
self.cfg_skip_ratio = transformer.cfg_skip_ratio
self.current_steps = transformer.current_steps
self.num_inference_steps = transformer.num_inference_steps
def disable_cfg_skip(self):
self.cfg_skip_ratio = None
self.current_steps = 0
self.num_inference_steps = None
def enable_teacache(
self,
coefficients,
num_steps: int,
rel_l1_thresh: float,
num_skip_start_steps: int = 0,
offload: bool = True,
):
self.teacache = TeaCache(
coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
)
def share_teacache(
self,
transformer = None,
):
self.teacache = transformer.teacache
def disable_teacache(self):
self.teacache = None
@cfg_skip()
def forward_bs(self, x, *args, **kwargs):
func = self.forward
sig = inspect.signature(func)
bs = len(x)
bs_half = int(bs // 2)
if bs >= 2:
# cond
x_i = x[bs_half:]
args_i = [
arg[bs_half:] if
isinstance(arg,
(torch.Tensor, list, tuple, np.ndarray)) and
len(arg) == bs else arg for arg in args
]
kwargs_i = {
k: (v[bs_half:] if
isinstance(v,
(torch.Tensor, list, tuple,
np.ndarray)) and len(v) == bs else v
) for k, v in kwargs.items()
}
if 'cond_flag' in sig.parameters:
kwargs_i["cond_flag"] = True
cond_out = func(x_i, *args_i, **kwargs_i)
# uncond
uncond_x_i = x[:bs_half]
uncond_args_i = [
arg[:bs_half] if
isinstance(arg,
(torch.Tensor, list, tuple, np.ndarray)) and
len(arg) == bs else arg for arg in args
]
uncond_kwargs_i = {
k: (v[:bs_half] if
isinstance(v,
(torch.Tensor, list, tuple,
np.ndarray)) and len(v) == bs else v
) for k, v in kwargs.items()
}
if 'cond_flag' in sig.parameters:
uncond_kwargs_i["cond_flag"] = False
uncond_out = func(uncond_x_i, *uncond_args_i,
**uncond_kwargs_i)
x = torch.cat([uncond_out, cond_out], dim=0)
else:
x = func(x, *args, **kwargs)
return x
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
cond_flag: bool = True,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`QwenTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
Mask of the input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if isinstance(encoder_hidden_states, list):
encoder_hidden_states = torch.stack(encoder_hidden_states)
encoder_hidden_states_mask = torch.stack(encoder_hidden_states_mask)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# Context Parallel
if self.sp_world_size > 1:
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
if image_rotary_emb is not None:
image_rotary_emb = (
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
image_rotary_emb[1]
)
# TeaCache
if self.teacache is not None:
if cond_flag:
inp = hidden_states.clone()
temb_ = temb.clone()
encoder_hidden_states_ = encoder_hidden_states.clone()
img_mod_params_ = self.transformer_blocks[0].img_mod(temb_)
img_mod1_, img_mod2_ = img_mod_params_.chunk(2, dim=-1)
img_normed_ = self.transformer_blocks[0].img_norm1(inp)
modulated_inp, img_gate1_ = self.transformer_blocks[0]._modulate(img_normed_, img_mod1_)
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
if skip_flag:
self.should_calc = True
self.teacache.accumulated_rel_l1_distance = 0
else:
if cond_flag:
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
self.should_calc = False
else:
self.should_calc = True
self.teacache.accumulated_rel_l1_distance = 0
self.teacache.previous_modulated_input = modulated_inp
self.teacache.should_calc = self.should_calc
else:
self.should_calc = self.teacache.should_calc
# TeaCache
if self.teacache is not None:
if not self.should_calc:
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
hidden_states = hidden_states + previous_residual.to(hidden_states.device)[-hidden_states.size()[0]:,]
else:
ori_hidden_states = hidden_states.clone().cpu() if self.teacache.offload else hidden_states.clone()
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
if cond_flag:
self.teacache.previous_residual_cond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
else:
self.teacache.previous_residual_uncond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
del ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
# Use only the image part (hidden_states) from the dual-stream blocks
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if self.sp_world_size > 1:
output = self.all_gather(output, dim=1)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if self.teacache is not None and cond_flag:
self.teacache.cnt += 1
if self.teacache.cnt == self.teacache.num_steps:
self.teacache.reset()
return output
@classmethod
def from_pretrained(
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if "dict_mapping" in transformer_additional_kwargs.keys():
for key in transformer_additional_kwargs["dict_mapping"]:
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
if low_cpu_mem_usage:
try:
import re
from diffusers import __version__ as diffusers_version
if diffusers_version >= "0.33.0":
from diffusers.models.model_loading_utils import \
load_model_dict_into_meta
else:
from diffusers.models.modeling_utils import \
load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
import accelerate
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **transformer_additional_kwargs)
param_device = "cpu"
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
print(model_files_safetensors)
for _model_file_safetensors in model_files_safetensors:
_state_dict = load_file(_model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
filtered_state_dict = {}
for key in state_dict:
if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
filtered_state_dict[key] = state_dict[key]
else:
print(f"Skipping key '{key}' due to size mismatch or absence in model.")
model_keys = set(model.state_dict().keys())
loaded_keys = set(filtered_state_dict.keys())
missing_keys = model_keys - loaded_keys
def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
initialized_dict = {}
with torch.no_grad():
for key in missing_keys:
param_shape = model_state_dict[key].shape
param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
if 'weight' in key:
if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
elif 'embedding' in key or 'embed' in key:
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
elif 'head' in key or 'output' in key or 'proj_out' in key:
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
elif len(param_shape) >= 2:
initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
nn.init.xavier_uniform_(initialized_dict[key])
else:
initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
elif 'bias' in key:
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
elif 'running_mean' in key:
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
elif 'running_var' in key:
initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
elif 'num_batches_tracked' in key:
initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
else:
initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
return initialized_dict
if missing_keys:
print(f"Missing keys will be initialized: {sorted(missing_keys)}")
initialized_params = initialize_missing_parameters(
missing_keys,
model.state_dict(),
torch_dtype
)
filtered_state_dict.update(initialized_params)
if diffusers_version >= "0.33.0":
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
load_model_dict_into_meta(
model,
filtered_state_dict,
dtype=torch_dtype,
model_name_or_path=pretrained_model_path,
)
else:
model._convert_deprecated_attention_blocks(filtered_state_dict)
unexpected_keys = load_model_dict_into_meta(
model,
filtered_state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
print(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
return model
except Exception as e:
print(
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
)
model = cls.from_config(config, **transformer_additional_kwargs)
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file, safe_open
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for _model_file_safetensors in model_files_safetensors:
_state_dict = load_file(_model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
print(f"### All Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
model = model.to(torch_dtype)
return model