import torch import spaces import gradio as gr import sys import platform import diffusers import transformers import os import torchvision.transforms as T from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig from diffusers import ZImagePipeline, AutoModel from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig # ============================================================ # LOGGING BUFFER # ============================================================ LOGS = "" def log(msg): global LOGS print(msg) LOGS += msg + "\n" return msg # ============================================================ # ENVIRONMENT INFO # ============================================================ log("===================================================") log("šŸ” Z-IMAGE-TURBO DEBUGGING + ROBUST TRANSFORMER INSPECTION") log("===================================================\n") log(f"šŸ“Œ PYTHON VERSION : {sys.version.replace(chr(10), ' ')}") log(f"šŸ“Œ PLATFORM : {platform.platform()}") log(f"šŸ“Œ TORCH VERSION : {torch.**version**}") log(f"šŸ“Œ TRANSFORMERS VERSION : {transformers.**version**}") log(f"šŸ“Œ DIFFUSERS VERSION : {diffusers.**version**}") log(f"šŸ“Œ CUDA AVAILABLE : {torch.cuda.is_available()}") if torch.cuda.is_available(): log(f"šŸ“Œ GPU NAME : {torch.cuda.get_device_name(0)}") log(f"šŸ“Œ GPU CAPABILITY : {torch.cuda.get_device_capability(0)}") log(f"šŸ“Œ GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB") log(f"šŸ“Œ FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}") else: raise RuntimeError("āŒ CUDA is REQUIRED but not available.") device = "cuda" gpu_id = 0 # ============================================================ # MODEL SETTINGS # ============================================================ model_cache = "./weights/" model_id = "Tongyi-MAI/Z-Image-Turbo" torch_dtype = torch.bfloat16 USE_CPU_OFFLOAD = False log("\n===================================================") log("🧠 MODEL CONFIGURATION") log("===================================================") log(f"Model ID : {model_id}") log(f"Model Cache Directory : {model_cache}") log(f"torch_dtype : {torch_dtype}") log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}") # ============================================================ # ROBUST TRANSFORMER INSPECTION FUNCTION # ============================================================ def inspect_transformer(model, model_name="Transformer"): log(f"\nšŸ” {model_name} Architecture Details:") try: block_attrs = ["transformer_blocks", "blocks", "layers", "encoder_blocks", "model"] blocks = None for attr in block_attrs: blocks = getattr(model, attr, None) if blocks is not None: break ``` if blocks is None: log(f"āš ļø Could not find transformer blocks in {model_name}, skipping detailed block info") else: try: log(f"Number of Transformer Modules : {len(blocks)}") for i, block in enumerate(blocks): log(f" Block {i}: {block.__class__.__name__}") attn_type = getattr(block, "attn", None) if attn_type: log(f" Attention: {attn_type.__class__.__name__}") flash_enabled = getattr(attn_type, "flash", None) log(f" FlashAttention Enabled? : {flash_enabled}") except Exception as e: log(f"āš ļø Error inspecting blocks: {e}") config = getattr(model, "config", None) if config: log(f"Hidden size: {getattr(config, 'hidden_size', 'N/A')}") log(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'N/A')}") log(f"Number of layers: {getattr(config, 'num_hidden_layers', 'N/A')}") log(f"Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}") else: log(f"āš ļø No config attribute found in {model_name}") except Exception as e: log(f"āš ļø Failed to inspect {model_name}: {e}") ``` # ============================================================ # LOAD TRANSFORMER BLOCK # ============================================================ log("\n===================================================") log("šŸ”§ LOADING TRANSFORMER BLOCK") log("===================================================") quantization_config = DiffusersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, llm_int8_skip_modules=["transformer_blocks.0.img_mod"], ) log("4-bit Quantization Config (Transformer):") log(str(quantization_config)) transformer = AutoModel.from_pretrained( model_id, cache_dir=model_cache, subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=device, ) log("āœ… Transformer block loaded successfully.") inspect_transformer(transformer, "Transformer") if USE_CPU_OFFLOAD: transformer = transformer.to("cpu") # ============================================================ # LOAD TEXT ENCODER # ============================================================ log("\n===================================================") log("šŸ”§ LOADING TEXT ENCODER") log("===================================================") quantization_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, ) log("4-bit Quantization Config (Text Encoder):") log(str(quantization_config)) text_encoder = AutoModel.from_pretrained( model_id, cache_dir=model_cache, subfolder="text_encoder", quantization_config=quantization_config, torch_dtype=torch_dtype, device_map=device, ) log("āœ… Text encoder loaded successfully.") inspect_transformer(text_encoder, "Text Encoder") if USE_CPU_OFFLOAD: text_encoder = text_encoder.to("cpu") # ============================================================ # BUILD PIPELINE # ============================================================ log("\n===================================================") log("šŸ”§ BUILDING Z-IMAGE-TURBO PIPELINE") log("===================================================") pipe = ZImagePipeline.from_pretrained( model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch_dtype, ) if USE_CPU_OFFLOAD: pipe.enable_model_cpu_offload(gpu_id=gpu_id) log("āš™ CPU OFFLOAD ENABLED") else: pipe.to(device) log("āš™ Pipeline moved to GPU") log("āœ… Pipeline ready.") # ============================================================ # FUNCTION TO CONVERT LATENTS TO IMAGE # ============================================================ def latent_to_image(latent): try: img_tensor = pipe.vae.decode(latent) img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1) pil_img = T.ToPILImage()(img_tensor[0]) return pil_img except Exception as e: log(f"āš ļø Failed to decode latent: {e}") return None # ============================================================ # REAL-TIME INFERENCE FUNCTION # ============================================================ @spaces.GPU def generate_image_realtime(prompt, height, width, steps, seed): global LOGS LOGS = "" log("===================================================") log("šŸŽØ RUNNING REAL-TIME INFERENCE") log("===================================================") log(f"Prompt : {prompt}") log(f"Resolution : {width} x {height}") log(f"Steps : {steps}") log(f"Seed : {seed}") ``` generator = torch.Generator(device).manual_seed(seed) latent_history = [] # Define callback to save latents and GPU info def save_latents(step, timestep, latents): latent_history.append(latents.detach().clone()) gpu_mem = torch.cuda.memory_allocated(0)/1e9 log(f"Step {step} - GPU Memory Used: {gpu_mem:.2f} GB") # Yield images step-by-step for step, img in pipe( prompt=prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=0.0, generator=generator, callback=save_latents, callback_steps=1 ).iter(): # Decode current latent for live preview current_latent = latent_history[-1] if latent_history else None latent_images = [latent_to_image(l) for l in latent_history if l is not None] yield img, latent_images, LOGS ``` # ============================================================ # GRADIO UI # ============================================================ with gr.Blocks(title="Z-Image-Turbo Generator") as demo: gr.Markdown("# **šŸš€ Z-Image-Turbo —4bit Quant + Real-Time Latent & Transformer Logs**") ``` with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image") height = gr.Slider(256, 2048, value=1024, step=8, label="Height") width = gr.Slider(256, 2048, value=1024, step=8, label="Width") steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps") seed = gr.Slider(0, 999999, value=42, step=1, label="Seed") btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Final Output Image") latent_gallery = gr.Gallery(label="Latent Evolution", elem_id="latent_gallery").style(grid=[2], height="auto") logs_panel = gr.Textbox(label="šŸ“œ Transformer & GPU Logs", lines=25, interactive=False) btn.click( generate_image_realtime, inputs=[prompt, height, width, steps, seed], outputs=[output_image, latent_gallery, logs_panel], ) ``` demo.launch()