thankfulcarp commited on
Commit
ae27ee5
Β·
1 Parent(s): ed32595

Changed to loading LLM at start again.

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -12,6 +12,7 @@ import json
12
  import random
13
  import tempfile
14
  import traceback
 
15
 
16
  import gradio as gr
17
  import numpy as np
@@ -219,7 +220,7 @@ def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
219
 
220
  def load_pipelines():
221
  """Loads and configures the T2V and LLM pipelines."""
222
- t2v_pipe = None
223
 
224
  print("\nπŸš€ Loading T2V pipeline with base LoRA...")
225
  try:
@@ -248,40 +249,38 @@ def load_pipelines():
248
  traceback.print_exc()
249
  t2v_pipe = None
250
 
251
- # The enhancer pipeline is now loaded on-demand inside its decorated function.
252
- return t2v_pipe
 
 
 
 
 
 
 
253
 
 
254
 
255
- # --- 5. Core Generation & UI Logic ---
256
 
257
- ENHANCER_PIPE_CACHE = None # Global cache for the LLM pipeline
258
 
259
  @spaces.GPU()
260
- def enhance_prompt_with_llm(prompt: str):
261
  """
262
- Uses a cached LLM to enhance a given prompt.
263
- In a ZeroGPU environment, the model is loaded on the first call.
264
  """
265
- global ENHANCER_PIPE_CACHE
266
- if ENHANCER_PIPE_CACHE is None:
267
- print("\nπŸ€– Loading LLM for Prompt Enhancement (first run)...")
268
- try:
269
- # This happens inside the GPU session, so device_map="auto" is correct.
270
- ENHANCER_PIPE_CACHE = pipeline(
271
- "text-generation",
272
- model=ENHANCER_MODEL_ID,
273
- torch_dtype=torch.bfloat16,
274
- device_map="auto"
275
- )
276
- print("βœ… LLM Prompt Enhancer loaded successfully.")
277
- except Exception as e:
278
- print(f"❌ Error loading LLM enhancer: {e}")
279
- raise gr.Error("Could not load the AI prompt enhancer. Please check the logs.")
280
 
281
  messages = [{"role": "system", "content": T2V_CINEMATIC_PROMPT_SYSTEM}, {"role": "user", "content": prompt}]
282
  print(f"Enhancing prompt: '{prompt}'")
283
  try:
284
- outputs = ENHANCER_PIPE_CACHE(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
285
  final_answer = outputs[0]['generated_text'][-1]['content']
286
  print(f"Enhanced prompt: '{final_answer.strip()}'")
287
  return final_answer.strip()
@@ -370,7 +369,7 @@ def generate_t2v_video(
370
 
371
  # --- 6. Gradio UI Layout ---
372
 
373
- def build_ui(t2v_pipe, available_loras):
374
  """Creates and configures the Gradio UI."""
375
  with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
376
  gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
@@ -389,8 +388,8 @@ def build_ui(t2v_pipe, available_loras):
389
  )
390
  t2v_enhance_btn = gr.Button(
391
  "πŸ€– Enhance Prompt with AI",
392
- # This is now always interactive. Errors are handled inside the click handler.
393
- interactive=True
394
  )
395
 
396
  with gr.Group():
@@ -426,8 +425,11 @@ def build_ui(t2v_pipe, available_loras):
426
  t2v_download = gr.File(label="πŸ“₯ Download Video", visible=False)
427
 
428
  if t2v_pipe is not None:
 
 
 
429
  t2v_enhance_btn.click(
430
- fn=enhance_prompt_with_llm,
431
  inputs=[t2v_prompt],
432
  outputs=[t2v_prompt]
433
  )
@@ -450,12 +452,12 @@ def build_ui(t2v_pipe, available_loras):
450
 
451
  # --- 7. Main Execution ---
452
  if __name__ == "__main__":
453
- t2v_pipe = load_pipelines()
454
 
455
  # Fetch LoRAs only if the main pipeline loaded successfully
456
  available_loras = []
457
  if t2v_pipe:
458
  available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
459
 
460
- app_ui = build_ui(t2v_pipe, available_loras)
461
  app_ui.queue(max_size=10).launch()
 
12
  import random
13
  import tempfile
14
  import traceback
15
+ from functools import partial
16
 
17
  import gradio as gr
18
  import numpy as np
 
220
 
221
  def load_pipelines():
222
  """Loads and configures the T2V and LLM pipelines."""
223
+ t2v_pipe, enhancer_pipe = None, None
224
 
225
  print("\nπŸš€ Loading T2V pipeline with base LoRA...")
226
  try:
 
249
  traceback.print_exc()
250
  t2v_pipe = None
251
 
252
+ print("\nπŸ€– Loading LLM for Prompt Enhancement...")
253
+ try:
254
+ # In a ZeroGPU environment, we must load models on the CPU at startup.
255
+ # The model will be moved to the GPU inside the decorated function.
256
+ enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
257
+ print("βœ… LLM Prompt Enhancer loaded successfully (on CPU).")
258
+ except Exception as e:
259
+ print(f"⚠️ WARNING: Could not load the LLM prompt enhancer. The feature will be disabled. Error: {e}")
260
+ enhancer_pipe = None
261
 
262
+ return t2v_pipe, enhancer_pipe
263
 
 
264
 
265
+ # --- 5. Core Generation & UI Logic ---
266
 
267
  @spaces.GPU()
268
+ def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
269
  """
270
+ Uses the loaded LLM to enhance a given prompt.
 
271
  """
272
+ if enhancer_pipeline is None:
273
+ print("LLM enhancer not available, returning original prompt.")
274
+ gr.Warning("LLM enhancer is not available.")
275
+ return prompt
276
+
277
+ # Move the entire pipeline to the GPU. This handles the model, tokenizer, and device settings.
278
+ enhancer_pipeline.to("cuda")
 
 
 
 
 
 
 
 
279
 
280
  messages = [{"role": "system", "content": T2V_CINEMATIC_PROMPT_SYSTEM}, {"role": "user", "content": prompt}]
281
  print(f"Enhancing prompt: '{prompt}'")
282
  try:
283
+ outputs = enhancer_pipeline(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
284
  final_answer = outputs[0]['generated_text'][-1]['content']
285
  print(f"Enhanced prompt: '{final_answer.strip()}'")
286
  return final_answer.strip()
 
369
 
370
  # --- 6. Gradio UI Layout ---
371
 
372
+ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
373
  """Creates and configures the Gradio UI."""
374
  with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
375
  gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
 
388
  )
389
  t2v_enhance_btn = gr.Button(
390
  "πŸ€– Enhance Prompt with AI",
391
+ # The button is disabled if the enhancer pipeline failed to load
392
+ interactive=enhancer_pipe is not None
393
  )
394
 
395
  with gr.Group():
 
425
  t2v_download = gr.File(label="πŸ“₯ Download Video", visible=False)
426
 
427
  if t2v_pipe is not None:
428
+ # Create a partial function that has the enhancer_pipe "baked in".
429
+ # This avoids the need to pass the complex object through Gradio's state.
430
+ enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
431
  t2v_enhance_btn.click(
432
+ fn=enhance_fn,
433
  inputs=[t2v_prompt],
434
  outputs=[t2v_prompt]
435
  )
 
452
 
453
  # --- 7. Main Execution ---
454
  if __name__ == "__main__":
455
+ t2v_pipe, enhancer_pipe = load_pipelines()
456
 
457
  # Fetch LoRAs only if the main pipeline loaded successfully
458
  available_loras = []
459
  if t2v_pipe:
460
  available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
461
 
462
+ app_ui = build_ui(t2v_pipe, enhancer_pipe, available_loras)
463
  app_ui.queue(max_size=10).launch()