Spaces:
Running
on
Zero
Running
on
Zero
| # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py | |
| # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.loaders.single_file_model import FromOriginalModelMixin | |
| from diffusers.models.attention import FeedForward | |
| from diffusers.models.attention_processor import AttentionProcessor | |
| from diffusers.models.embeddings import ( | |
| CombinedTimestepGuidanceTextProjEmbeddings, | |
| CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed) | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.normalization import (AdaLayerNormContinuous, | |
| AdaLayerNormZero, | |
| AdaLayerNormZeroSingle) | |
| from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, | |
| scale_lora_layers, unscale_lora_layers) | |
| from diffusers.utils.torch_utils import maybe_allow_in_graph | |
| from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, get_sp_group) | |
| from .attention_utils import attention | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| encoder_query = encoder_key = encoder_value = None | |
| if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: | |
| encoder_query = attn.add_q_proj(encoder_hidden_states) | |
| encoder_key = attn.add_k_proj(encoder_hidden_states) | |
| encoder_value = attn.add_v_proj(encoder_hidden_states) | |
| return query, key, value, encoder_query, encoder_key, encoder_value | |
| def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): | |
| return _get_projections(attn, hidden_states, encoder_hidden_states) | |
| def apply_rotary_emb( | |
| x: torch.Tensor, | |
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], | |
| use_real: bool = True, | |
| use_real_unbind_dim: int = -1, | |
| sequence_dim: int = 2, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings | |
| to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are | |
| reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting | |
| tensors contain rotary embeddings and are returned as real tensors. | |
| Args: | |
| x (`torch.Tensor`): | |
| Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply | |
| freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | |
| """ | |
| if use_real: | |
| cos, sin = freqs_cis # [S, D] | |
| if sequence_dim == 2: | |
| cos = cos[None, None, :, :] | |
| sin = sin[None, None, :, :] | |
| elif sequence_dim == 1: | |
| cos = cos[None, :, None, :] | |
| sin = sin[None, :, None, :] | |
| else: | |
| raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") | |
| cos, sin = cos.to(x.device), sin.to(x.device) | |
| if use_real_unbind_dim == -1: | |
| # Used for flux, cogvideox, hunyuan-dit | |
| x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] | |
| x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
| elif use_real_unbind_dim == -2: | |
| # Used for Stable Audio, OmniGen, CogView4 and Cosmos | |
| x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] | |
| x_rotated = torch.cat([-x_imag, x_real], dim=-1) | |
| else: | |
| raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") | |
| out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | |
| return out | |
| else: | |
| # used for lumina | |
| x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
| freqs_cis = freqs_cis.unsqueeze(2) | |
| x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) | |
| return x_out.type_as(x) | |
| class FluxAttnProcessor: | |
| _attention_backend = None | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") | |
| def __call__( | |
| self, | |
| attn: "FluxAttention", | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| text_seq_len: int = None, | |
| ) -> torch.Tensor: | |
| query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( | |
| attn, hidden_states, encoder_hidden_states | |
| ) | |
| query = query.unflatten(-1, (attn.heads, -1)) | |
| key = key.unflatten(-1, (attn.heads, -1)) | |
| value = value.unflatten(-1, (attn.heads, -1)) | |
| query = attn.norm_q(query) | |
| key = attn.norm_k(key) | |
| if attn.added_kv_proj_dim is not None: | |
| encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) | |
| encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) | |
| encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) | |
| encoder_query = attn.norm_added_q(encoder_query) | |
| encoder_key = attn.norm_added_k(encoder_key) | |
| query = torch.cat([encoder_query, query], dim=1) | |
| key = torch.cat([encoder_key, key], dim=1) | |
| value = torch.cat([encoder_value, value], dim=1) | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) | |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) | |
| hidden_states = attention( | |
| query, key, value, attn_mask=attention_mask, | |
| ) | |
| hidden_states = hidden_states.flatten(2, 3) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if encoder_hidden_states is not None: | |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( | |
| [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 | |
| ) | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | |
| return hidden_states, encoder_hidden_states | |
| else: | |
| return hidden_states | |
| class FluxAttention(torch.nn.Module): | |
| _default_processor_cls = FluxAttnProcessor | |
| _available_processors = [ | |
| FluxAttnProcessor, | |
| ] | |
| def __init__( | |
| self, | |
| query_dim: int, | |
| heads: int = 8, | |
| dim_head: int = 64, | |
| dropout: float = 0.0, | |
| bias: bool = False, | |
| added_kv_proj_dim: Optional[int] = None, | |
| added_proj_bias: Optional[bool] = True, | |
| out_bias: bool = True, | |
| eps: float = 1e-5, | |
| out_dim: int = None, | |
| context_pre_only: Optional[bool] = None, | |
| pre_only: bool = False, | |
| elementwise_affine: bool = True, | |
| processor=None, | |
| ): | |
| super().__init__() | |
| self.head_dim = dim_head | |
| self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
| self.query_dim = query_dim | |
| self.use_bias = bias | |
| self.dropout = dropout | |
| self.out_dim = out_dim if out_dim is not None else query_dim | |
| self.context_pre_only = context_pre_only | |
| self.pre_only = pre_only | |
| self.heads = out_dim // dim_head if out_dim is not None else heads | |
| self.added_kv_proj_dim = added_kv_proj_dim | |
| self.added_proj_bias = added_proj_bias | |
| self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
| self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
| self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) | |
| self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) | |
| self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) | |
| if not self.pre_only: | |
| self.to_out = torch.nn.ModuleList([]) | |
| self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) | |
| self.to_out.append(torch.nn.Dropout(dropout)) | |
| if added_kv_proj_dim is not None: | |
| self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) | |
| self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) | |
| self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
| self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
| self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | |
| self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) | |
| if processor is None: | |
| self.processor = self._default_processor_cls() | |
| else: | |
| self.processor = processor | |
| def set_processor(self, processor: "AttnProcessor") -> None: | |
| r""" | |
| Set the attention processor to use. | |
| Args: | |
| processor (`AttnProcessor`): | |
| The attention processor to use. | |
| """ | |
| # if current processor is in `self._modules` and if passed `processor` is not, we need to | |
| # pop `processor` from `self._modules` | |
| if ( | |
| hasattr(self, "processor") | |
| and isinstance(self.processor, torch.nn.Module) | |
| and not isinstance(processor, torch.nn.Module) | |
| ): | |
| logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") | |
| self._modules.pop("processor") | |
| self.processor = processor | |
| def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": | |
| r""" | |
| Get the attention processor in use. | |
| Args: | |
| return_deprecated_lora (`bool`, *optional*, defaults to `False`): | |
| Set to `True` to return the deprecated LoRA attention processor. | |
| Returns: | |
| "AttentionProcessor": The attention processor in use. | |
| """ | |
| if not return_deprecated_lora: | |
| return self.processor | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) | |
| quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} | |
| unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] | |
| if len(unused_kwargs) > 0: | |
| logger.warning( | |
| f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." | |
| ) | |
| kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} | |
| return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) | |
| class FluxSingleTransformerBlock(nn.Module): | |
| def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): | |
| super().__init__() | |
| self.mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.norm = AdaLayerNormZeroSingle(dim) | |
| self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) | |
| self.act_mlp = nn.GELU(approximate="tanh") | |
| self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) | |
| self.attn = FluxAttention( | |
| query_dim=dim, | |
| dim_head=attention_head_dim, | |
| heads=num_attention_heads, | |
| out_dim=dim, | |
| bias=True, | |
| processor=FluxAttnProcessor(), | |
| eps=1e-6, | |
| pre_only=True, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| text_seq_len = encoder_hidden_states.shape[1] | |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
| residual = hidden_states | |
| norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
| mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
| joint_attention_kwargs = joint_attention_kwargs or {} | |
| attn_output = self.attn( | |
| hidden_states=norm_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| text_seq_len=text_seq_len, | |
| **joint_attention_kwargs, | |
| ) | |
| hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
| gate = gate.unsqueeze(1) | |
| hidden_states = gate * self.proj_out(hidden_states) | |
| hidden_states = residual + hidden_states | |
| if hidden_states.dtype == torch.float16: | |
| hidden_states = hidden_states.clip(-65504, 65504) | |
| encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] | |
| return encoder_hidden_states, hidden_states | |
| class FluxTransformerBlock(nn.Module): | |
| def __init__( | |
| self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 | |
| ): | |
| super().__init__() | |
| self.norm1 = AdaLayerNormZero(dim) | |
| self.norm1_context = AdaLayerNormZero(dim) | |
| self.attn = FluxAttention( | |
| query_dim=dim, | |
| added_kv_proj_dim=dim, | |
| dim_head=attention_head_dim, | |
| heads=num_attention_heads, | |
| out_dim=dim, | |
| context_pre_only=False, | |
| bias=True, | |
| processor=FluxAttnProcessor(), | |
| eps=eps, | |
| ) | |
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
| self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
| self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
| self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
| norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
| encoder_hidden_states, emb=temb | |
| ) | |
| joint_attention_kwargs = joint_attention_kwargs or {} | |
| # Attention. | |
| attention_outputs = self.attn( | |
| hidden_states=norm_hidden_states, | |
| encoder_hidden_states=norm_encoder_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| **joint_attention_kwargs, | |
| ) | |
| if len(attention_outputs) == 2: | |
| attn_output, context_attn_output = attention_outputs | |
| elif len(attention_outputs) == 3: | |
| attn_output, context_attn_output, ip_attn_output = attention_outputs | |
| # Process attention outputs for the `hidden_states`. | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = hidden_states + attn_output | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| ff_output = self.ff(norm_hidden_states) | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| hidden_states = hidden_states + ff_output | |
| if len(attention_outputs) == 3: | |
| hidden_states = hidden_states + ip_attn_output | |
| # Process attention outputs for the `encoder_hidden_states`. | |
| context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output | |
| encoder_hidden_states = encoder_hidden_states + context_attn_output | |
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
| norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] | |
| context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
| encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output | |
| if encoder_hidden_states.dtype == torch.float16: | |
| encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) | |
| return encoder_hidden_states, hidden_states | |
| class FluxPosEmbed(nn.Module): | |
| # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 | |
| def __init__(self, theta: int, axes_dim: List[int]): | |
| super().__init__() | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
| n_axes = ids.shape[-1] | |
| cos_out = [] | |
| sin_out = [] | |
| pos = ids.float() | |
| is_mps = ids.device.type == "mps" | |
| is_npu = ids.device.type == "npu" | |
| freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 | |
| for i in range(n_axes): | |
| cos, sin = get_1d_rotary_pos_embed( | |
| self.axes_dim[i], | |
| pos[:, i], | |
| theta=self.theta, | |
| repeat_interleave_real=True, | |
| use_real=True, | |
| freqs_dtype=freqs_dtype, | |
| ) | |
| cos_out.append(cos) | |
| sin_out.append(sin) | |
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) | |
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) | |
| return freqs_cos, freqs_sin | |
| class FluxTransformer2DModel( | |
| ModelMixin, | |
| ConfigMixin, | |
| PeftAdapterMixin, | |
| FromOriginalModelMixin, | |
| ): | |
| """ | |
| The Transformer model introduced in Flux. | |
| Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ | |
| Args: | |
| patch_size (`int`, defaults to `1`): | |
| Patch size to turn the input data into small patches. | |
| in_channels (`int`, defaults to `64`): | |
| The number of channels in the input. | |
| out_channels (`int`, *optional*, defaults to `None`): | |
| The number of channels in the output. If not specified, it defaults to `in_channels`. | |
| num_layers (`int`, defaults to `19`): | |
| The number of layers of dual stream DiT blocks to use. | |
| num_single_layers (`int`, defaults to `38`): | |
| The number of layers of single stream DiT blocks to use. | |
| attention_head_dim (`int`, defaults to `128`): | |
| The number of dimensions to use for each attention head. | |
| num_attention_heads (`int`, defaults to `24`): | |
| The number of attention heads to use. | |
| joint_attention_dim (`int`, defaults to `4096`): | |
| The number of dimensions to use for the joint attention (embedding/channel dimension of | |
| `encoder_hidden_states`). | |
| pooled_projection_dim (`int`, defaults to `768`): | |
| The number of dimensions to use for the pooled projection. | |
| guidance_embeds (`bool`, defaults to `False`): | |
| Whether to use guidance embeddings for guidance-distilled variant of the model. | |
| axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): | |
| The dimensions to use for the rotary positional embeddings. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| # _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] | |
| # _skip_layerwise_casting_patterns = ["pos_embed", "norm"] | |
| # _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] | |
| def __init__( | |
| self, | |
| patch_size: int = 1, | |
| in_channels: int = 64, | |
| out_channels: Optional[int] = None, | |
| num_layers: int = 19, | |
| num_single_layers: int = 38, | |
| attention_head_dim: int = 128, | |
| num_attention_heads: int = 24, | |
| joint_attention_dim: int = 4096, | |
| pooled_projection_dim: int = 768, | |
| guidance_embeds: bool = False, | |
| axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels or in_channels | |
| self.inner_dim = num_attention_heads * attention_head_dim | |
| self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) | |
| text_time_guidance_cls = ( | |
| CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings | |
| ) | |
| self.time_text_embed = text_time_guidance_cls( | |
| embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | |
| ) | |
| self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) | |
| self.x_embedder = nn.Linear(in_channels, self.inner_dim) | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| FluxTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.single_transformer_blocks = nn.ModuleList( | |
| [ | |
| FluxSingleTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| ) | |
| for _ in range(num_single_layers) | |
| ] | |
| ) | |
| self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
| self.gradient_checkpointing = False | |
| self.sp_world_size = 1 | |
| self.sp_world_rank = 0 | |
| def _set_gradient_checkpointing(self, *args, **kwargs): | |
| if "value" in kwargs: | |
| self.gradient_checkpointing = kwargs["value"] | |
| elif "enable" in kwargs: | |
| self.gradient_checkpointing = kwargs["enable"] | |
| else: | |
| raise ValueError("Invalid set gradient checkpointing") | |
| def enable_multi_gpus_inference(self,): | |
| self.sp_world_size = get_sequence_parallel_world_size() | |
| self.sp_world_rank = get_sequence_parallel_rank() | |
| self.all_gather = get_sp_group().all_gather | |
| self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0()) | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| r""" | |
| Returns: | |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with | |
| indexed by its weight name. | |
| """ | |
| # set recursively | |
| processors = {} | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
| if hasattr(module, "get_processor"): | |
| processors[f"{name}.processor"] = module.get_processor() | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
| r""" | |
| Sets the attention processor to use to compute attention. | |
| Parameters: | |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
| for **all** `Attention` layers. | |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
| processor. This is strongly recommended when setting trainable attention processors. | |
| """ | |
| count = len(self.attn_processors.keys()) | |
| if isinstance(processor, dict) and len(processor) != count: | |
| raise ValueError( | |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
| ) | |
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| if hasattr(module, "set_processor"): | |
| if not isinstance(processor, dict): | |
| module.set_processor(processor) | |
| else: | |
| module.set_processor(processor.pop(f"{name}.processor")) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
| for name, module in self.named_children(): | |
| fn_recursive_attn_processor(name, module, processor) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| pooled_projections: torch.Tensor = None, | |
| timestep: torch.LongTensor = None, | |
| img_ids: torch.Tensor = None, | |
| txt_ids: torch.Tensor = None, | |
| guidance: torch.Tensor = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| controlnet_block_samples=None, | |
| controlnet_single_block_samples=None, | |
| return_dict: bool = True, | |
| controlnet_blocks_repeat: bool = False, | |
| ) -> Union[torch.Tensor, Transformer2DModelOutput]: | |
| """ | |
| The [`FluxTransformer2DModel`] forward method. | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): | |
| Input `hidden_states`. | |
| encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): | |
| Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. | |
| pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected | |
| from the embeddings of input conditions. | |
| timestep ( `torch.LongTensor`): | |
| Used to indicate denoising step. | |
| block_controlnet_hidden_states: (`list` of `torch.Tensor`): | |
| A list of tensors that if specified are added to the residuals of transformer blocks. | |
| joint_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain | |
| tuple. | |
| Returns: | |
| If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a | |
| `tuple` where the first element is the sample tensor. | |
| """ | |
| if joint_attention_kwargs is not None: | |
| joint_attention_kwargs = joint_attention_kwargs.copy() | |
| lora_scale = joint_attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| else: | |
| if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: | |
| logger.warning( | |
| "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." | |
| ) | |
| hidden_states = self.x_embedder(hidden_states) | |
| timestep = timestep.to(hidden_states.dtype) * 1000 | |
| if guidance is not None: | |
| guidance = guidance.to(hidden_states.dtype) * 1000 | |
| temb = ( | |
| self.time_text_embed(timestep, pooled_projections) | |
| if guidance is None | |
| else self.time_text_embed(timestep, guidance, pooled_projections) | |
| ) | |
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | |
| if txt_ids.ndim == 3: | |
| logger.warning( | |
| "Passing `txt_ids` 3d torch.Tensor is deprecated." | |
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | |
| ) | |
| txt_ids = txt_ids[0] | |
| if img_ids.ndim == 3: | |
| logger.warning( | |
| "Passing `img_ids` 3d torch.Tensor is deprecated." | |
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | |
| ) | |
| img_ids = img_ids[0] | |
| ids = torch.cat((txt_ids, img_ids), dim=0) | |
| image_rotary_emb = self.pos_embed(ids) | |
| if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: | |
| ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") | |
| ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) | |
| joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) | |
| # Context Parallel | |
| if self.sp_world_size > 1: | |
| hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank] | |
| if image_rotary_emb is not None: | |
| txt_rotary_emb = ( | |
| image_rotary_emb[0][:encoder_hidden_states.shape[1]], | |
| image_rotary_emb[1][:encoder_hidden_states.shape[1]] | |
| ) | |
| image_rotary_emb = ( | |
| torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank], | |
| torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank], | |
| ) | |
| image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \ | |
| for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)] | |
| for index_block, block in enumerate(self.transformer_blocks): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, | |
| encoder_hidden_states, | |
| temb, | |
| image_rotary_emb, | |
| joint_attention_kwargs, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| encoder_hidden_states, hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| temb=temb, | |
| image_rotary_emb=image_rotary_emb, | |
| joint_attention_kwargs=joint_attention_kwargs, | |
| ) | |
| # controlnet residual | |
| if controlnet_block_samples is not None: | |
| interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) | |
| interval_control = int(np.ceil(interval_control)) | |
| # For Xlabs ControlNet. | |
| if controlnet_blocks_repeat: | |
| hidden_states = ( | |
| hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] | |
| ) | |
| else: | |
| hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] | |
| for index_block, block in enumerate(self.single_transformer_blocks): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, | |
| encoder_hidden_states, | |
| temb, | |
| image_rotary_emb, | |
| joint_attention_kwargs, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| encoder_hidden_states, hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| temb=temb, | |
| image_rotary_emb=image_rotary_emb, | |
| joint_attention_kwargs=joint_attention_kwargs, | |
| ) | |
| # controlnet residual | |
| if controlnet_single_block_samples is not None: | |
| interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) | |
| interval_control = int(np.ceil(interval_control)) | |
| hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] | |
| hidden_states = self.norm_out(hidden_states, temb) | |
| output = self.proj_out(hidden_states) | |
| if self.sp_world_size > 1: | |
| output = self.all_gather(output, dim=1) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) |