Image2Video / app_quant.py
rahul7star's picture
Update app_quant.py
5fca444 verified
raw
history blame
8.8 kB
import torch
import spaces
import gradio as gr
import sys
import platform
import diffusers
import transformers
import os
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import ZImagePipeline, AutoModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
# ============================================================
# LOGGING BUFFER
# ============================================================
LOGS = ""
def log(msg):
global LOGS
print(msg)
LOGS += msg + "\n"
return msg
# ============================================================
# ENVIRONMENT INFO
# ============================================================
log("===================================================")
log("πŸ” Z-IMAGE-TURBO DEBUGGING + ROBUST TRANSFORMER INSPECTION")
log("===================================================\n")
log(f"πŸ“Œ PYTHON VERSION : {sys.version.replace(chr(10), ' ')}")
log(f"πŸ“Œ PLATFORM : {platform.platform()}")
log(f"πŸ“Œ TORCH VERSION : {torch.__version__}")
log(f"πŸ“Œ TRANSFORMERS VERSION : {transformers.__version__}")
log(f"πŸ“Œ DIFFUSERS VERSION : {diffusers.__version__}")
log(f"πŸ“Œ CUDA AVAILABLE : {torch.cuda.is_available()}")
if torch.cuda.is_available():
log(f"πŸ“Œ GPU NAME : {torch.cuda.get_device_name(0)}")
log(f"πŸ“Œ GPU CAPABILITY : {torch.cuda.get_device_capability(0)}")
log(f"πŸ“Œ GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
log(f"πŸ“Œ FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}")
else:
raise RuntimeError("❌ CUDA is REQUIRED but not available.")
device = "cuda"
gpu_id = 0
# ============================================================
# MODEL SETTINGS
# ============================================================
model_cache = "./weights/"
model_id = "Tongyi-MAI/Z-Image-Turbo"
torch_dtype = torch.bfloat16
USE_CPU_OFFLOAD = False
log("\n===================================================")
log("🧠 MODEL CONFIGURATION")
log("===================================================")
log(f"Model ID : {model_id}")
log(f"Model Cache Directory : {model_cache}")
log(f"torch_dtype : {torch_dtype}")
log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}")
# ============================================================
# ROBUST TRANSFORMER INSPECTION FUNCTION
# ============================================================
def inspect_transformer(model, model_name="Transformer"):
log(f"\nπŸ” {model_name} Architecture Details:")
try:
block_attrs = ["transformer_blocks", "blocks", "layers", "encoder_blocks", "model"]
blocks = None
for attr in block_attrs:
blocks = getattr(model, attr, None)
if blocks is not None:
break
if blocks is None:
log(f"⚠️ Could not find transformer blocks in {model_name}, skipping detailed block info")
else:
try:
log(f"Number of Transformer Modules : {len(blocks)}")
for i, block in enumerate(blocks):
log(f" Block {i}: {block.__class__.__name__}")
attn_type = getattr(block, "attn", None)
if attn_type:
log(f" Attention: {attn_type.__class__.__name__}")
flash_enabled = getattr(attn_type, "flash", None)
log(f" FlashAttention Enabled? : {flash_enabled}")
except Exception as e:
log(f"⚠️ Error inspecting blocks: {e}")
config = getattr(model, "config", None)
if config:
log(f"Hidden size: {getattr(config, 'hidden_size', 'N/A')}")
log(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'N/A')}")
log(f"Number of layers: {getattr(config, 'num_hidden_layers', 'N/A')}")
log(f"Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}")
else:
log(f"⚠️ No config attribute found in {model_name}")
except Exception as e:
log(f"⚠️ Failed to inspect {model_name}: {e}")
# ============================================================
# LOAD TRANSFORMER BLOCK
# ============================================================
log("\n===================================================")
log("πŸ”§ LOADING TRANSFORMER BLOCK")
log("===================================================")
quantization_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
log("4-bit Quantization Config (Transformer):")
log(str(quantization_config))
transformer = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
log("βœ… Transformer block loaded successfully.")
inspect_transformer(transformer, "Transformer")
if USE_CPU_OFFLOAD:
transformer = transformer.to("cpu")
# ============================================================
# LOAD TEXT ENCODER
# ============================================================
log("\n===================================================")
log("πŸ”§ LOADING TEXT ENCODER")
log("===================================================")
quantization_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)
log("4-bit Quantization Config (Text Encoder):")
log(str(quantization_config))
text_encoder = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="text_encoder",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
log("βœ… Text encoder loaded successfully.")
inspect_transformer(text_encoder, "Text Encoder")
if USE_CPU_OFFLOAD:
text_encoder = text_encoder.to("cpu")
# ============================================================
# BUILD PIPELINE
# ============================================================
log("\n===================================================")
log("πŸ”§ BUILDING Z-IMAGE-TURBO PIPELINE")
log("===================================================")
pipe = ZImagePipeline.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch_dtype,
)
if USE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload(gpu_id=gpu_id)
log("βš™ CPU OFFLOAD ENABLED")
else:
pipe.to(device)
log("βš™ Pipeline moved to GPU")
log("βœ… Pipeline ready.")
# ============================================================
# INFERENCE FUNCTION
# ============================================================
@spaces.GPU
def generate_image(prompt, height, width, steps, seed):
global LOGS
LOGS = "" # reset logs
log("===================================================")
log("🎨 RUNNING INFERENCE")
log("===================================================")
log(f"Prompt : {prompt}")
log(f"Resolution : {width} x {height}")
log(f"Steps : {steps}")
log(f"Seed : {seed}")
generator = torch.Generator(device).manual_seed(seed)
out = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=0.0,
generator=generator,
)
log("βœ… Inference Finished")
return out.images[0], LOGS
# ============================================================
# GRADIO UI
# ============================================================
with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
gr.Markdown("# **Z-Image-Turbo β€” 4bit Quantized Image Generator**")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Output Image")
logs_panel = gr.Textbox(label="πŸ“œ Transformer Logs", lines=25, interactive=False)
btn.click(
generate_image,
inputs=[prompt, height, width, steps, seed],
outputs=[output_image, logs_panel],
)
demo.launch()