Image2Video / app_quant_latent.py
rahul7star's picture
Create app_quant_latent.py
6d29b78 verified
raw
history blame
9.56 kB
import torch
import spaces
import gradio as gr
import sys
import platform
import diffusers
import transformers
import os
import torchvision.transforms as T
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers import ZImagePipeline, AutoModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
# ============================================================
# LOGGING BUFFER
# ============================================================
LOGS = ""
def log(msg):
global LOGS
print(msg)
LOGS += msg + "\n"
return msg
# ============================================================
# ENVIRONMENT INFO
# ============================================================
log("===================================================")
log("πŸ” Z-IMAGE-TURBO DEBUGGING + ROBUST TRANSFORMER INSPECTION")
log("===================================================\n")
log(f"πŸ“Œ PYTHON VERSION : {sys.version.replace(chr(10), ' ')}")
log(f"πŸ“Œ PLATFORM : {platform.platform()}")
log(f"πŸ“Œ TORCH VERSION : {torch.**version**}")
log(f"πŸ“Œ TRANSFORMERS VERSION : {transformers.**version**}")
log(f"πŸ“Œ DIFFUSERS VERSION : {diffusers.**version**}")
log(f"πŸ“Œ CUDA AVAILABLE : {torch.cuda.is_available()}")
if torch.cuda.is_available():
log(f"πŸ“Œ GPU NAME : {torch.cuda.get_device_name(0)}")
log(f"πŸ“Œ GPU CAPABILITY : {torch.cuda.get_device_capability(0)}")
log(f"πŸ“Œ GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
log(f"πŸ“Œ FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}")
else:
raise RuntimeError("❌ CUDA is REQUIRED but not available.")
device = "cuda"
gpu_id = 0
# ============================================================
# MODEL SETTINGS
# ============================================================
model_cache = "./weights/"
model_id = "Tongyi-MAI/Z-Image-Turbo"
torch_dtype = torch.bfloat16
USE_CPU_OFFLOAD = False
log("\n===================================================")
log("🧠 MODEL CONFIGURATION")
log("===================================================")
log(f"Model ID : {model_id}")
log(f"Model Cache Directory : {model_cache}")
log(f"torch_dtype : {torch_dtype}")
log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}")
# ============================================================
# ROBUST TRANSFORMER INSPECTION FUNCTION
# ============================================================
def inspect_transformer(model, model_name="Transformer"):
log(f"\nπŸ” {model_name} Architecture Details:")
try:
block_attrs = ["transformer_blocks", "blocks", "layers", "encoder_blocks", "model"]
blocks = None
for attr in block_attrs:
blocks = getattr(model, attr, None)
if blocks is not None:
break
```
if blocks is None:
log(f"⚠️ Could not find transformer blocks in {model_name}, skipping detailed block info")
else:
try:
log(f"Number of Transformer Modules : {len(blocks)}")
for i, block in enumerate(blocks):
log(f" Block {i}: {block.__class__.__name__}")
attn_type = getattr(block, "attn", None)
if attn_type:
log(f" Attention: {attn_type.__class__.__name__}")
flash_enabled = getattr(attn_type, "flash", None)
log(f" FlashAttention Enabled? : {flash_enabled}")
except Exception as e:
log(f"⚠️ Error inspecting blocks: {e}")
config = getattr(model, "config", None)
if config:
log(f"Hidden size: {getattr(config, 'hidden_size', 'N/A')}")
log(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'N/A')}")
log(f"Number of layers: {getattr(config, 'num_hidden_layers', 'N/A')}")
log(f"Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}")
else:
log(f"⚠️ No config attribute found in {model_name}")
except Exception as e:
log(f"⚠️ Failed to inspect {model_name}: {e}")
```
# ============================================================
# LOAD TRANSFORMER BLOCK
# ============================================================
log("\n===================================================")
log("πŸ”§ LOADING TRANSFORMER BLOCK")
log("===================================================")
quantization_config = DiffusersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
)
log("4-bit Quantization Config (Transformer):")
log(str(quantization_config))
transformer = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
log("βœ… Transformer block loaded successfully.")
inspect_transformer(transformer, "Transformer")
if USE_CPU_OFFLOAD:
transformer = transformer.to("cpu")
# ============================================================
# LOAD TEXT ENCODER
# ============================================================
log("\n===================================================")
log("πŸ”§ LOADING TEXT ENCODER")
log("===================================================")
quantization_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)
log("4-bit Quantization Config (Text Encoder):")
log(str(quantization_config))
text_encoder = AutoModel.from_pretrained(
model_id,
cache_dir=model_cache,
subfolder="text_encoder",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=device,
)
log("βœ… Text encoder loaded successfully.")
inspect_transformer(text_encoder, "Text Encoder")
if USE_CPU_OFFLOAD:
text_encoder = text_encoder.to("cpu")
# ============================================================
# BUILD PIPELINE
# ============================================================
log("\n===================================================")
log("πŸ”§ BUILDING Z-IMAGE-TURBO PIPELINE")
log("===================================================")
pipe = ZImagePipeline.from_pretrained(
model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch_dtype,
)
if USE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload(gpu_id=gpu_id)
log("βš™ CPU OFFLOAD ENABLED")
else:
pipe.to(device)
log("βš™ Pipeline moved to GPU")
log("βœ… Pipeline ready.")
# ============================================================
# FUNCTION TO CONVERT LATENTS TO IMAGE
# ============================================================
def latent_to_image(latent):
try:
img_tensor = pipe.vae.decode(latent)
img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1)
pil_img = T.ToPILImage()(img_tensor[0])
return pil_img
except Exception as e:
log(f"⚠️ Failed to decode latent: {e}")
return None
# ============================================================
# REAL-TIME INFERENCE FUNCTION
# ============================================================
@spaces.GPU
def generate_image_realtime(prompt, height, width, steps, seed):
global LOGS
LOGS = ""
log("===================================================")
log("🎨 RUNNING REAL-TIME INFERENCE")
log("===================================================")
log(f"Prompt : {prompt}")
log(f"Resolution : {width} x {height}")
log(f"Steps : {steps}")
log(f"Seed : {seed}")
```
generator = torch.Generator(device).manual_seed(seed)
latent_history = []
# Define callback to save latents and GPU info
def save_latents(step, timestep, latents):
latent_history.append(latents.detach().clone())
gpu_mem = torch.cuda.memory_allocated(0)/1e9
log(f"Step {step} - GPU Memory Used: {gpu_mem:.2f} GB")
# Yield images step-by-step
for step, img in pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=0.0,
generator=generator,
callback=save_latents,
callback_steps=1
).iter():
# Decode current latent for live preview
current_latent = latent_history[-1] if latent_history else None
latent_images = [latent_to_image(l) for l in latent_history if l is not None]
yield img, latent_images, LOGS
```
# ============================================================
# GRADIO UI
# ============================================================
with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
gr.Markdown("# **πŸš€ Z-Image-Turbo β€”4bit Quant + Real-Time Latent & Transformer Logs**")
```
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Final Output Image")
latent_gallery = gr.Gallery(label="Latent Evolution", elem_id="latent_gallery").style(grid=[2], height="auto")
logs_panel = gr.Textbox(label="πŸ“œ Transformer & GPU Logs", lines=25, interactive=False)
btn.click(
generate_image_realtime,
inputs=[prompt, height, width, steps, seed],
outputs=[output_image, latent_gallery, logs_panel],
)
```
demo.launch()