rahul7star commited on
Commit
6228c8b
ยท
verified ยท
1 Parent(s): 02db00d

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +151 -66
app_quant_latent.py CHANGED
@@ -8,6 +8,7 @@ import transformers
8
  import psutil
9
  import os
10
  import time
 
11
 
12
  from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
13
  from diffusers import ZImagePipeline, AutoModel
@@ -446,53 +447,63 @@ try:
446
  torch_dtype=torch_dtype,
447
 
448
  )
449
- pipe.transformer.set_attention_backend("_flash_3")
450
- # pipe.load_lora_weights("bdsqlsz/qinglong_DetailedEyes_Z-Image", weight_name="qinglong_detailedeye_z-imageV2(comfy).safetensors", adapter_name="lora")
451
- pipe.load_lora_weights("rahul7star/ZImageLora",
452
- weight_name="NSFW/doggystyle_pov.safetensors", adapter_name="lora")
453
-
454
- pipe.set_adapters(["lora",], adapter_weights=[1.])
455
- pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75)
 
 
 
 
 
 
 
 
 
 
456
  debug_pipeline(pipe)
457
- # pipe.unload_lora_weights()
458
  pipe.to("cuda")
459
  log("โœ… Pipeline built successfully.")
460
- LOGS.append(log)
461
  except Exception as e:
462
  log(f"โŒ Pipeline build failed: {e}")
 
463
  pipe = None
464
 
465
  log_system_stats("AFTER PIPELINE BUILD")
466
 
467
 
468
  # -----------------------------
469
- # Monkey-patch prepare_latents
470
- # -----------------------------
471
  # -----------------------------
472
- # Monkey-patch prepare_latents
473
- # -----------------------------
474
- if pipe is not None:
475
  original_prepare_latents = pipe.prepare_latents
476
 
477
  def logged_prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
478
- result_latents = original_prepare_latents(
479
- batch_size, num_channels_latents, height, width, dtype, device, generator, latents
480
- )
481
- log_msg = f"๐Ÿ”น prepare_latents called | shape={result_latents.shape}, dtype={result_latents.dtype}, device={result_latents.device}"
482
- if hasattr(self, "_latents_log"):
483
- self._latents_log.append(log_msg)
484
- else:
485
- self._latents_log = [log_msg]
486
- return result_latents
487
-
488
- pipe.prepare_latents = logged_prepare_latents.__get__(pipe)
489
- else:
490
- log("โŒ WARNING: Pipe not initialized; skipping prepare_latents patch")
491
-
492
-
493
- # Apply patch
494
- pipe.prepare_latents = logged_prepare_latents.__get__(pipe)
495
 
 
 
 
 
 
 
 
 
496
 
497
 
498
  from PIL import Image
@@ -520,7 +531,7 @@ def safe_get_latents(pipe, height, width, generator, device, LOGS):
520
 
521
 
522
  # --------------------------
523
- # Main generation function
524
  # --------------------------
525
  @spaces.GPU
526
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
@@ -584,6 +595,7 @@ def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
584
  except Exception as e:
585
  LOGS.append(f"โŒ Total failure: {e}")
586
  return placeholder, [placeholder], LOGS
 
587
  @spaces.GPU
588
  def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
589
  """
@@ -676,35 +688,35 @@ def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0
676
  # UI
677
  # ============================================================
678
 
679
- # with gr.Blocks(title="Z-Image- experiment - dont run")as demo:
680
- # gr.Markdown("# **๐Ÿš€ do not run Z-Image-Turbo โ€” Final Image & Latents**")
681
-
682
-
683
- # with gr.Row():
684
- # with gr.Column(scale=1):
685
- # prompt = gr.Textbox(label="Prompt", value="boat in Ocean")
686
- # height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
687
- # width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
688
- # steps = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
689
- # seed = gr.Number(value=42, label="Seed")
690
- # run_btn = gr.Button("Generate Image")
691
-
692
- # with gr.Column(scale=1):
693
- # final_image = gr.Image(label="Final Image")
694
- # latent_gallery = gr.Gallery(
695
- # label="Latent Steps",
696
- # columns=4,
697
- # height=256,
698
- # preview=True
699
- # )
700
-
701
- # logs_box = gr.Textbox(label="Logs", lines=15)
 
 
 
 
 
702
 
703
- # run_btn.click(
704
- # generate_image,
705
- # inputs=[prompt, height, width, steps, seed],
706
- # outputs=[final_image, latent_gallery, logs_box]
707
- # )
708
 
709
  with gr.Blocks(title="Z-Image-Turbo") as demo:
710
  with gr.Tabs():
@@ -727,6 +739,84 @@ with gr.Blocks(title="Z-Image-Turbo") as demo:
727
  with gr.TabItem("Logs"):
728
  logs_box = gr.Textbox(label="All Logs", lines=25)
729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  # Wire the button AFTER all components exist
731
  run_btn.click(
732
  generate_image,
@@ -734,9 +824,4 @@ with gr.Blocks(title="Z-Image-Turbo") as demo:
734
  outputs=[final_image, latent_gallery, logs_box]
735
  )
736
 
737
-
738
-
739
-
740
-
741
-
742
- demo.launch()
 
8
  import psutil
9
  import os
10
  import time
11
+ import traceback
12
 
13
  from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
14
  from diffusers import ZImagePipeline, AutoModel
 
447
  torch_dtype=torch_dtype,
448
 
449
  )
450
+ # If transformer supports setting backend, prefer flash-3
451
+ try:
452
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"):
453
+ pipe.transformer.set_attention_backend("_flash_3")
454
+ log("โœ… transformer.set_attention_backend('_flash_3') called")
455
+ except Exception as _e:
456
+ log(f"โš ๏ธ set_attention_backend failed: {_e}")
457
+
458
+ # default LoRA load (keeps your existing behaviour)
459
+ try:
460
+ pipe.load_lora_weights("rahul7star/ZImageLora",
461
+ weight_name="NSFW/doggystyle_pov.safetensors", adapter_name="lora")
462
+ pipe.set_adapters(["lora",], adapter_weights=[1.])
463
+ pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75)
464
+ except Exception as _e:
465
+ log(f"โš ๏ธ Default LoRA load failed: {_e}")
466
+
467
  debug_pipeline(pipe)
468
+ # pipe.unload_lora_weights()
469
  pipe.to("cuda")
470
  log("โœ… Pipeline built successfully.")
471
+ LOGS += log("Pipeline build completed.") + "\n"
472
  except Exception as e:
473
  log(f"โŒ Pipeline build failed: {e}")
474
+ log(traceback.format_exc())
475
  pipe = None
476
 
477
  log_system_stats("AFTER PIPELINE BUILD")
478
 
479
 
480
  # -----------------------------
481
+ # Monkey-patch prepare_latents (safe)
 
482
  # -----------------------------
483
+ if pipe is not None and hasattr(pipe, "prepare_latents"):
 
 
484
  original_prepare_latents = pipe.prepare_latents
485
 
486
  def logged_prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
487
+ try:
488
+ result_latents = original_prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, latents)
489
+ log_msg = f"๐Ÿ”น prepare_latents called | shape={result_latents.shape}, dtype={result_latents.dtype}, device={result_latents.device}"
490
+ if hasattr(self, "_latents_log"):
491
+ self._latents_log.append(log_msg)
492
+ else:
493
+ self._latents_log = [log_msg]
494
+ return result_latents
495
+ except Exception as e:
496
+ log(f"โš ๏ธ prepare_latents wrapper failed: {e}")
497
+ raise
 
 
 
 
 
 
498
 
499
+ # apply patch safely
500
+ try:
501
+ pipe.prepare_latents = logged_prepare_latents.__get__(pipe)
502
+ log("โœ… prepare_latents monkey-patched")
503
+ except Exception as e:
504
+ log(f"โš ๏ธ Failed to attach prepare_latents patch: {e}")
505
+ else:
506
+ log("โŒ WARNING: Pipe not initialized or prepare_latents missing; skipping prepare_latents patch")
507
 
508
 
509
  from PIL import Image
 
531
 
532
 
533
  # --------------------------
534
+ # Main generation function (kept exactly as your logic)
535
  # --------------------------
536
  @spaces.GPU
537
  def generate_image(prompt, height, width, steps, seed, guidance_scale=0.0):
 
595
  except Exception as e:
596
  LOGS.append(f"โŒ Total failure: {e}")
597
  return placeholder, [placeholder], LOGS
598
+
599
  @spaces.GPU
600
  def generate_image_backup(prompt, height, width, steps, seed, guidance_scale=0.0, return_latents=False):
601
  """
 
688
  # UI
689
  # ============================================================
690
 
691
+ # Utility: scan local HF cache for safetensors in a repo folder name
692
+ def list_loras_from_repo(repo_id):
693
+ """
694
+ Attempts to find safetensors inside HF cache directory for repo_id.
695
+ This only scans local cache; it does NOT download anything.
696
+ """
697
+ if not repo_id:
698
+ return []
699
+ # Map a repo id to local cache folder name heuristic (works for many cases)
700
+ safe_list = []
701
+ # Common Hugging Face cache root:
702
+ hf_cache = os.path.expanduser("~/.cache/huggingface/hub")
703
+ # Also check /home/user/.cache/huggingface/hub (Spaces environments)
704
+ alt_cache = "/home/user/.cache/huggingface/hub"
705
+
706
+ candidates = [hf_cache, alt_cache]
707
+ needle = repo_id.replace("/", "_")
708
+ for root_cache in candidates:
709
+ if not os.path.exists(root_cache):
710
+ continue
711
+ for root, dirs, files in os.walk(root_cache):
712
+ if needle in root:
713
+ for f in files:
714
+ if f.endswith(".safetensors"):
715
+ safe_list.append(os.path.join(root, f))
716
+ # de-duplicate and sort
717
+ safe_list = sorted(list(dict.fromkeys(safe_list)))
718
+ return safe_list
719
 
 
 
 
 
 
720
 
721
  with gr.Blocks(title="Z-Image-Turbo") as demo:
722
  with gr.Tabs():
 
739
  with gr.TabItem("Logs"):
740
  logs_box = gr.Textbox(label="All Logs", lines=25)
741
 
742
+ # New UI: LoRA repo textbox, dropdown, refresh & rebuild
743
+ with gr.Row():
744
+ lora_repo = gr.Textbox(label="LoRA Repo (HF id)", value="rahul7star/ZImageLora", placeholder="e.g. rahul7star/ZImageLora")
745
+ lora_dropdown = gr.Dropdown(choices=[], label="LoRA files (from local cache)")
746
+ refresh_lora_btn = gr.Button("Refresh LoRA List")
747
+ rebuild_pipe_btn = gr.Button("Rebuild pipeline (use selected LoRA)")
748
+
749
+ # Refresh callback: repopulate dropdown from repo text
750
+ def refresh_lora_list(repo_name):
751
+ try:
752
+ files = list_loras_from_repo(repo_name)
753
+ if not files:
754
+ return gr.update(choices=[], value=None)
755
+ return gr.update(choices=files, value=files[0])
756
+ except Exception as e:
757
+ log(f"โš ๏ธ refresh_lora_list failed: {e}")
758
+ return gr.update(choices=[], value=None)
759
+
760
+ refresh_lora_btn.click(refresh_lora_list, inputs=[lora_repo], outputs=[lora_dropdown])
761
+
762
+ # Rebuild callback: build pipeline with selected lora file path (if any)
763
+ def rebuild_pipeline_with_lora(lora_path, repo_name):
764
+ global pipe, LOGS
765
+ try:
766
+ log(f"๐Ÿ”„ Rebuilding pipeline using LoRA repo={repo_name} file={lora_path}")
767
+ # call existing logic to rebuild: attempt to create new pipeline then load lora file
768
+ pipe = ZImagePipeline.from_pretrained(
769
+ model_id,
770
+ transformer=transformer,
771
+ text_encoder=text_encoder,
772
+ torch_dtype=torch_dtype,
773
+ )
774
+ # try set backend
775
+ try:
776
+ if hasattr(pipe, "transformer") and hasattr(pipe.transformer, "set_attention_backend"):
777
+ pipe.transformer.set_attention_backend("_flash_3")
778
+ except Exception as _e:
779
+ log(f"โš ๏ธ set_attention_backend failed during rebuild: {_e}")
780
+
781
+ # load selected lora if provided
782
+ if lora_path:
783
+ try:
784
+ # repo_name must be HF repo id where load_lora_weights expects it; if user provided repo id use that
785
+ pipe.load_lora_weights(repo_name or "rahul7star/ZImageLora",
786
+ weight_name=os.path.basename(lora_path),
787
+ adapter_name="lora")
788
+ pipe.set_adapters(["lora"], adapter_weights=[1.])
789
+ pipe.fuse_lora(adapter_names=["lora"], lora_scale=0.75)
790
+ except Exception as _e:
791
+ log(f"โš ๏ธ Failed to load selected LoRA during rebuild: {_e}")
792
+
793
+ # finalize
794
+ debug_pipeline(pipe)
795
+ pipe.to("cuda")
796
+ # re-attach monkey patch safely
797
+ if pipe is not None and hasattr(pipe, "prepare_latents"):
798
+ try:
799
+ original_prepare = pipe.prepare_latents
800
+ def logged_prepare(self, *args, **kwargs):
801
+ lat = original_prepare(*args, **kwargs)
802
+ msg = f"๐Ÿ”น prepare_latents called | shape={lat.shape}, dtype={lat.dtype}"
803
+ if hasattr(self, "_latents_log"):
804
+ self._latents_log.append(msg)
805
+ else:
806
+ self._latents_log = [msg]
807
+ return lat
808
+ pipe.prepare_latents = logged_prepare.__get__(pipe)
809
+ log("โœ… Re-applied prepare_latents monkey patch after rebuild")
810
+ except Exception as _e:
811
+ log(f"โš ๏ธ Could not re-apply prepare_latents patch: {_e}")
812
+ return "\n".join([LOGS, "Rebuild complete."])
813
+ except Exception as e:
814
+ log(f"โŒ Rebuild pipeline failed: {e}")
815
+ log(traceback.format_exc())
816
+ return "\n".join([LOGS, f"Rebuild failed: {e}"])
817
+
818
+ rebuild_pipe_btn.click(rebuild_pipeline_with_lora, inputs=[lora_dropdown, lora_repo], outputs=[logs_box])
819
+
820
  # Wire the button AFTER all components exist
821
  run_btn.click(
822
  generate_image,
 
824
  outputs=[final_image, latent_gallery, logs_box]
825
  )
826
 
827
+ demo.launch()