rahul7star commited on
Commit
6d29b78
Β·
verified Β·
1 Parent(s): 3179e10

Create app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +304 -0
app_quant_latent.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import gradio as gr
4
+ import sys
5
+ import platform
6
+ import diffusers
7
+ import transformers
8
+ import os
9
+ import torchvision.transforms as T
10
+
11
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
12
+ from diffusers import ZImagePipeline, AutoModel
13
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
14
+
15
+ # ============================================================
16
+
17
+ # LOGGING BUFFER
18
+
19
+ # ============================================================
20
+
21
+ LOGS = ""
22
+ def log(msg):
23
+ global LOGS
24
+ print(msg)
25
+ LOGS += msg + "\n"
26
+ return msg
27
+
28
+ # ============================================================
29
+
30
+ # ENVIRONMENT INFO
31
+
32
+ # ============================================================
33
+
34
+ log("===================================================")
35
+ log("πŸ” Z-IMAGE-TURBO DEBUGGING + ROBUST TRANSFORMER INSPECTION")
36
+ log("===================================================\n")
37
+
38
+ log(f"πŸ“Œ PYTHON VERSION : {sys.version.replace(chr(10), ' ')}")
39
+ log(f"πŸ“Œ PLATFORM : {platform.platform()}")
40
+ log(f"πŸ“Œ TORCH VERSION : {torch.**version**}")
41
+ log(f"πŸ“Œ TRANSFORMERS VERSION : {transformers.**version**}")
42
+ log(f"πŸ“Œ DIFFUSERS VERSION : {diffusers.**version**}")
43
+ log(f"πŸ“Œ CUDA AVAILABLE : {torch.cuda.is_available()}")
44
+
45
+ if torch.cuda.is_available():
46
+ log(f"πŸ“Œ GPU NAME : {torch.cuda.get_device_name(0)}")
47
+ log(f"πŸ“Œ GPU CAPABILITY : {torch.cuda.get_device_capability(0)}")
48
+ log(f"πŸ“Œ GPU MEMORY (TOTAL) : {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
49
+ log(f"πŸ“Œ FLASH ATTENTION : {torch.backends.cuda.flash_sdp_enabled()}")
50
+ else:
51
+ raise RuntimeError("❌ CUDA is REQUIRED but not available.")
52
+
53
+ device = "cuda"
54
+ gpu_id = 0
55
+
56
+ # ============================================================
57
+
58
+ # MODEL SETTINGS
59
+
60
+ # ============================================================
61
+
62
+ model_cache = "./weights/"
63
+ model_id = "Tongyi-MAI/Z-Image-Turbo"
64
+ torch_dtype = torch.bfloat16
65
+ USE_CPU_OFFLOAD = False
66
+
67
+ log("\n===================================================")
68
+ log("🧠 MODEL CONFIGURATION")
69
+ log("===================================================")
70
+ log(f"Model ID : {model_id}")
71
+ log(f"Model Cache Directory : {model_cache}")
72
+ log(f"torch_dtype : {torch_dtype}")
73
+ log(f"USE_CPU_OFFLOAD : {USE_CPU_OFFLOAD}")
74
+
75
+ # ============================================================
76
+
77
+ # ROBUST TRANSFORMER INSPECTION FUNCTION
78
+
79
+ # ============================================================
80
+
81
+ def inspect_transformer(model, model_name="Transformer"):
82
+ log(f"\nπŸ” {model_name} Architecture Details:")
83
+ try:
84
+ block_attrs = ["transformer_blocks", "blocks", "layers", "encoder_blocks", "model"]
85
+ blocks = None
86
+ for attr in block_attrs:
87
+ blocks = getattr(model, attr, None)
88
+ if blocks is not None:
89
+ break
90
+
91
+ ```
92
+ if blocks is None:
93
+ log(f"⚠️ Could not find transformer blocks in {model_name}, skipping detailed block info")
94
+ else:
95
+ try:
96
+ log(f"Number of Transformer Modules : {len(blocks)}")
97
+ for i, block in enumerate(blocks):
98
+ log(f" Block {i}: {block.__class__.__name__}")
99
+ attn_type = getattr(block, "attn", None)
100
+ if attn_type:
101
+ log(f" Attention: {attn_type.__class__.__name__}")
102
+ flash_enabled = getattr(attn_type, "flash", None)
103
+ log(f" FlashAttention Enabled? : {flash_enabled}")
104
+ except Exception as e:
105
+ log(f"⚠️ Error inspecting blocks: {e}")
106
+
107
+ config = getattr(model, "config", None)
108
+ if config:
109
+ log(f"Hidden size: {getattr(config, 'hidden_size', 'N/A')}")
110
+ log(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'N/A')}")
111
+ log(f"Number of layers: {getattr(config, 'num_hidden_layers', 'N/A')}")
112
+ log(f"Intermediate size: {getattr(config, 'intermediate_size', 'N/A')}")
113
+ else:
114
+ log(f"⚠️ No config attribute found in {model_name}")
115
+ except Exception as e:
116
+ log(f"⚠️ Failed to inspect {model_name}: {e}")
117
+ ```
118
+
119
+ # ============================================================
120
+
121
+ # LOAD TRANSFORMER BLOCK
122
+
123
+ # ============================================================
124
+
125
+ log("\n===================================================")
126
+ log("πŸ”§ LOADING TRANSFORMER BLOCK")
127
+ log("===================================================")
128
+
129
+ quantization_config = DiffusersBitsAndBytesConfig(
130
+ load_in_4bit=True,
131
+ bnb_4bit_quant_type="nf4",
132
+ bnb_4bit_compute_dtype=torch_dtype,
133
+ bnb_4bit_use_double_quant=True,
134
+ llm_int8_skip_modules=["transformer_blocks.0.img_mod"],
135
+ )
136
+ log("4-bit Quantization Config (Transformer):")
137
+ log(str(quantization_config))
138
+
139
+ transformer = AutoModel.from_pretrained(
140
+ model_id,
141
+ cache_dir=model_cache,
142
+ subfolder="transformer",
143
+ quantization_config=quantization_config,
144
+ torch_dtype=torch_dtype,
145
+ device_map=device,
146
+ )
147
+ log("βœ… Transformer block loaded successfully.")
148
+ inspect_transformer(transformer, "Transformer")
149
+
150
+ if USE_CPU_OFFLOAD:
151
+ transformer = transformer.to("cpu")
152
+
153
+ # ============================================================
154
+
155
+ # LOAD TEXT ENCODER
156
+
157
+ # ============================================================
158
+
159
+ log("\n===================================================")
160
+ log("πŸ”§ LOADING TEXT ENCODER")
161
+ log("===================================================")
162
+
163
+ quantization_config = TransformersBitsAndBytesConfig(
164
+ load_in_4bit=True,
165
+ bnb_4bit_quant_type="nf4",
166
+ bnb_4bit_compute_dtype=torch_dtype,
167
+ bnb_4bit_use_double_quant=True,
168
+ )
169
+ log("4-bit Quantization Config (Text Encoder):")
170
+ log(str(quantization_config))
171
+
172
+ text_encoder = AutoModel.from_pretrained(
173
+ model_id,
174
+ cache_dir=model_cache,
175
+ subfolder="text_encoder",
176
+ quantization_config=quantization_config,
177
+ torch_dtype=torch_dtype,
178
+ device_map=device,
179
+ )
180
+ log("βœ… Text encoder loaded successfully.")
181
+ inspect_transformer(text_encoder, "Text Encoder")
182
+
183
+ if USE_CPU_OFFLOAD:
184
+ text_encoder = text_encoder.to("cpu")
185
+
186
+ # ============================================================
187
+
188
+ # BUILD PIPELINE
189
+
190
+ # ============================================================
191
+
192
+ log("\n===================================================")
193
+ log("πŸ”§ BUILDING Z-IMAGE-TURBO PIPELINE")
194
+ log("===================================================")
195
+
196
+ pipe = ZImagePipeline.from_pretrained(
197
+ model_id,
198
+ transformer=transformer,
199
+ text_encoder=text_encoder,
200
+ torch_dtype=torch_dtype,
201
+ )
202
+
203
+ if USE_CPU_OFFLOAD:
204
+ pipe.enable_model_cpu_offload(gpu_id=gpu_id)
205
+ log("βš™ CPU OFFLOAD ENABLED")
206
+ else:
207
+ pipe.to(device)
208
+ log("βš™ Pipeline moved to GPU")
209
+
210
+ log("βœ… Pipeline ready.")
211
+
212
+ # ============================================================
213
+
214
+ # FUNCTION TO CONVERT LATENTS TO IMAGE
215
+
216
+ # ============================================================
217
+
218
+ def latent_to_image(latent):
219
+ try:
220
+ img_tensor = pipe.vae.decode(latent)
221
+ img_tensor = (img_tensor / 2 + 0.5).clamp(0, 1)
222
+ pil_img = T.ToPILImage()(img_tensor[0])
223
+ return pil_img
224
+ except Exception as e:
225
+ log(f"⚠️ Failed to decode latent: {e}")
226
+ return None
227
+
228
+ # ============================================================
229
+
230
+ # REAL-TIME INFERENCE FUNCTION
231
+
232
+ # ============================================================
233
+
234
+ @spaces.GPU
235
+ def generate_image_realtime(prompt, height, width, steps, seed):
236
+ global LOGS
237
+ LOGS = ""
238
+ log("===================================================")
239
+ log("🎨 RUNNING REAL-TIME INFERENCE")
240
+ log("===================================================")
241
+ log(f"Prompt : {prompt}")
242
+ log(f"Resolution : {width} x {height}")
243
+ log(f"Steps : {steps}")
244
+ log(f"Seed : {seed}")
245
+
246
+ ```
247
+ generator = torch.Generator(device).manual_seed(seed)
248
+ latent_history = []
249
+
250
+ # Define callback to save latents and GPU info
251
+ def save_latents(step, timestep, latents):
252
+ latent_history.append(latents.detach().clone())
253
+ gpu_mem = torch.cuda.memory_allocated(0)/1e9
254
+ log(f"Step {step} - GPU Memory Used: {gpu_mem:.2f} GB")
255
+
256
+ # Yield images step-by-step
257
+ for step, img in pipe(
258
+ prompt=prompt,
259
+ height=height,
260
+ width=width,
261
+ num_inference_steps=steps,
262
+ guidance_scale=0.0,
263
+ generator=generator,
264
+ callback=save_latents,
265
+ callback_steps=1
266
+ ).iter():
267
+ # Decode current latent for live preview
268
+ current_latent = latent_history[-1] if latent_history else None
269
+ latent_images = [latent_to_image(l) for l in latent_history if l is not None]
270
+ yield img, latent_images, LOGS
271
+ ```
272
+
273
+ # ============================================================
274
+
275
+ # GRADIO UI
276
+
277
+ # ============================================================
278
+
279
+ with gr.Blocks(title="Z-Image-Turbo Generator") as demo:
280
+ gr.Markdown("# **πŸš€ Z-Image-Turbo β€”4bit Quant + Real-Time Latent & Transformer Logs**")
281
+
282
+ ```
283
+ with gr.Row():
284
+ with gr.Column(scale=1):
285
+ prompt = gr.Textbox(label="Prompt", value="Realistic mid-aged male image")
286
+ height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
287
+ width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
288
+ steps = gr.Slider(1, 16, value=9, step=1, label="Inference Steps")
289
+ seed = gr.Slider(0, 999999, value=42, step=1, label="Seed")
290
+ btn = gr.Button("Generate", variant="primary")
291
+
292
+ with gr.Column(scale=1):
293
+ output_image = gr.Image(label="Final Output Image")
294
+ latent_gallery = gr.Gallery(label="Latent Evolution", elem_id="latent_gallery").style(grid=[2], height="auto")
295
+ logs_panel = gr.Textbox(label="πŸ“œ Transformer & GPU Logs", lines=25, interactive=False)
296
+
297
+ btn.click(
298
+ generate_image_realtime,
299
+ inputs=[prompt, height, width, steps, seed],
300
+ outputs=[output_image, latent_gallery, logs_panel],
301
+ )
302
+ ```
303
+
304
+ demo.launch()