File size: 6,780 Bytes
be751d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import importlib.util

from diffusers import AutoencoderKL
# from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
#                           CLIPTextModel, CLIPTokenizer,
#                           CLIPVisionModelWithProjection, LlamaModel,
#                           LlamaTokenizerFast, LlavaForConditionalGeneration,
#                           Mistral3ForConditionalGeneration, PixtralProcessor,
#                           Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
#                           T5TokenizerFast)

# try:
#     from transformers import (Qwen2_5_VLConfig,
#                               Qwen2_5_VLForConditionalGeneration,
#                               Qwen2Tokenizer, Qwen2VLProcessor)
# except:
#     Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
#     Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
#     print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")

# from .cogvideox_transformer3d import CogVideoXTransformer3DModel
# from .cogvideox_vae import AutoencoderKLCogVideoX
# from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
# from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
# from .flux2_image_processor import Flux2ImageProcessor
# from .flux2_transformer2d import Flux2Transformer2DModel
# from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
# from .flux2_vae import AutoencoderKLFlux2
# from .flux_transformer2d import FluxTransformer2DModel
# from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
# from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
# from .qwenimage_transformer2d import QwenImageTransformer2DModel
# from .qwenimage_vae import AutoencoderKLQwenImage
# from .wan_audio_encoder import WanAudioEncoder
# from .wan_image_encoder import CLIPModel
# from .wan_text_encoder import WanT5EncoderModel
# from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
#                                 WanSelfAttention, WanTransformer3DModel)
# from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
# from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
# from .wan_transformer3d_vace import VaceWanTransformer3DModel
# from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
# from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
from .z_image_transformer2d import ZImageTransformer2DModel
from .z_image_transformer2d_control import ZImageControlTransformer2DModel

# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
# if importlib.util.find_spec("paifuser") is not None:
#     # --------------------------------------------------------------- #
#     #   The simple_wrapper is used to solve the problem 
#     #   about conflicts between cython and torch.compile
#     # --------------------------------------------------------------- #
#     def simple_wrapper(func):
#         def inner(*args, **kwargs):
#             return func(*args, **kwargs)
#         return inner

#     # --------------------------------------------------------------- #
#     #   VAE Parallel Kernel
#     # --------------------------------------------------------------- #
#     from ..dist import parallel_magvit_vae
#     AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
#     AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))

#     # --------------------------------------------------------------- #
#     #   Sparse Attention
#     # --------------------------------------------------------------- #
#     import torch
#     from paifuser.ops import wan_sparse_attention_wrapper
    
#     WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
#     print("Import Sparse Attention")

#     WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)

#     # --------------------------------------------------------------- #
#     #   CFG Skip Turbo
#     # --------------------------------------------------------------- #
#     import os

#     if importlib.util.find_spec("paifuser.accelerator") is not None:
#         from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
#                                           enable_cfg_skip, share_cfg_skip)
#     else:
#         from paifuser import (cfg_skip_turbo, disable_cfg_skip,
#                               enable_cfg_skip, share_cfg_skip)

#     WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
#     WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
#     WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)

#     QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
#     QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
#     print("Import CFG Skip Turbo")

#     # --------------------------------------------------------------- #
#     #   RMS Norm Kernel
#     # --------------------------------------------------------------- #
#     from paifuser.ops import rms_norm_forward
#     WanRMSNorm.forward = rms_norm_forward
#     print("Import PAI RMS Fuse")

#     # --------------------------------------------------------------- #
#     #   Fast Rope Kernel
#     # --------------------------------------------------------------- #
#     import types

#     import torch
#     from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
#                               rope_apply_real_qk)

#     from . import wan_transformer3d

#     def deepcopy_function(f):
#         return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)

#     local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)

#     if ENABLE_KERNEL:
#         def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
#             if torch.is_grad_enabled():
#                 return local_rope_apply_qk(q, k, grid_sizes, freqs)
#             else:
#                 return fast_rope_apply_qk(q, k, grid_sizes, freqs)
#     else:
#         def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
#             return rope_apply_real_qk(q, k, grid_sizes, freqs)
            
#     wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
#     rope_apply_qk = adaptive_fast_rope_apply_qk
#     print("Import PAI Fast rope")