import importlib.util from diffusers import AutoencoderKL # from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor, # CLIPTextModel, CLIPTokenizer, # CLIPVisionModelWithProjection, LlamaModel, # LlamaTokenizerFast, LlavaForConditionalGeneration, # Mistral3ForConditionalGeneration, PixtralProcessor, # Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer, # T5TokenizerFast) # try: # from transformers import (Qwen2_5_VLConfig, # Qwen2_5_VLForConditionalGeneration, # Qwen2Tokenizer, Qwen2VLProcessor) # except: # Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None # Qwen2VLProcessor, Qwen2_5_VLConfig = None, None # print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.") # from .cogvideox_transformer3d import CogVideoXTransformer3DModel # from .cogvideox_vae import AutoencoderKLCogVideoX # from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder # from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel # from .flux2_image_processor import Flux2ImageProcessor # from .flux2_transformer2d import Flux2Transformer2DModel # from .flux2_transformer2d_control import Flux2ControlTransformer2DModel # from .flux2_vae import AutoencoderKLFlux2 # from .flux_transformer2d import FluxTransformer2DModel # from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel # from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo # from .qwenimage_transformer2d import QwenImageTransformer2DModel # from .qwenimage_vae import AutoencoderKLQwenImage # from .wan_audio_encoder import WanAudioEncoder # from .wan_image_encoder import CLIPModel # from .wan_text_encoder import WanT5EncoderModel # from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm, # WanSelfAttention, WanTransformer3DModel) # from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate # from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V # from .wan_transformer3d_vace import VaceWanTransformer3DModel # from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_ # from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8 from .z_image_transformer2d import ZImageTransformer2DModel from .z_image_transformer2d_control import ZImageControlTransformer2DModel # The pai_fuser is an internally developed acceleration package, which can be used on PAI. # if importlib.util.find_spec("paifuser") is not None: # # --------------------------------------------------------------- # # # The simple_wrapper is used to solve the problem # # about conflicts between cython and torch.compile # # --------------------------------------------------------------- # # def simple_wrapper(func): # def inner(*args, **kwargs): # return func(*args, **kwargs) # return inner # # --------------------------------------------------------------- # # # VAE Parallel Kernel # # --------------------------------------------------------------- # # from ..dist import parallel_magvit_vae # AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode)) # AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode)) # # --------------------------------------------------------------- # # # Sparse Attention # # --------------------------------------------------------------- # # import torch # from paifuser.ops import wan_sparse_attention_wrapper # WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward)) # print("Import Sparse Attention") # WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward) # # --------------------------------------------------------------- # # # CFG Skip Turbo # # --------------------------------------------------------------- # # import os # if importlib.util.find_spec("paifuser.accelerator") is not None: # from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip, # enable_cfg_skip, share_cfg_skip) # else: # from paifuser import (cfg_skip_turbo, disable_cfg_skip, # enable_cfg_skip, share_cfg_skip) # WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip) # WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip) # WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip) # QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip) # QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip) # print("Import CFG Skip Turbo") # # --------------------------------------------------------------- # # # RMS Norm Kernel # # --------------------------------------------------------------- # # from paifuser.ops import rms_norm_forward # WanRMSNorm.forward = rms_norm_forward # print("Import PAI RMS Fuse") # # --------------------------------------------------------------- # # # Fast Rope Kernel # # --------------------------------------------------------------- # # import types # import torch # from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk, # rope_apply_real_qk) # from . import wan_transformer3d # def deepcopy_function(f): # return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__) # local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk) # if ENABLE_KERNEL: # def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs): # if torch.is_grad_enabled(): # return local_rope_apply_qk(q, k, grid_sizes, freqs) # else: # return fast_rope_apply_qk(q, k, grid_sizes, freqs) # else: # def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs): # return rope_apply_real_qk(q, k, grid_sizes, freqs) # wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk # rope_apply_qk = adaptive_fast_rope_apply_qk # print("Import PAI Fast rope")