rahul7star commited on
Commit
8ee0e91
·
verified ·
1 Parent(s): 0e6b7b0

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +93 -111
app_quant_latent.py CHANGED
@@ -253,127 +253,109 @@ import io
253
  logs = []
254
  latent_gallery = []
255
  def calculate_shift(
256
- image_seq_len,
257
- base_seq_len: int = 256,
258
- max_seq_len: int = 4096,
259
- base_shift: float = 0.5,
260
- max_shift: float = 1.15,
261
  ):
262
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
263
- b = base_shift - m * base_seq_len
264
- mu = image_seq_len * m + b
265
- return mu
266
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  @spaces.GPU
269
  def generate_image(prompt, height, width, steps, seed):
270
 
 
 
271
 
272
- try:
273
- device = pipe._execution_device
274
- generator = torch.Generator(device).manual_seed(int(seed))
275
 
276
- # 1. Encode prompt
277
- prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
278
- prompt=prompt,
279
- negative_prompt=None,
280
- do_classifier_free_guidance=True,
281
- device=device,
282
- )
283
 
284
- batch_size = 1
285
- num_images_per_prompt = 1
286
- actual_batch_size = batch_size * num_images_per_prompt
287
- num_channels_latents = pipe.transformer.in_channels
288
-
289
- # 2. Prepare latents
290
- latents = pipe.prepare_latents(
291
- actual_batch_size,
292
- num_channels_latents,
293
- height,
294
- width,
295
- torch.float32,
296
- device,
297
- generator,
298
- latents=None,
299
- )
300
 
301
- # Repeat prompt embeddings for multiple images
 
302
  prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
303
- negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
304
-
305
- # 3. Prepare timesteps
306
- image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
307
- mu = calculate_shift(
308
- image_seq_len,
309
- pipe.scheduler.config.get("base_image_seq_len", 256),
310
- pipe.scheduler.config.get("max_image_seq_len", 4096),
311
- pipe.scheduler.config.get("base_shift", 0.5),
312
- pipe.scheduler.config.get("max_shift", 1.15),
313
- )
314
- pipe.scheduler.sigma_min = 0.0
315
- timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, steps, device, sigmas=None, mu=mu)
316
-
317
- # 4. Denoising loop
318
- with pipe.progress_bar(total=num_inference_steps) as progress_bar:
319
- for i, t in enumerate(timesteps):
320
- timestep = t.expand(latents.shape[0])
321
- timestep = (1000 - timestep) / 1000
322
-
323
- # CFG
324
- latents_typed = latents.to(pipe.transformer.dtype)
325
- latent_model_input = latents_typed.repeat(2, 1, 1, 1).unsqueeze(2)
326
- prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
327
- timestep_model_input = timestep.repeat(2)
328
-
329
- latent_model_input_list = list(latent_model_input.unbind(dim=0))
330
- model_out_list = pipe.transformer(
331
- latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
332
- )[0]
333
-
334
- # Perform CFG
335
- pos_out = model_out_list[:actual_batch_size]
336
- neg_out = model_out_list[actual_batch_size:]
337
- noise_pred = torch.stack([p + pipe.guidance_scale * (p - n) for p, n in zip(pos_out, neg_out)], dim=0)
338
-
339
- noise_pred = noise_pred.squeeze(2)
340
- noise_pred = -noise_pred
341
- latents = pipe.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
342
-
343
- # Store each latent step for gallery
344
- latent_gallery.append(latents.clone().detach().cpu())
345
- progress_bar.update()
346
-
347
- # 5. Decode final image
348
- latents_dec = latents.to(pipe.vae.dtype)
349
- latents_dec = (latents_dec / pipe.vae.config.scaling_factor) + getattr(pipe.vae.config, "shift_factor", 0.0)
350
-
351
- # Squeeze extra dim if present
352
- if latents_dec.dim() == 5 and latents_dec.shape[2] == 1:
353
- latents_dec = latents_dec.squeeze(2)
354
-
355
- image = pipe.vae.decode(latents_dec, return_dict=False)[0]
356
- final_image = pipe.image_processor.postprocess(image, output_type="pil")
357
-
358
- # Decode latent gallery steps to images (optional)
359
- gallery_images = []
360
- for idx, lat in enumerate(latent_gallery):
361
- try:
362
- lat = lat.to(pipe.vae.dtype)
363
- if lat.dim() == 5 and lat.shape[2] == 1:
364
- lat = lat.squeeze(2)
365
- img = pipe.vae.decode(lat, return_dict=False)[0]
366
- img = pipe.image_processor.postprocess(img, output_type="pil")
367
- gallery_images.append(img[0])
368
- except Exception as e:
369
- logs.append(f"⚠️ Failed to decode latent step {idx}: {e}")
370
-
371
- return final_image[0], gallery_images, "\n".join(logs)
372
-
373
- except Exception as e:
374
- return None, [], f"❌ Inference error: {e}"
375
-
376
-
377
 
378
 
379
  # ============================================================
 
253
  logs = []
254
  latent_gallery = []
255
  def calculate_shift(
256
+ image_seq_len,
257
+ base_seq_len: int = 256,
258
+ max_seq_len: int = 4096,
259
+ base_shift: float = 0.5,
260
+ max_shift: float = 1.15,
261
  ):
262
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
263
+ b = base_shift - m * base_seq_len
264
+ mu = image_seq_len * m + b
265
+ return mu
266
+
267
+ def retrieve_timesteps(
268
+ scheduler,
269
+ num_inference_steps: int = None,
270
+ device: str = None,
271
+ timesteps: list = None,
272
+ sigmas: list = None,
273
+ **kwargs,
274
+ ):
275
+ if timesteps is not None and sigmas is not None:
276
+ raise ValueError("Only one of timesteps or sigmas can be passed")
277
+ if timesteps is not None:
278
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
279
+ timesteps = scheduler.timesteps
280
+ num_inference_steps = len(timesteps)
281
+ elif sigmas is not None:
282
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
283
+ timesteps = scheduler.timesteps
284
+ num_inference_steps = len(timesteps)
285
+ else:
286
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
287
+ timesteps = scheduler.timesteps
288
+ return timesteps, num_inference_steps
289
 
290
  @spaces.GPU
291
  def generate_image(prompt, height, width, steps, seed):
292
 
293
+
294
+ generator = torch.Generator(device).manual_seed(int(seed))
295
 
296
+ # Encode prompt
297
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
 
298
 
299
+ batch_size = len(prompt_embeds)
300
+ num_images_per_prompt = 1
301
+ actual_batch_size = batch_size * num_images_per_prompt
302
+ num_channels_latents = pipe.transformer.in_channels
 
 
 
303
 
304
+ # Prepare latents
305
+ latents = pipe.prepare_latents(
306
+ actual_batch_size, num_channels_latents, height, width, torch.float32, device, generator
307
+ )
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ # Repeat embeddings for multiple images per prompt
310
+ if num_images_per_prompt > 1:
311
  prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
312
+ if pipe.do_classifier_free_guidance and negative_prompt_embeds:
313
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
314
+
315
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
316
+ mu = calculate_shift(image_seq_len)
317
+
318
+ pipe.scheduler.sigma_min = 0.0
319
+ scheduler_kwargs = {"mu": mu}
320
+ timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, steps, device, **scheduler_kwargs)
321
+
322
+ # Denoising loop
323
+ for i, t in enumerate(timesteps):
324
+ timestep = t.expand(latents.shape[0])
325
+ timestep = (1000 - timestep) / 1000
326
+ t_norm = timestep[0].item()
327
+ apply_cfg = pipe.do_classifier_free_guidance and pipe.guidance_scale > 0
328
+
329
+ if apply_cfg:
330
+ latent_model_input = latents.to(pipe.transformer.dtype).repeat(2, 1, 1, 1).unsqueeze(2)
331
+ prompt_input = prompt_embeds + negative_prompt_embeds
332
+ timestep_input = timestep.repeat(2)
333
+ else:
334
+ latent_model_input = latents.to(pipe.transformer.dtype).unsqueeze(2)
335
+ prompt_input = prompt_embeds
336
+ timestep_input = timestep
337
+
338
+ latent_list = list(latent_model_input.unbind(0))
339
+ model_out_list = pipe.transformer(latent_list, timestep_input, prompt_input, return_dict=False)[0]
340
+
341
+ if apply_cfg:
342
+ pos_out = model_out_list[:actual_batch_size]
343
+ neg_out = model_out_list[actual_batch_size:]
344
+ noise_pred = torch.stack([p + pipe.guidance_scale * (p - n) for p, n in zip(pos_out, neg_out)])
345
+ else:
346
+ noise_pred = torch.stack([t.float() for t in model_out_list], 0)
347
+
348
+ noise_pred = noise_pred.squeeze(2)
349
+ noise_pred = -noise_pred
350
+ latents = pipe.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
351
+
352
+ # Decode final image
353
+ latents = latents.to(pipe.vae.dtype)
354
+ latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
355
+ image = pipe.vae.decode(latents, return_dict=False)[0]
356
+ image = pipe.image_processor.postprocess(image, output_type="pil")
357
+
358
+ return image, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
 
361
  # ============================================================