import torch import spaces import gradio as gr from diffusers import DiffusionPipeline from transformers import pipeline import diffusers import io # ------------------------ # GLOBAL LOG BUFFER # ------------------------ log_buffer = io.StringIO() def log(msg): print(msg) log_buffer.write(msg + "\n") # Enable diffusers debug logs diffusers.utils.logging.set_verbosity_info() # ------------------------ # LOAD PIPELINES # ------------------------ log("Loading Z-Image-Turbo pipeline...") pipe = DiffusionPipeline.from_pretrained( "Tongyi-MAI/Z-Image-Turbo", dtype=torch.bfloat16, low_cpu_mem_usage=False, attn_implementation="kernels-community/vllm-flash-attn3", ) pipe.to("cuda") log("Loading FP8 text encoder: Qwen/Qwen3-4B...") fp8_encoder = pipeline("text-generation", model="Qwen/Qwen3-4B", device=0) # device=0 → CUDA # ------------------------ # PIPELINE DEBUG INFO # ------------------------ def pipeline_debug_info(pipe): info = ["=== PIPELINE DEBUG INFO ==="] try: tr = pipe.transformer.config info.append(f"Transformer Class: {pipe.transformer.__class__.__name__}") info.append(f"Hidden dim: {tr.get('hidden_dim')}") info.append(f"Attention heads: {tr.get('num_heads')}") info.append(f"Depth (layers): {tr.get('depth')}") info.append(f"Patch size: {tr.get('patch_size')}") info.append(f"MLP ratio: {tr.get('mlp_ratio')}") info.append(f"Attention backend: {tr.get('attn_implementation')}") except Exception as e: info.append(f"Transformer diagnostics failed: {e}") try: vae = pipe.vae.config info.append(f"VAE latent channels: {vae.latent_channels}") info.append(f"VAE scaling factor: {vae.scaling_factor}") except Exception as e: info.append(f"VAE diagnostics failed: {e}") return "\n".join(info) def latent_shape_info(h, w, pipe): try: c = pipe.vae.config.latent_channels s = pipe.vae.config.scaling_factor h_lat = int(h * s) w_lat = int(w * s) return f"Latent shape → ({c}, {h_lat}, {w_lat})" except Exception as e: return f"Latent shape calc failed: {e}" # ------------------------ # IMAGE GENERATION # ------------------------ @spaces.GPU def generate_image(prompt, height, width, num_inference_steps, seed, randomize_seed, num_images): log_buffer.truncate(0) log_buffer.seek(0) log("=== NEW GENERATION REQUEST ===") log(f"Prompt: {prompt}") log(f"Height: {height}, Width: {width}") log(f"Inference Steps: {num_inference_steps}") log(f"Num Images: {num_images}") if randomize_seed: seed = torch.randint(0, 2**32 - 1, (1,)).item() log(f"Randomized Seed → {seed}") else: log(f"Seed: {seed}") # Clamp images to 1–3 num_images = min(max(1, int(num_images)), 3) # Run FP8 text encoder first log("Encoding prompt with FP8 text encoder...") encoded_prompt = fp8_encoder([{"role": "user", "content": prompt}]) log(f"FP8 encoding output: {encoded_prompt}") # Debug pipeline info log(pipeline_debug_info(pipe)) generator = torch.Generator("cuda").manual_seed(int(seed)) log("Running Z-Image-Turbo pipeline forward()...") result = pipe( prompt=prompt, height=int(height), width=int(width), num_inference_steps=int(num_inference_steps), guidance_scale=0.0, generator=generator, max_sequence_length=1024, num_images_per_prompt=num_images, output_type="pil", ) # Latent diagnostics try: log(f"VAE latent channels: {pipe.vae.config.latent_channels}") log(f"VAE scaling factor: {pipe.vae.config.scaling_factor}") log(latent_shape_info(height, width, pipe)) except Exception as e: log(f"Latent diagnostics error: {e}") log("Pipeline finished.") log("Returning images...") return result.images, seed, log_buffer.getvalue() # ------------------------ # GRADIO UI # ------------------------ examples = [ ["Young Chinese woman in red Hanfu, intricate embroidery..."], ["A majestic dragon soaring through clouds at sunset..."], ["Cozy coffee shop interior, warm lighting, rain on windows..."], ["Astronaut riding a horse on Mars, cinematic lighting..."], ["Portrait of a wise old wizard..."], ] with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo: gr.Markdown("# 🎨 Z-Image-Turbo — Multi Image (FP8 Text Encoder)") with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", lines=4) with gr.Row(): height = gr.Slider(512, 2048, 1024, step=64, label="Height") width = gr.Slider(512, 2048, 1024, step=64, label="Width") num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images") num_inference_steps = gr.Slider( 1, 20, 9, step=1, label="Inference Steps", info="9 steps = 8 DiT forward passes", ) with gr.Row(): seed = gr.Number(label="Seed", value=42, precision=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) generate_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(scale=1): output_images = gr.Gallery(label="Generated Images", type="pil") used_seed = gr.Number(label="Seed Used", interactive=False) debug_log = gr.Textbox(label="Debug Log Output", lines=25, interactive=False) gr.Examples(examples=examples, inputs=[prompt], cache_examples=False) generate_btn.click( fn=generate_image, inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images], outputs=[output_images, used_seed, debug_log], ) if __name__ == "__main__": demo.launch()