ZIT-Controlnet / videox_fun /models /z_image_transformer2d_control.py
Alexander Bagus
22
be751d2
# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version,
scale_lora_layers, unscale_lora_layers)
import glob
import inspect
import json
import os
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import RMSNorm
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention_processor import Attention, AttentionProcessor
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
scale_lora_layers, unscale_lora_layers)
from .z_image_transformer2d import (ZImageTransformer2DModel, FinalLayer,
ZImageTransformerBlock)
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
class ZImageControlTransformerBlock(ZImageTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.after_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class BaseZImageTransformerBlock(ZImageTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
hidden_states = super().forward(hidden_states, **kwargs)
if self.block_id is not None:
hidden_states = hidden_states + hints[self.block_id] * context_scale
return hidden_states
class ZImageControlTransformer2DModel(ZImageTransformer2DModel):
@register_to_config
def __init__(
self,
control_layers_places=None,
control_in_dim=None,
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=3840,
n_layers=30,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
):
super().__init__(
all_patch_size=all_patch_size,
all_f_patch_size=all_f_patch_size,
in_channels=in_channels,
dim=dim,
n_layers=n_layers,
n_refiner_layers=n_refiner_layers,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
norm_eps=norm_eps,
qk_norm=qk_norm,
cap_feat_dim=cap_feat_dim,
rope_theta=rope_theta,
t_scale=t_scale,
axes_dims=axes_dims,
axes_lens=axes_lens,
)
self.control_layers_places = [i for i in range(0, self.num_layers, 2)] if control_layers_places is None else control_layers_places
self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim
assert 0 in self.control_layers_places
self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)}
# blocks
del self.layers
self.layers = nn.ModuleList(
[
BaseZImageTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
block_id=self.control_layers_mapping[i] if i in self.control_layers_places else None
)
for i in range(n_layers)
]
)
# control blocks
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
block_id=i
)
for i in self.control_layers_places
]
)
# control patch embeddings
all_x_embedder = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
print(f_patch_size * patch_size * patch_size * self.control_in_dim, dim)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=True,
)
for layer_id in range(n_refiner_layers)
]
)
def forward_control(
self,
x,
cap_feats,
control_context,
kwargs,
t=None,
patch_size=2,
f_patch_size=1,
):
# embeddings
bsz = len(control_context)
device = control_context[0].device
(
control_context,
x_size,
x_pos_ids,
x_inner_pad_mask,
) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
# control_context embed & refine
x_item_seqlens = [len(_) for _ in control_context]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
control_context = torch.cat(control_context, dim=0)
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
# Match t_embedder output dtype to control_context for layerwise casting compatibility
adaln_input = t.type_as(control_context)
control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
control_context = list(control_context.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
# Context Parallel
if self.sp_world_size > 1:
control_context = torch.chunk(control_context, self.sp_world_size, dim=1)[self.sp_world_rank]
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.control_noise_refiner:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
control_context = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
control_context, x_attn_mask, x_freqs_cis, adaln_input,
**ckpt_kwargs,
)
else:
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input)
# unified
cap_item_seqlens = [len(_) for _ in cap_feats]
control_context_unified = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]]))
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
c = control_context_unified
# Context Parallel
if self.sp_world_size > 1:
c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for layer in self.control_layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, **static_kwargs):
def custom_forward(*inputs):
return module(*inputs, **static_kwargs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer, **new_kwargs),
c,
**ckpt_kwargs,
)
else:
c = layer(c, **new_kwargs)
hints = torch.unbind(c)[:-1]
return hints
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
control_context=None,
control_context_scale=1.0,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# x embed & refine
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
# Match t_embedder output dtype to x for layerwise casting compatibility
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
# Context Parallel
if self.sp_world_size > 1:
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
x, x_attn_mask, x_freqs_cis, adaln_input,
**ckpt_kwargs,
)
else:
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
cap_feats = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
cap_feats,
cap_attn_mask,
cap_freqs_cis,
**ckpt_kwargs,
)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
# unified
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
# Arguments
kwargs = dict(
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
)
hints = self.forward_control(
unified, cap_feats, control_context, kwargs, t=t, patch_size=patch_size, f_patch_size=f_patch_size,
)
for layer in self.layers:
# Arguments
kwargs = dict(
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
hints=hints,
context_scale=control_context_scale
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, **static_kwargs):
def custom_forward(*inputs):
return module(*inputs, **static_kwargs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
unified = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer, **kwargs),
unified,
**ckpt_kwargs,
)
else:
unified = layer(unified, **kwargs)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
if self.sp_world_size > 1:
x = self.all_gather(x, dim=1)
x = torch.stack(x)
return x, {}