import torch import spaces import gradio as gr import sys import platform import diffusers import transformers import os from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig from diffusers import ZImagePipeline, AutoModel from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig # ============================================================ # LOGGING BUFFER # ============================================================ LOGS = "" def log(msg): global LOGS print(msg) LOGS += msg + "\n" return msg # ============================================================ # ENVIRONMENT INFO # ============================================================ log("===================================================") log("šŸ” Z-IMAGE-TURBO DEBUGGING + ROBUST TRANSFORMER INSPECTION") log("===================================================\n") log(f"šŸ“Œ PYTHON VERSION : {sys.version.replace(chr(10), ' ')}") log(f"šŸ“Œ PLATFORM : {platform.platform()}") log(f"šŸ“Œ TORCH VERSION : {torch.__version__}") log(f"šŸ“Œ TRANSFORMERS VERSION : {transformers.__version__}") log(f"šŸ“Œ DIFFUSERS VERSION : {diffusers.__version__}") log(f"šŸ“Œ CUDA AVAILABLE : {torch.cuda.is_available()}") if torch.cuda.is_available(): log(f"šŸ“Œ GPU NAME : {torch.cuda.get_device_name(0)}") log(f"šŸ“Œ GPU CAPABILITY : {torch.cuda.get_device_capability(0)}") log(f"šŸ“Œ GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB") log(f"šŸ“Œ FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}") else: raise RuntimeError("āŒ CUDA is REQUIRED but not available.") device = "cuda" gpu_id = 0 # ============================================================ # MODEL SETTINGS # ============================================================ model_cache = "./weights/" model_id = "Tongyi-MAI/Z-Image-Turbo" torch_dtype = torch.bfloat16 USE_CPU_OFFLOAD = False log("\n===================================================") log("🧠 MODEL CONFIGURATION") log("===================================================") log(f"Model ID : {model_id}") log(f"Model Cache Directory : {model_cache}") log(f"torch_dtype : {torch_dtype}") log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}") # ============================================================ # ROBUST TRANSFORMER INSPECTION FUNCTION # ============================================================ def inspect_transformer(model, model_name="Transformer"): log(f"\nšŸ” {model_name} Architecture Details:") try: block_attrs = ["transformer_blocks", "blocks", "layers", "encoder_blocks", "model"] blocks = None for attr in block_attrs: blocks = getattr(model, attr, None) if blocks is not None: break if blocks is None: log(f"āš ļø Could not find transformer blocks in {model_name}, skipping detailed block info") else: try: log(f"Number of Transformer Modules : {len(blocks)}") for i, block in enumerate(blocks): log(f" Block {i}: {block.__class__.__name__}") attn_type = getattr(block, "attn", None) if attn_type: log(f" Attention: {attn_type.__class__.__name__}") flash_enabled = getattr(attn_type, "flash", None) log(f" FlashAttention Enabled? : {flash_enabled}") except Exception as e: log(f"āš ļø Error inspecting blocks: {e}") config = getattr(model, "config", None) if config: log(f"Hidden size: {getattr(config, 'hidden_size', 'N/A')}") log(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'N/A')}") log(f"Number of layers: {getattr(config, 'num_hidden_layers', 'N/A')}") log(f"Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}") else: log(f"āš ļø No config attribute found in {model_name}") except Exception as e: log(f"āš ļø Failed to inspect {model_name}: {e}") # ============================================================ # LOAD TRANSFORMER BLOCK # ============================================================ log("\n===================================================") log("šŸ”§ LOADING TRANSFORMER BLOCK") log("===================================================") quantization_config = DiffusersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, llm_int8_skip_modules=["transformer_blocks.0.img_mod"], ) log("4-bit Quantization Config (Transformer):") log(str(quantization_config)) transformer = AutoModel.from_pretrained( model_id, cache_dir=model_cache, subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=device, ) log("āœ… Transformer block loaded successfully.") inspect_transformer(transformer, "Transformer") if USE_CPU_OFFLOAD: transformer = transformer.to("cpu") # ============================================================ # LOAD TEXT ENCODER # ============================================================ log("\n===================================================") log("šŸ”§ LOADING TEXT ENCODER") log("===================================================") quantization_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, ) log("4-bit Quantization Config (Text Encoder):") log(str(quantization_config)) text_encoder = AutoModel.from_pretrained( model_id, cache_dir=model_cache, subfolder="text_encoder", quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=device, ) log("āœ… Text encoder loaded successfully.") inspect_transformer(text_encoder, "Text Encoder") if USE_CPU_OFFLOAD: text_encoder = text_encoder.to("cpu") # ============================================================ # BUILD PIPELINE # ============================================================ log("\n===================================================") log("šŸ”§ BUILDING Z-IMAGE-TURBO PIPELINE") log("===================================================") pipe = ZImagePipeline.from_pretrained( model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch_dtype, ) if USE_CPU_OFFLOAD: pipe.enable_model_cpu_offload(gpu_id=gpu_id) log("āš™ CPU OFFLOAD ENABLED") else: pipe.to(device) log("āš™ Pipeline moved to GPU") log("āœ… Pipeline ready.") # ============================================================ # INFERENCE FUNCTION # ============================================================ @spaces.GPU def generate_image(prompt, height, width, steps, seed): global LOGS LOGS = "" # reset logs log("===================================================") log("šŸŽØ RUNNING INFERENCE") log("===================================================") log(f"Prompt : {prompt}") log(f"Resolution : {width} x {height}") log(f"Steps : {steps}") log(f"Seed : {seed}") generator = torch.Generator(device).manual_seed(seed) out = pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=0.0, generator=generator, ) log("āœ… Inference Finished") return out.images[0], LOGS # ============================================================ # GRADIO UI # ============================================================ with gr.Blocks(title="Z-Image-Turbo Generator") as demo: gr.Markdown("# **Z-Image-Turbo — 4bit Quantized Image Generator**") with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image") height = gr.Slider(256, 2048, value=1024, step=8, label="Height") width = gr.Slider(256, 2048, value=1024, step=8, label="Width") steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps") seed = gr.Slider(0, 999999, value=42, step=1, label="Seed") btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Output Image") logs_panel = gr.Textbox(label="šŸ“œ Transformer Logs", lines=25, interactive=False) btn.click( generate_image, inputs=[prompt, height, width, steps, seed], outputs=[output_image, logs_panel], ) demo.launch()