rahul7star commited on
Commit
ea621bd
Β·
verified Β·
1 Parent(s): 1bd9a25

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +71 -97
app_quant_latent.py CHANGED
@@ -252,110 +252,84 @@ import io
252
 
253
  logs = []
254
  latent_gallery = []
255
- def calculate_shift(
256
- image_seq_len,
257
- base_seq_len: int = 256,
258
- max_seq_len: int = 4096,
259
- base_shift: float = 0.5,
260
- max_shift: float = 1.15,
261
- ):
262
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
263
- b = base_shift - m * base_seq_len
264
- mu = image_seq_len * m + b
265
- return mu
266
-
267
- def retrieve_timesteps(
268
- scheduler,
269
- num_inference_steps: int = None,
270
- device: str = None,
271
- timesteps: list = None,
272
- sigmas: list = None,
273
- **kwargs,
274
- ):
275
- if timesteps is not None and sigmas is not None:
276
- raise ValueError("Only one of timesteps or sigmas can be passed")
277
- if timesteps is not None:
278
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
279
- timesteps = scheduler.timesteps
280
- num_inference_steps = len(timesteps)
281
- elif sigmas is not None:
282
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
283
- timesteps = scheduler.timesteps
284
- num_inference_steps = len(timesteps)
285
- else:
286
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
287
- timesteps = scheduler.timesteps
288
- return timesteps, num_inference_steps
289
 
290
  @spaces.GPU
291
- def generate_image(prompt, height, width, steps, seed):
292
-
293
-
294
- generator = torch.Generator(device).manual_seed(int(seed))
295
-
296
- # Encode prompt
297
- prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
298
-
299
- batch_size = len(prompt_embeds)
300
- num_images_per_prompt = 1
301
- actual_batch_size = batch_size * num_images_per_prompt
302
- num_channels_latents = pipe.transformer.in_channels
303
-
304
- # Prepare latents
305
- latents = pipe.prepare_latents(
306
- actual_batch_size, num_channels_latents, height, width, torch.float32, device, generator
307
- )
308
-
309
- # Repeat embeddings for multiple images per prompt
310
- if num_images_per_prompt > 1:
311
- prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
312
- if pipe.do_classifier_free_guidance and negative_prompt_embeds:
313
- negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
314
-
315
- image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
316
- mu = calculate_shift(image_seq_len)
317
-
318
- pipe.scheduler.sigma_min = 0.0
319
- scheduler_kwargs = {"mu": mu}
320
- timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, steps, device, **scheduler_kwargs)
321
-
322
- # Denoising loop
323
- for i, t in enumerate(timesteps):
324
- timestep = t.expand(latents.shape[0])
325
- timestep = (1000 - timestep) / 1000
326
- t_norm = timestep[0].item()
327
- apply_cfg = pipe.do_classifier_free_guidance and pipe.guidance_scale > 0
328
-
329
- if apply_cfg:
330
- latent_model_input = latents.to(pipe.transformer.dtype).repeat(2, 1, 1, 1).unsqueeze(2)
331
- prompt_input = prompt_embeds + negative_prompt_embeds
332
- timestep_input = timestep.repeat(2)
333
- else:
334
- latent_model_input = latents.to(pipe.transformer.dtype).unsqueeze(2)
335
- prompt_input = prompt_embeds
336
- timestep_input = timestep
337
 
338
- latent_list = list(latent_model_input.unbind(0))
339
- model_out_list = pipe.transformer(latent_list, timestep_input, prompt_input, return_dict=False)[0]
340
 
341
- if apply_cfg:
342
- pos_out = model_out_list[:actual_batch_size]
343
- neg_out = model_out_list[actual_batch_size:]
344
- noise_pred = torch.stack([p + pipe.guidance_scale * (p - n) for p, n in zip(pos_out, neg_out)])
345
- else:
346
- noise_pred = torch.stack([t.float() for t in model_out_list], 0)
347
 
348
- noise_pred = noise_pred.squeeze(2)
349
- noise_pred = -noise_pred
350
- latents = pipe.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
- # Decode final image
353
- latents = latents.to(pipe.vae.dtype)
354
- latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
355
- image = pipe.vae.decode(latents, return_dict=False)[0]
356
- image = pipe.image_processor.postprocess(image, output_type="pil")
357
 
358
- return image, None, None
359
 
360
 
361
  # ============================================================
 
252
 
253
  logs = []
254
  latent_gallery = []
255
+
256
+ import torch
257
+ from PIL import Image
258
+
259
+ # Global log storage
260
+
261
+ LOGS = []
262
+
263
+ def log(msg):
264
+ LOGS.append(msg)
265
+ print(msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  @spaces.GPU
268
+ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
269
+ """
270
+ Generate an image from a prompt.
271
+ Tries advanced latent-based method; falls back to standard pipeline if anything fails.
272
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
 
 
274
 
 
 
 
 
 
 
275
 
276
+ try:
277
+ generator = torch.Generator(device).manual_seed(int(seed))
278
+
279
+ # Try advanced latent preparation
280
+ try:
281
+ batch_size = 1
282
+ num_channels_latents = getattr(pipe.unet, "in_channels", None)
283
+ if num_channels_latents is None:
284
+ raise AttributeError("pipe.unet.in_channels not found, fallback to standard pipeline")
285
+
286
+ latents = pipe.prepare_latents(
287
+ batch_size=batch_size,
288
+ num_channels=num_channels_latents,
289
+ height=height,
290
+ width=width,
291
+ dtype=torch.float32,
292
+ device=device,
293
+ generator=generator
294
+ )
295
+ log(f"βœ… Latents prepared: {latents.shape}")
296
+
297
+ # Generate image using prepared latents
298
+ output = pipe(
299
+ prompt=prompt,
300
+ height=height,
301
+ width=width,
302
+ num_inference_steps=steps,
303
+ guidance_scale=guidance_scale,
304
+ generator=generator,
305
+ latents=latents
306
+ )
307
+
308
+ except Exception as e_inner:
309
+ # If advanced method fails, fallback to standard pipeline
310
+ log(f"⚠️ Advanced latent method failed: {e_inner}")
311
+ log("πŸ” Falling back to standard pipeline...")
312
+ output = pipe(
313
+ prompt=prompt,
314
+ height=height,
315
+ width=width,
316
+ num_inference_steps=steps,
317
+ guidance_scale=guidance_scale,
318
+ generator=generator
319
+ )
320
+
321
+ image = output.images[0]
322
+ log("βœ… Inference finished successfully.")
323
+
324
+ if return_latents and 'latents' in locals():
325
+ return image, latents, LOGS
326
+ else:
327
+ return image, LOGS
328
 
329
+ except Exception as e:
330
+ log(f"❌ Inference failed entirely: {e}")
331
+ return None, LOGS
 
 
332
 
 
333
 
334
 
335
  # ============================================================