rahul7star commited on
Commit
9ff5a65
·
verified ·
1 Parent(s): 81d6159

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +81 -0
app_quant_latent.py CHANGED
@@ -245,7 +245,88 @@ log_system_stats("AFTER PIPELINE BUILD")
245
 
246
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
 
 
 
249
 
250
  @spaces.GPU
251
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
 
245
 
246
 
247
 
248
+ @spaces.GPU
249
+ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
250
+ """
251
+ Returns 3 outputs for Gradio:
252
+ - image: PIL.Image
253
+ - gallery: optional latents or empty list
254
+ - logs: list of messages
255
+ """
256
+
257
+ LOGS = []
258
+ image = None
259
+ latents = None
260
+ gallery = []
261
+
262
+ try:
263
+ generator = torch.Generator(device).manual_seed(int(seed))
264
+
265
+ # -------------------------------
266
+ # Attempt advanced latent generation
267
+ # -------------------------------
268
+ try:
269
+ batch_size = 1
270
+ num_channels_latents = getattr(pipe.unet, "in_channels", None)
271
+ if num_channels_latents is None:
272
+ raise AttributeError("pipe.unet.in_channels not found, fallback to standard pipeline")
273
+
274
+ latents = pipe.prepare_latents(
275
+ batch_size=batch_size,
276
+ num_channels=num_channels_latents,
277
+ height=height,
278
+ width=width,
279
+ dtype=torch.float32,
280
+ device=device,
281
+ generator=generator
282
+ )
283
+ LOGS.append(f"✅ Latents prepared: {latents.shape}")
284
+
285
+ output = pipe(
286
+ prompt=prompt,
287
+ height=height,
288
+ width=width,
289
+ num_inference_steps=steps,
290
+ guidance_scale=guidance_scale,
291
+ generator=generator,
292
+ latents=latents
293
+ )
294
+ image = output.images[0]
295
+ gallery = [image] # You can also return other latent previews if needed
296
+ LOGS.append("✅ Advanced latent generation succeeded.")
297
+
298
+ # -------------------------------
299
+ # Fallback to standard pipeline
300
+ # -------------------------------
301
+ except Exception as e_latent:
302
+ LOGS.append(f"⚠️ Advanced latent generation failed: {e_latent}")
303
+ LOGS.append("🔁 Falling back to standard pipeline...")
304
+ try:
305
+ output = pipe(
306
+ prompt=prompt,
307
+ height=height,
308
+ width=width,
309
+ num_inference_steps=steps,
310
+ guidance_scale=guidance_scale,
311
+ generator=generator
312
+ )
313
+ image = output.images[0]
314
+ gallery = [image] # still return at least the final image
315
+ LOGS.append("✅ Standard pipeline generation succeeded.")
316
+ except Exception as e_standard:
317
+ LOGS.append(f"❌ Standard pipeline generation failed: {e_standard}")
318
+ # Fallback: return None image and empty gallery
319
+ image = None
320
+ gallery = []
321
+
322
+ # -------------------------------
323
+ # Return all 3 outputs
324
+ # -------------------------------
325
+ return image, gallery, LOGS
326
 
327
+ except Exception as e:
328
+ LOGS.append(f"❌ Inference failed entirely: {e}")
329
+ return None, [], LOGS
330
 
331
  @spaces.GPU
332
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):