import functools import glob import json import math import os import types import warnings from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers) from diffusers.utils.torch_utils import maybe_allow_in_graph from torch import nn from .fuser import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group, init_distributed_environment, initialize_model_parallel, xFuserLongContextAttention) def apply_rotary_emb_qwen( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Used for Stable Audio, OmniGen, CogView4 and Cosmos x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class QwenImageMultiGPUsAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, # Image stream encoder_hidden_states: torch.FloatTensor = None, # Text stream encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") seq_txt = encoder_hidden_states.shape[1] # Compute QKV for image stream (sample projections) img_query = attn.to_q(hidden_states) img_key = attn.to_k(hidden_states) img_value = attn.to_v(hidden_states) # Compute QKV for text stream (context projections) txt_query = attn.add_q_proj(encoder_hidden_states) txt_key = attn.add_k_proj(encoder_hidden_states) txt_value = attn.add_v_proj(encoder_hidden_states) # Reshape for multi-head attention img_query = img_query.unflatten(-1, (attn.heads, -1)) img_key = img_key.unflatten(-1, (attn.heads, -1)) img_value = img_value.unflatten(-1, (attn.heads, -1)) txt_query = txt_query.unflatten(-1, (attn.heads, -1)) txt_key = txt_key.unflatten(-1, (attn.heads, -1)) txt_value = txt_value.unflatten(-1, (attn.heads, -1)) # Apply QK normalization if attn.norm_q is not None: img_query = attn.norm_q(img_query) if attn.norm_k is not None: img_key = attn.norm_k(img_key) if attn.norm_added_q is not None: txt_query = attn.norm_added_q(txt_query) if attn.norm_added_k is not None: txt_key = attn.norm_added_k(txt_key) # Apply RoPE if image_rotary_emb is not None: img_freqs, txt_freqs = image_rotary_emb img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) # Concatenate for joint attention # Order: [text, image] # joint_query = torch.cat([txt_query, img_query], dim=1) # joint_key = torch.cat([txt_key, img_key], dim=1) # joint_value = torch.cat([txt_value, img_value], dim=1) half_dtypes = (torch.float16, torch.bfloat16) def half(x): return x if x.dtype in half_dtypes else x.to(dtype) joint_hidden_states = xFuserLongContextAttention()( None, half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False, joint_tensor_query=half(txt_query), joint_tensor_key=half(txt_key), joint_tensor_value=half(txt_value), joint_strategy='front', ) # Reshape back joint_hidden_states = joint_hidden_states.flatten(2, 3) joint_hidden_states = joint_hidden_states.to(img_query.dtype) # Split attention outputs back txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part # Apply output projections img_attn_output = attn.to_out[0](img_attn_output) if len(attn.to_out) > 1: img_attn_output = attn.to_out[1](img_attn_output) # dropout txt_attn_output = attn.to_add_out(txt_attn_output) return img_attn_output, txt_attn_output