rahul7star commited on
Commit
792bd64
·
verified ·
1 Parent(s): cbb491e

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +107 -66
app_quant_latent.py CHANGED
@@ -245,82 +245,123 @@ log_system_stats("AFTER PIPELINE BUILD")
245
 
246
 
247
 
 
 
 
 
248
 
 
 
249
 
250
 
251
  @spaces.GPU
252
  def generate_image(prompt, height, width, steps, seed):
253
-
254
- try:
255
- generator = torch.Generator(device).manual_seed(int(seed))
256
- latent_history = []
257
-
258
- # callback signature expected by ZImagePipeline:
259
- # callback_on_step_end(self_pipeline, step_index, timestep, callback_kwargs_dict)
260
- def save_latents(self_pipeline, step_idx, timestep, callback_kwargs):
261
- # callback_kwargs contains tensor inputs specified by
262
- # callback_on_step_end_tensor_inputs (defaults to ["latents"])
263
- try:
264
- lat = callback_kwargs.get("latents", None)
265
- if lat is not None:
266
- # store CPU copy to avoid holding GPU memory
267
- latent_history.append(lat.detach().clone().cpu())
268
- # we must return a dict (may include overrides), here no overrides:
269
- return {}
270
- except Exception as e:
271
- log(f"⚠️ save_latents error: {e}")
272
- return {}
273
-
274
- # Run pipeline once, using the pipeline's callback mechanism
275
- out = pipe(
276
- prompt=prompt,
277
- height=height,
278
- width=width,
279
- num_inference_steps=steps,
280
- guidance_scale=0.0,
281
- generator=generator,
282
- callback_on_step_end=save_latents,
283
- callback_on_step_end_tensor_inputs=["latents"], # ensure latents passed to callback
284
- )
285
-
286
- # out is a ZImagePipelineOutput; pipeline already postprocessed images
287
- final_image = out.images[0] if hasattr(out, "images") and len(out.images) > 0 else out
288
-
289
- # Convert saved latents into displayable images (use same postprocessing as pipeline)
290
- latent_images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  try:
292
- # Determine decode device and dtype
293
- vae = pipe.vae
294
- img_proc = pipe.image_processor
295
- vae_device = vae.device if hasattr(vae, "device") else device
296
-
297
- for i, lat_cpu in enumerate(latent_history):
298
- try:
299
- # move to vae device and dtype
300
- lat = lat_cpu.to(vae_device).to(vae.dtype)
301
-
302
- # pipeline used this transform before decoding:
303
- lat = (lat / vae.config.scaling_factor) + getattr(vae.config, "shift_factor", 0.0)
304
-
305
- # decode: vae.decode returns (batch, C, H, W)
306
- img_tensor = vae.decode(lat, return_dict=False)[0]
307
-
308
- # postprocess with pipeline's image processor to PIL
309
- pil = img_proc.postprocess(img_tensor.unsqueeze(0), output_type="pil")[0]
310
- latent_images.append(pil)
311
- except Exception as e:
312
- log(f"⚠️ Failed to decode latent step {i}: {e}")
313
  except Exception as e:
314
- log(f"⚠️ Error while converting latents: {e}")
315
 
316
- log("✅ Inference finished.")
317
- log_system_stats("AFTER INFERENCE")
318
 
319
- return final_image, latent_images, LOGS
 
320
 
321
- except Exception as e:
322
- log(f"❌ Inference error: {e}")
323
- return None, [], LOGS
324
 
325
 
326
 
 
245
 
246
 
247
 
248
+ import torch
249
+ from PIL import Image
250
+ import io
251
+
252
 
253
+ logs = []
254
+ latent_gallery = []
255
 
256
 
257
  @spaces.GPU
258
  def generate_image(prompt, height, width, steps, seed):
259
+
260
+
261
+ try:
262
+ device = pipe._execution_device
263
+ generator = torch.Generator(device).manual_seed(int(seed))
264
+
265
+ # 1. Encode prompt
266
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
267
+ prompt=prompt,
268
+ negative_prompt=None,
269
+ do_classifier_free_guidance=True,
270
+ device=device,
271
+ )
272
+
273
+ batch_size = 1
274
+ num_images_per_prompt = 1
275
+ actual_batch_size = batch_size * num_images_per_prompt
276
+ num_channels_latents = pipe.transformer.in_channels
277
+
278
+ # 2. Prepare latents
279
+ latents = pipe.prepare_latents(
280
+ actual_batch_size,
281
+ num_channels_latents,
282
+ height,
283
+ width,
284
+ torch.float32,
285
+ device,
286
+ generator,
287
+ latents=None,
288
+ )
289
+
290
+ # Repeat prompt embeddings for multiple images
291
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
292
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
293
+
294
+ # 3. Prepare timesteps
295
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
296
+ mu = calculate_shift(
297
+ image_seq_len,
298
+ pipe.scheduler.config.get("base_image_seq_len", 256),
299
+ pipe.scheduler.config.get("max_image_seq_len", 4096),
300
+ pipe.scheduler.config.get("base_shift", 0.5),
301
+ pipe.scheduler.config.get("max_shift", 1.15),
302
+ )
303
+ pipe.scheduler.sigma_min = 0.0
304
+ timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, steps, device, sigmas=None, mu=mu)
305
+
306
+ # 4. Denoising loop
307
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
308
+ for i, t in enumerate(timesteps):
309
+ timestep = t.expand(latents.shape[0])
310
+ timestep = (1000 - timestep) / 1000
311
+
312
+ # CFG
313
+ latents_typed = latents.to(pipe.transformer.dtype)
314
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1).unsqueeze(2)
315
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
316
+ timestep_model_input = timestep.repeat(2)
317
+
318
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
319
+ model_out_list = pipe.transformer(
320
+ latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
321
+ )[0]
322
+
323
+ # Perform CFG
324
+ pos_out = model_out_list[:actual_batch_size]
325
+ neg_out = model_out_list[actual_batch_size:]
326
+ noise_pred = torch.stack([p + pipe.guidance_scale * (p - n) for p, n in zip(pos_out, neg_out)], dim=0)
327
+
328
+ noise_pred = noise_pred.squeeze(2)
329
+ noise_pred = -noise_pred
330
+ latents = pipe.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
331
+
332
+ # Store each latent step for gallery
333
+ latent_gallery.append(latents.clone().detach().cpu())
334
+ progress_bar.update()
335
+
336
+ # 5. Decode final image
337
+ latents_dec = latents.to(pipe.vae.dtype)
338
+ latents_dec = (latents_dec / pipe.vae.config.scaling_factor) + getattr(pipe.vae.config, "shift_factor", 0.0)
339
+
340
+ # Squeeze extra dim if present
341
+ if latents_dec.dim() == 5 and latents_dec.shape[2] == 1:
342
+ latents_dec = latents_dec.squeeze(2)
343
+
344
+ image = pipe.vae.decode(latents_dec, return_dict=False)[0]
345
+ final_image = pipe.image_processor.postprocess(image, output_type="pil")
346
+
347
+ # Decode latent gallery steps to images (optional)
348
+ gallery_images = []
349
+ for idx, lat in enumerate(latent_gallery):
350
  try:
351
+ lat = lat.to(pipe.vae.dtype)
352
+ if lat.dim() == 5 and lat.shape[2] == 1:
353
+ lat = lat.squeeze(2)
354
+ img = pipe.vae.decode(lat, return_dict=False)[0]
355
+ img = pipe.image_processor.postprocess(img, output_type="pil")
356
+ gallery_images.append(img[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  except Exception as e:
358
+ logs.append(f"⚠️ Failed to decode latent step {idx}: {e}")
359
 
360
+ return final_image[0], gallery_images, "\n".join(logs)
 
361
 
362
+ except Exception as e:
363
+ return None, [], f"❌ Inference error: {e}"
364
 
 
 
 
365
 
366
 
367