thankfulcarp commited on
Commit
1d1a8c3
Β·
1 Parent(s): e4ae032

Enhancer fix

Browse files
Files changed (1) hide show
  1. app.py +27 -36
app.py CHANGED
@@ -12,7 +12,6 @@ import json
12
  import random
13
  import tempfile
14
  import traceback
15
- from functools import partial
16
 
17
  import gradio as gr
18
  import numpy as np
@@ -163,7 +162,6 @@ def handle_lora_selection_change(preset_name: str, current_prompt: str):
163
  def load_pipelines():
164
  """Loads and configures the T2V and LLM pipelines."""
165
  t2v_pipe = None
166
- enhancer_pipe = None
167
 
168
  print("\nπŸš€ Loading T2V pipeline with base LoRA...")
169
  try:
@@ -192,35 +190,35 @@ def load_pipelines():
192
  traceback.print_exc()
193
  t2v_pipe = None
194
 
195
- print("\nπŸ€– Loading LLM for Prompt Enhancement...")
196
- try:
197
- # In a ZeroGPU environment, we must load models on the CPU at startup.
198
- # The model will be moved to the GPU inside the decorated function.
199
- enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
200
- print("βœ… LLM Prompt Enhancer loaded successfully (on CPU).")
201
- except Exception as e:
202
- print(f"⚠️ WARNING: Could not load the LLM prompt enhancer. The feature will be disabled. Error: {e}")
203
- enhancer_pipe = None
204
-
205
- return t2v_pipe, enhancer_pipe
206
 
207
 
208
  # --- 5. Core Generation & UI Logic ---
209
- @spaces.GPU()
210
- def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
211
- """Uses the loaded LLM to enhance a given prompt."""
212
- if enhancer_pipeline is None:
213
- print("LLM enhancer not available, returning original prompt.")
214
- gr.Warning("LLM enhancer is not available.")
215
- return prompt
216
 
217
- # Move the model to the GPU now that we are in a decorated function
218
- enhancer_pipeline.model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  messages = [{"role": "system", "content": T2V_CINEMATIC_PROMPT_SYSTEM}, {"role": "user", "content": prompt}]
221
  print(f"Enhancing prompt: '{prompt}'")
222
  try:
223
- outputs = enhancer_pipeline(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
224
  final_answer = outputs[0]['generated_text'][-1]['content']
225
  print(f"Enhanced prompt: '{final_answer.strip()}'")
226
  return final_answer.strip()
@@ -340,13 +338,9 @@ def generate_t2v_video(
340
 
341
  # --- 6. Gradio UI Layout ---
342
 
343
- def build_ui(t2v_pipe, enhancer_pipe, available_loras):
344
  """Creates and configures the Gradio UI."""
345
  with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
346
- # We don't use gr.State for the pipeline object because it's not serializable
347
- # and causes a deepcopy error with tensors on multiple devices (CPU/GPU).
348
- # Instead, we use functools.partial to bind the pipeline to its handler function.
349
-
350
  gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
351
  gr.Markdown("Generate videos from text, enhanced by the base `FusionX` LoRA and your choice of dynamic style LoRA.")
352
 
@@ -363,8 +357,8 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
363
  )
364
  t2v_enhance_btn = gr.Button(
365
  "πŸ€– Enhance Prompt with AI",
366
- # The button is disabled if the enhancer pipeline failed to load
367
- interactive=enhancer_pipe is not None
368
  )
369
 
370
  with gr.Group():
@@ -400,11 +394,8 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
400
  t2v_download = gr.File(label="πŸ“₯ Download Video", visible=False)
401
 
402
  if t2v_pipe is not None:
403
- # Create a partial function that has the enhancer_pipe "baked in".
404
- # This avoids the need to pass the complex object through Gradio's state.
405
- enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
406
  t2v_enhance_btn.click(
407
- fn=enhance_fn,
408
  inputs=[t2v_prompt],
409
  outputs=[t2v_prompt]
410
  )
@@ -427,12 +418,12 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
427
 
428
  # --- 7. Main Execution ---
429
  if __name__ == "__main__":
430
- t2v_pipe, enhancer_pipe = load_pipelines()
431
 
432
  # Fetch LoRAs only if the main pipeline loaded successfully
433
  available_loras = []
434
  if t2v_pipe:
435
  available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
436
 
437
- app_ui = build_ui(t2v_pipe, enhancer_pipe, available_loras)
438
  app_ui.queue(max_size=10).launch()
 
12
  import random
13
  import tempfile
14
  import traceback
 
15
 
16
  import gradio as gr
17
  import numpy as np
 
162
  def load_pipelines():
163
  """Loads and configures the T2V and LLM pipelines."""
164
  t2v_pipe = None
 
165
 
166
  print("\nπŸš€ Loading T2V pipeline with base LoRA...")
167
  try:
 
190
  traceback.print_exc()
191
  t2v_pipe = None
192
 
193
+ # The enhancer pipeline is now loaded on-demand inside its decorated function.
194
+ return t2v_pipe
 
 
 
 
 
 
 
 
 
195
 
196
 
197
  # --- 5. Core Generation & UI Logic ---
 
 
 
 
 
 
 
198
 
199
+ ENHANCER_PIPE_CACHE = None # Global cache for the LLM pipeline
200
+
201
+ @spaces.GPU()
202
+ def enhance_prompt_with_llm(prompt: str):
203
+ """
204
+ Uses a cached LLM to enhance a given prompt.
205
+ In a ZeroGPU environment, the model is loaded on the first call.
206
+ """
207
+ global ENHANCER_PIPE_CACHE
208
+ if ENHANCER_PIPE_CACHE is None:
209
+ print("\nπŸ€– Loading LLM for Prompt Enhancement (first run)...")
210
+ try:
211
+ # This happens inside the GPU session, so device_map="auto" is correct.
212
+ ENHANCER_PIPE_CACHE = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
213
+ print("βœ… LLM Prompt Enhancer loaded successfully.")
214
+ except Exception as e:
215
+ print(f"❌ Error loading LLM enhancer: {e}")
216
+ raise gr.Error("Could not load the AI prompt enhancer. Please check the logs.")
217
 
218
  messages = [{"role": "system", "content": T2V_CINEMATIC_PROMPT_SYSTEM}, {"role": "user", "content": prompt}]
219
  print(f"Enhancing prompt: '{prompt}'")
220
  try:
221
+ outputs = ENHANCER_PIPE_CACHE(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
222
  final_answer = outputs[0]['generated_text'][-1]['content']
223
  print(f"Enhanced prompt: '{final_answer.strip()}'")
224
  return final_answer.strip()
 
338
 
339
  # --- 6. Gradio UI Layout ---
340
 
341
+ def build_ui(t2v_pipe, available_loras):
342
  """Creates and configures the Gradio UI."""
343
  with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
 
 
 
 
344
  gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
345
  gr.Markdown("Generate videos from text, enhanced by the base `FusionX` LoRA and your choice of dynamic style LoRA.")
346
 
 
357
  )
358
  t2v_enhance_btn = gr.Button(
359
  "πŸ€– Enhance Prompt with AI",
360
+ # This is now always interactive. Errors are handled inside the click handler.
361
+ interactive=True
362
  )
363
 
364
  with gr.Group():
 
394
  t2v_download = gr.File(label="πŸ“₯ Download Video", visible=False)
395
 
396
  if t2v_pipe is not None:
 
 
 
397
  t2v_enhance_btn.click(
398
+ fn=enhance_prompt_with_llm,
399
  inputs=[t2v_prompt],
400
  outputs=[t2v_prompt]
401
  )
 
418
 
419
  # --- 7. Main Execution ---
420
  if __name__ == "__main__":
421
+ t2v_pipe = load_pipelines()
422
 
423
  # Fetch LoRAs only if the main pipeline loaded successfully
424
  available_loras = []
425
  if t2v_pipe:
426
  available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
427
 
428
+ app_ui = build_ui(t2v_pipe, available_loras)
429
  app_ui.queue(max_size=10).launch()