Spaces:
Restarting
on
Zero
Restarting
on
Zero
| import torch | |
| import spaces | |
| import gradio as gr | |
| import sys | |
| import platform | |
| import diffusers | |
| import transformers | |
| import psutil | |
| import os | |
| import time | |
| from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
| from diffusers import ZImagePipeline, AutoModel | |
| from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
| latent_history = [] | |
| # ============================================================ | |
| # LOGGING BUFFER | |
| # ============================================================ | |
| LOGS = "" | |
| def log(msg): | |
| global LOGS | |
| print(msg) | |
| LOGS += msg + "\n" | |
| return msg | |
| # ============================================================ | |
| # SYSTEM METRICS β LIVE GPU + CPU MONITORING | |
| # ============================================================ | |
| def log_system_stats(tag=""): | |
| try: | |
| log(f"\n===== π₯ SYSTEM STATS {tag} =====") | |
| # ============= GPU STATS ============= | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated(0) / 1e9 | |
| reserved = torch.cuda.memory_reserved(0) / 1e9 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| free = total - allocated | |
| log(f"π GPU Total : {total:.2f} GB") | |
| log(f"π GPU Allocated : {allocated:.2f} GB") | |
| log(f"π GPU Reserved : {reserved:.2f} GB") | |
| log(f"π GPU Free : {free:.2f} GB") | |
| # ============= CPU STATS ============ | |
| cpu = psutil.cpu_percent() | |
| ram_used = psutil.virtual_memory().used / 1e9 | |
| ram_total = psutil.virtual_memory().total / 1e9 | |
| log(f"π§ CPU Usage : {cpu}%") | |
| log(f"π§ RAM Used : {ram_used:.2f} GB / {ram_total:.2f} GB") | |
| except Exception as e: | |
| log(f"β οΈ Failed to log system stats: {e}") | |
| # ============================================================ | |
| # ENVIRONMENT INFO | |
| # ============================================================ | |
| log("===================================================") | |
| log("π Z-IMAGE-TURBO DEBUGGING + LIVE METRIC LOGGER") | |
| log("===================================================\n") | |
| log(f"π PYTHON VERSION : {sys.version.replace(chr(10),' ')}") | |
| log(f"π PLATFORM : {platform.platform()}") | |
| log(f"π TORCH VERSION : {torch.__version__}") | |
| log(f"π TRANSFORMERS VERSION : {transformers.__version__}") | |
| log(f"π DIFFUSERS VERSION : {diffusers.__version__}") | |
| log(f"π CUDA AVAILABLE : {torch.cuda.is_available()}") | |
| log_system_stats("AT STARTUP") | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("β CUDA Required") | |
| device = "cuda" | |
| gpu_id = 0 | |
| # ============================================================ | |
| # MODEL SETTINGS | |
| # ============================================================ | |
| model_cache = "./weights/" | |
| model_id = "Tongyi-MAI/Z-Image-Turbo" | |
| torch_dtype = torch.bfloat16 | |
| USE_CPU_OFFLOAD = False | |
| log("\n===================================================") | |
| log("π§ MODEL CONFIGURATION") | |
| log("===================================================") | |
| log(f"Model ID : {model_id}") | |
| log(f"Model Cache Directory : {model_cache}") | |
| log(f"torch_dtype : {torch_dtype}") | |
| log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}") | |
| log_system_stats("BEFORE TRANSFORMER LOAD") | |
| # ============================================================ | |
| # FUNCTION TO CONVERT LATENTS TO IMAGE | |
| # ============================================================ | |
| def latent_to_image(latent): | |
| try: | |
| img_tensor = pipe.vae.decode(latent) | |
| img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1) | |
| pil_img = T.ToPILImage()(img_tensor[0]) | |
| return pil_img | |
| except Exception as e: | |
| log(f"β οΈ Failed to decode latent: {e}") | |
| return None | |
| # ============================================================ | |
| # SAFE TRANSFORMER INSPECTION | |
| # ============================================================ | |
| def inspect_transformer(model, name): | |
| log(f"\nπ Inspecting {name}") | |
| try: | |
| candidates = ["transformer_blocks", "blocks", "layers", "encoder", "model"] | |
| blocks = None | |
| for attr in candidates: | |
| if hasattr(model, attr): | |
| blocks = getattr(model, attr) | |
| break | |
| if blocks is None: | |
| log(f"β οΈ No block structure found in {name}") | |
| return | |
| if hasattr(blocks, "__len__"): | |
| log(f"Total Blocks = {len(blocks)}") | |
| else: | |
| log("β οΈ Blocks exist but are not iterable") | |
| for i in range(min(10, len(blocks) if hasattr(blocks, "__len__") else 0)): | |
| log(f"Block {i} = {blocks[i].__class__.__name__}") | |
| except Exception as e: | |
| log(f"β οΈ Transformer inspect error: {e}") | |
| # ============================================================ | |
| # LOAD TRANSFORMER β WITH LIVE STATS | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("π§ LOADING TRANSFORMER BLOCK") | |
| log("===================================================") | |
| log("π Logging memory before load:") | |
| log_system_stats("START TRANSFORMER LOAD") | |
| try: | |
| quant_cfg = DiffusersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| transformer = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="transformer", | |
| quantization_config=quant_cfg, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| log("β Transformer loaded successfully.") | |
| except Exception as e: | |
| log(f"β Transformer load failed: {e}") | |
| transformer = None | |
| log_system_stats("AFTER TRANSFORMER LOAD") | |
| if transformer: | |
| inspect_transformer(transformer, "Transformer") | |
| # ============================================================ | |
| # LOAD TEXT ENCODER | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("π§ LOADING TEXT ENCODER") | |
| log("===================================================") | |
| log_system_stats("START TEXT ENCODER LOAD") | |
| try: | |
| quant_cfg2 = TransformersBitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| model_id, | |
| cache_dir=model_cache, | |
| subfolder="text_encoder", | |
| quantization_config=quant_cfg2, | |
| torch_dtype=torch_dtype, | |
| device_map=device, | |
| ) | |
| log("β Text encoder loaded successfully.") | |
| except Exception as e: | |
| log(f"β Text encoder load failed: {e}") | |
| text_encoder = None | |
| log_system_stats("AFTER TEXT ENCODER LOAD") | |
| if text_encoder: | |
| inspect_transformer(text_encoder, "Text Encoder") | |
| # ============================================================ | |
| # BUILD PIPELINE | |
| # ============================================================ | |
| log("\n===================================================") | |
| log("π§ BUILDING PIPELINE") | |
| log("===================================================") | |
| log_system_stats("START PIPELINE BUILD") | |
| try: | |
| pipe = ZImagePipeline.from_pretrained( | |
| model_id, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch_dtype, | |
| ) | |
| pipe.to(device) | |
| log("β Pipeline built successfully.") | |
| except Exception as e: | |
| log(f"β Pipeline build failed: {e}") | |
| pipe = None | |
| log_system_stats("AFTER PIPELINE BUILD") | |
| import torch | |
| from PIL import Image | |
| import io | |
| logs = [] | |
| latent_gallery = [] | |
| def calculate_shift( | |
| image_seq_len, | |
| base_seq_len: int = 256, | |
| max_seq_len: int = 4096, | |
| base_shift: float = 0.5, | |
| max_shift: float = 1.15, | |
| ): | |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
| b = base_shift - m * base_seq_len | |
| mu = image_seq_len * m + b | |
| return mu | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: int = None, | |
| device: str = None, | |
| timesteps: list = None, | |
| sigmas: list = None, | |
| **kwargs, | |
| ): | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of timesteps or sigmas can be passed") | |
| if timesteps is not None: | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| def generate_image(prompt, height, width, steps, seed): | |
| generator = torch.Generator(device).manual_seed(int(seed)) | |
| # Encode prompt | |
| prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt) | |
| batch_size = len(prompt_embeds) | |
| num_images_per_prompt = 1 | |
| actual_batch_size = batch_size * num_images_per_prompt | |
| num_channels_latents = pipe.transformer.in_channels | |
| # Prepare latents | |
| latents = pipe.prepare_latents( | |
| actual_batch_size, num_channels_latents, height, width, torch.float32, device, generator | |
| ) | |
| # Repeat embeddings for multiple images per prompt | |
| if num_images_per_prompt > 1: | |
| prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] | |
| if pipe.do_classifier_free_guidance and negative_prompt_embeds: | |
| negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] | |
| image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) | |
| mu = calculate_shift(image_seq_len) | |
| pipe.scheduler.sigma_min = 0.0 | |
| scheduler_kwargs = {"mu": mu} | |
| timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, steps, device, **scheduler_kwargs) | |
| # Denoising loop | |
| for i, t in enumerate(timesteps): | |
| timestep = t.expand(latents.shape[0]) | |
| timestep = (1000 - timestep) / 1000 | |
| t_norm = timestep[0].item() | |
| apply_cfg = pipe.do_classifier_free_guidance and pipe.guidance_scale > 0 | |
| if apply_cfg: | |
| latent_model_input = latents.to(pipe.transformer.dtype).repeat(2, 1, 1, 1).unsqueeze(2) | |
| prompt_input = prompt_embeds + negative_prompt_embeds | |
| timestep_input = timestep.repeat(2) | |
| else: | |
| latent_model_input = latents.to(pipe.transformer.dtype).unsqueeze(2) | |
| prompt_input = prompt_embeds | |
| timestep_input = timestep | |
| latent_list = list(latent_model_input.unbind(0)) | |
| model_out_list = pipe.transformer(latent_list, timestep_input, prompt_input, return_dict=False)[0] | |
| if apply_cfg: | |
| pos_out = model_out_list[:actual_batch_size] | |
| neg_out = model_out_list[actual_batch_size:] | |
| noise_pred = torch.stack([p + pipe.guidance_scale * (p - n) for p, n in zip(pos_out, neg_out)]) | |
| else: | |
| noise_pred = torch.stack([t.float() for t in model_out_list], 0) | |
| noise_pred = noise_pred.squeeze(2) | |
| noise_pred = -noise_pred | |
| latents = pipe.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] | |
| # Decode final image | |
| latents = latents.to(pipe.vae.dtype) | |
| latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | |
| image = pipe.vae.decode(latents, return_dict=False)[0] | |
| image = pipe.image_processor.postprocess(image, output_type="pil") | |
| return image, None, None | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| with gr.Blocks(title="Z-Image-Turbo Generator") as demo: | |
| gr.Markdown("# **π Z-Image-Turbo β Final Image & Latents**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image") | |
| height = gr.Slider(256, 2048, value=1024, step=8, label="Height") | |
| width = gr.Slider(256, 2048, value=1024, step=8, label="Width") | |
| steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps") | |
| seed = gr.Number(value=42, label="Seed") | |
| run_btn = gr.Button("Generate Image") | |
| with gr.Column(scale=1): | |
| final_image = gr.Image(label="Final Image") | |
| latent_gallery = gr.Gallery( | |
| label="Latent Steps", | |
| columns=4, | |
| height=256, | |
| preview=True | |
| ) | |
| logs_box = gr.Textbox(label="Logs", lines=15) | |
| run_btn.click( | |
| generate_image, | |
| inputs=[prompt, height, width, steps, seed], | |
| outputs=[final_image, latent_gallery, logs_box] | |
| ) | |
| demo.launch() |