Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,780 Bytes
be751d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import importlib.util
from diffusers import AutoencoderKL
# from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
# CLIPTextModel, CLIPTokenizer,
# CLIPVisionModelWithProjection, LlamaModel,
# LlamaTokenizerFast, LlavaForConditionalGeneration,
# Mistral3ForConditionalGeneration, PixtralProcessor,
# Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
# T5TokenizerFast)
# try:
# from transformers import (Qwen2_5_VLConfig,
# Qwen2_5_VLForConditionalGeneration,
# Qwen2Tokenizer, Qwen2VLProcessor)
# except:
# Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
# Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
# print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
# from .cogvideox_transformer3d import CogVideoXTransformer3DModel
# from .cogvideox_vae import AutoencoderKLCogVideoX
# from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
# from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
# from .flux2_image_processor import Flux2ImageProcessor
# from .flux2_transformer2d import Flux2Transformer2DModel
# from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
# from .flux2_vae import AutoencoderKLFlux2
# from .flux_transformer2d import FluxTransformer2DModel
# from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
# from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
# from .qwenimage_transformer2d import QwenImageTransformer2DModel
# from .qwenimage_vae import AutoencoderKLQwenImage
# from .wan_audio_encoder import WanAudioEncoder
# from .wan_image_encoder import CLIPModel
# from .wan_text_encoder import WanT5EncoderModel
# from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
# WanSelfAttention, WanTransformer3DModel)
# from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
# from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
# from .wan_transformer3d_vace import VaceWanTransformer3DModel
# from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
# from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
from .z_image_transformer2d import ZImageTransformer2DModel
from .z_image_transformer2d_control import ZImageControlTransformer2DModel
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
# if importlib.util.find_spec("paifuser") is not None:
# # --------------------------------------------------------------- #
# # The simple_wrapper is used to solve the problem
# # about conflicts between cython and torch.compile
# # --------------------------------------------------------------- #
# def simple_wrapper(func):
# def inner(*args, **kwargs):
# return func(*args, **kwargs)
# return inner
# # --------------------------------------------------------------- #
# # VAE Parallel Kernel
# # --------------------------------------------------------------- #
# from ..dist import parallel_magvit_vae
# AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
# AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
# # --------------------------------------------------------------- #
# # Sparse Attention
# # --------------------------------------------------------------- #
# import torch
# from paifuser.ops import wan_sparse_attention_wrapper
# WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
# print("Import Sparse Attention")
# WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
# # --------------------------------------------------------------- #
# # CFG Skip Turbo
# # --------------------------------------------------------------- #
# import os
# if importlib.util.find_spec("paifuser.accelerator") is not None:
# from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
# enable_cfg_skip, share_cfg_skip)
# else:
# from paifuser import (cfg_skip_turbo, disable_cfg_skip,
# enable_cfg_skip, share_cfg_skip)
# WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
# WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
# WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
# QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
# QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
# print("Import CFG Skip Turbo")
# # --------------------------------------------------------------- #
# # RMS Norm Kernel
# # --------------------------------------------------------------- #
# from paifuser.ops import rms_norm_forward
# WanRMSNorm.forward = rms_norm_forward
# print("Import PAI RMS Fuse")
# # --------------------------------------------------------------- #
# # Fast Rope Kernel
# # --------------------------------------------------------------- #
# import types
# import torch
# from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
# rope_apply_real_qk)
# from . import wan_transformer3d
# def deepcopy_function(f):
# return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
# local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
# if ENABLE_KERNEL:
# def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
# if torch.is_grad_enabled():
# return local_rope_apply_qk(q, k, grid_sizes, freqs)
# else:
# return fast_rope_apply_qk(q, k, grid_sizes, freqs)
# else:
# def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
# return rope_apply_real_qk(q, k, grid_sizes, freqs)
# wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
# rope_apply_qk = adaptive_fast_rope_apply_qk
# print("Import PAI Fast rope") |