Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import spaces | |
| import gradio as gr | |
| import sys | |
| import platform | |
| import diffusers | |
| import transformers | |
| import os | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from diffusers import ZImagePipeline, AutoModel | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| # ============================================================ | |
| # LOGGING BUFFER | |
| # ============================================================ | |
| LOGS = "" | |
| def log(msg): | |
| global LOGS | |
| print(msg) | |
| LOGS += msg + "\n" | |
| return msg | |
| # ============================================================ | |
| # ENVIRONMENT INFO | |
| # ============================================================ | |
| log("===================================================") | |
| log("π Z-IMAGE-TURBO DEBUGGING + ROBUST TRANSFORMER INSPECTION") | |
| log("===================================================\n") | |
| log(f"π PYTHON VERSION : {sys.version.replace(chr(10), ' ')}") | |
| log(f"π PLATFORM : {platform.platform()}") | |
| log(f"π TORCH VERSION : {torch.__version__}") | |
| log(f"π TRANSFORMERS VERSION : {transformers.__version__}") | |
| log(f"π DIFFUSERS VERSION : {diffusers.__version__}") | |
| log(f"π CUDA AVAILABLE : {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| log(f"π GPU NAME : {torch.cuda.get_device_name(0)}") | |
| log(f"π GPU CAPABILITY : {torch.cuda.get_device_capability(0)}") | |
| log(f"π GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB") | |
| log(f"π FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}") | |
| else: | |
| raise RuntimeError("β CUDA is REQUIRED but not available.") | |
| device = "cuda" | |
| gpu_id = 0 | |
| # ============================================================ | |
| # MODEL SETTINGS | |
| # ============================================================ | |
| model_cache = "./weights/" | |
| model_id = "Tongyi-MAI/Z-Image-Turbo" | |
| torch_dtype = torch.bfloat16 | |
| USE_CPU_OFFLOAD = False | |
| log("\n===================================================") | |
| log("π§ MODEL CONFIGURATION") | |
| log("===================================================") | |
| log(f"Model ID : {model_id}") | |
| log(f"Model Cache Directory : {model_cache}") | |
| log(f"torch_dtype : {torch_dtype}") | |
| log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}") | |
| # ============================================================ | |
| # ROBUST TRANSFORMER INSPECTION FUNCTION | |
| # ============================================================ | |
| def inspect_transformer(model, model_name="Transformer"): | |
| log(f"\nπ {model_name} Architecture Details:") | |
| try: | |
| block_attrs = ["transformer_blocks", "blocks", "layers", "encoder_blocks", "model"] | |
| blocks = None | |
| for attr in block_attrs: | |
| blocks = getattr(model, attr, None) | |
| if blocks is not None: | |
| break | |
| if blocks is None: | |
| log(f"β οΈ Could not find transformer blocks in {model_name}, skipping detailed block info") | |
| else: | |
| try: | |
| log(f"Number of Transformer Modules : {len(blocks)}") | |
| for i, block in enumerate(blocks): | |
| log(f" Block {i}: {block.__class__.__name__}") | |
| attn_type = getattr(block, "attn", None) | |
| if attn_type: | |
| log(f" Attention: {attn_type.__class__.__name__}") | |
| flash_enabled = getattr(attn_type, "flash", None) | |
| log(f" FlashAttention Enabled? : {flash_enabled}") | |
| except Exception as e: | |
| log(f"β οΈ Error inspecting blocks: {e}") | |
| config = getattr(model, "config", None) | |
| if config: | |
| log(f"Hidden size: {getattr(config, 'hidden_size', 'N/A')}") | |
| log(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'N/A')}") | |
| log(f"Number of layers: {getattr(config, 'num_hidden_layers', 'N/A')}") | |
| log(f"Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}") | |
| else: | |
| log(f"β οΈ No config attribute found in {model_name}") | |
| except Exception as e: | |
| log(f"β οΈ Failed to inspect {model_name}: {e}") | |
| # ============================================================ | |
| # LOAD TRANSFORMER BLOCK | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("π§ LOADING TRANSFORMER BLOCK") | |
| log("===================================================") | |
| quantization_config = DiffusersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| llm_int8_skip_modules=["transformer_blocks.0.img_mod"], | |
| ) | |
| log("4-bit Quantization Config (Transformer):") | |
| log(str(quantization_config)) | |
| transformer = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="transformer", | |
| quantization_config=quantization_config, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| log("β Transformer block loaded successfully.") | |
| inspect_transformer(transformer, "Transformer") | |
| if USE_CPU_OFFLOAD: | |
| transformer = transformer.to("cpu") | |
| # ============================================================ | |
| # LOAD TEXT ENCODER | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("π§ LOADING TEXT ENCODER") | |
| log("===================================================") | |
| quantization_config = TransformersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| log("4-bit Quantization Config (Text Encoder):") | |
| log(str(quantization_config)) | |
| text_encoder = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="text_encoder", | |
| quantization_config=quantization_config, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| log("β Text encoder loaded successfully.") | |
| inspect_transformer(text_encoder, "Text Encoder") | |
| if USE_CPU_OFFLOAD: | |
| text_encoder = text_encoder.to("cpu") | |
| # ============================================================ | |
| # BUILD PIPELINE | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("π§ BUILDING Z-IMAGE-TURBO PIPELINE") | |
| log("===================================================") | |
| pipe = ZImagePipeline.from_pretrained( | |
| model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch_dtype, | |
| ) | |
| if USE_CPU_OFFLOAD: | |
| pipe.enable_model_cpu_offload(gpu_id=gpu_id) | |
| log("β CPU OFFLOAD ENABLED") | |
| else: | |
| pipe.to(device) | |
| log("β Pipeline moved to GPU") | |
| log("β Pipeline ready.") | |
| # ============================================================ | |
| # INFERENCE FUNCTION | |
| # ============================================================ | |
| def generate_image(prompt, height, width, steps, seed): | |
| global LOGS | |
| LOGS = "" # reset logs | |
| log("===================================================") | |
| log("π¨ RUNNING INFERENCE") | |
| log("===================================================") | |
| log(f"Prompt : {prompt}") | |
| log(f"Resolution : {width} x {height}") | |
| log(f"Steps : {steps}") | |
| log(f"Seed : {seed}") | |
| generator = torch.Generator(device).manual_seed(seed) | |
| out = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=0.0, | |
| generator=generator, | |
| ) | |
| log("β Inference Finished") | |
| return out.images[0], LOGS | |
| # ============================================================ | |
| # GRADIO UI | |
| # ============================================================ | |
| with gr.Blocks(title="Z-Image-Turbo Generator") as demo: | |
| gr.Markdown("# **Z-Image-Turbo β 4bit Quantized Image Generator**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image") | |
| height = gr.Slider(256, 2048, value=1024, step=8, label="Height") | |
| width = gr.Slider(256, 2048, value=1024, step=8, label="Width") | |
| steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps") | |
| seed = gr.Slider(0, 999999, value=42, step=1, label="Seed") | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Output Image") | |
| logs_panel = gr.Textbox(label="π Transformer Logs", lines=25, interactive=False) | |
| btn.click( | |
| generate_image, | |
| inputs=[prompt, height, width, steps, seed], | |
| outputs=[output_image, logs_panel], | |
| ) | |
| demo.launch() | |