Spaces:
Runtime error
Runtime error
Commit
Β·
1d1a8c3
1
Parent(s):
e4ae032
Enhancer fix
Browse files
app.py
CHANGED
|
@@ -12,7 +12,6 @@ import json
|
|
| 12 |
import random
|
| 13 |
import tempfile
|
| 14 |
import traceback
|
| 15 |
-
from functools import partial
|
| 16 |
|
| 17 |
import gradio as gr
|
| 18 |
import numpy as np
|
|
@@ -163,7 +162,6 @@ def handle_lora_selection_change(preset_name: str, current_prompt: str):
|
|
| 163 |
def load_pipelines():
|
| 164 |
"""Loads and configures the T2V and LLM pipelines."""
|
| 165 |
t2v_pipe = None
|
| 166 |
-
enhancer_pipe = None
|
| 167 |
|
| 168 |
print("\nπ Loading T2V pipeline with base LoRA...")
|
| 169 |
try:
|
|
@@ -192,35 +190,35 @@ def load_pipelines():
|
|
| 192 |
traceback.print_exc()
|
| 193 |
t2v_pipe = None
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
# In a ZeroGPU environment, we must load models on the CPU at startup.
|
| 198 |
-
# The model will be moved to the GPU inside the decorated function.
|
| 199 |
-
enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
|
| 200 |
-
print("β
LLM Prompt Enhancer loaded successfully (on CPU).")
|
| 201 |
-
except Exception as e:
|
| 202 |
-
print(f"β οΈ WARNING: Could not load the LLM prompt enhancer. The feature will be disabled. Error: {e}")
|
| 203 |
-
enhancer_pipe = None
|
| 204 |
-
|
| 205 |
-
return t2v_pipe, enhancer_pipe
|
| 206 |
|
| 207 |
|
| 208 |
# --- 5. Core Generation & UI Logic ---
|
| 209 |
-
@spaces.GPU()
|
| 210 |
-
def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
|
| 211 |
-
"""Uses the loaded LLM to enhance a given prompt."""
|
| 212 |
-
if enhancer_pipeline is None:
|
| 213 |
-
print("LLM enhancer not available, returning original prompt.")
|
| 214 |
-
gr.Warning("LLM enhancer is not available.")
|
| 215 |
-
return prompt
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
messages = [{"role": "system", "content": T2V_CINEMATIC_PROMPT_SYSTEM}, {"role": "user", "content": prompt}]
|
| 221 |
print(f"Enhancing prompt: '{prompt}'")
|
| 222 |
try:
|
| 223 |
-
outputs =
|
| 224 |
final_answer = outputs[0]['generated_text'][-1]['content']
|
| 225 |
print(f"Enhanced prompt: '{final_answer.strip()}'")
|
| 226 |
return final_answer.strip()
|
|
@@ -340,13 +338,9 @@ def generate_t2v_video(
|
|
| 340 |
|
| 341 |
# --- 6. Gradio UI Layout ---
|
| 342 |
|
| 343 |
-
def build_ui(t2v_pipe,
|
| 344 |
"""Creates and configures the Gradio UI."""
|
| 345 |
with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
|
| 346 |
-
# We don't use gr.State for the pipeline object because it's not serializable
|
| 347 |
-
# and causes a deepcopy error with tensors on multiple devices (CPU/GPU).
|
| 348 |
-
# Instead, we use functools.partial to bind the pipeline to its handler function.
|
| 349 |
-
|
| 350 |
gr.Markdown("# β¨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
|
| 351 |
gr.Markdown("Generate videos from text, enhanced by the base `FusionX` LoRA and your choice of dynamic style LoRA.")
|
| 352 |
|
|
@@ -363,8 +357,8 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
|
| 363 |
)
|
| 364 |
t2v_enhance_btn = gr.Button(
|
| 365 |
"π€ Enhance Prompt with AI",
|
| 366 |
-
#
|
| 367 |
-
interactive=
|
| 368 |
)
|
| 369 |
|
| 370 |
with gr.Group():
|
|
@@ -400,11 +394,8 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
|
| 400 |
t2v_download = gr.File(label="π₯ Download Video", visible=False)
|
| 401 |
|
| 402 |
if t2v_pipe is not None:
|
| 403 |
-
# Create a partial function that has the enhancer_pipe "baked in".
|
| 404 |
-
# This avoids the need to pass the complex object through Gradio's state.
|
| 405 |
-
enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
|
| 406 |
t2v_enhance_btn.click(
|
| 407 |
-
fn=
|
| 408 |
inputs=[t2v_prompt],
|
| 409 |
outputs=[t2v_prompt]
|
| 410 |
)
|
|
@@ -427,12 +418,12 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
|
| 427 |
|
| 428 |
# --- 7. Main Execution ---
|
| 429 |
if __name__ == "__main__":
|
| 430 |
-
t2v_pipe
|
| 431 |
|
| 432 |
# Fetch LoRAs only if the main pipeline loaded successfully
|
| 433 |
available_loras = []
|
| 434 |
if t2v_pipe:
|
| 435 |
available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
|
| 436 |
|
| 437 |
-
app_ui = build_ui(t2v_pipe,
|
| 438 |
app_ui.queue(max_size=10).launch()
|
|
|
|
| 12 |
import random
|
| 13 |
import tempfile
|
| 14 |
import traceback
|
|
|
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
import numpy as np
|
|
|
|
| 162 |
def load_pipelines():
|
| 163 |
"""Loads and configures the T2V and LLM pipelines."""
|
| 164 |
t2v_pipe = None
|
|
|
|
| 165 |
|
| 166 |
print("\nπ Loading T2V pipeline with base LoRA...")
|
| 167 |
try:
|
|
|
|
| 190 |
traceback.print_exc()
|
| 191 |
t2v_pipe = None
|
| 192 |
|
| 193 |
+
# The enhancer pipeline is now loaded on-demand inside its decorated function.
|
| 194 |
+
return t2v_pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
# --- 5. Core Generation & UI Logic ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
ENHANCER_PIPE_CACHE = None # Global cache for the LLM pipeline
|
| 200 |
+
|
| 201 |
+
@spaces.GPU()
|
| 202 |
+
def enhance_prompt_with_llm(prompt: str):
|
| 203 |
+
"""
|
| 204 |
+
Uses a cached LLM to enhance a given prompt.
|
| 205 |
+
In a ZeroGPU environment, the model is loaded on the first call.
|
| 206 |
+
"""
|
| 207 |
+
global ENHANCER_PIPE_CACHE
|
| 208 |
+
if ENHANCER_PIPE_CACHE is None:
|
| 209 |
+
print("\nπ€ Loading LLM for Prompt Enhancement (first run)...")
|
| 210 |
+
try:
|
| 211 |
+
# This happens inside the GPU session, so device_map="auto" is correct.
|
| 212 |
+
ENHANCER_PIPE_CACHE = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
|
| 213 |
+
print("β
LLM Prompt Enhancer loaded successfully.")
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"β Error loading LLM enhancer: {e}")
|
| 216 |
+
raise gr.Error("Could not load the AI prompt enhancer. Please check the logs.")
|
| 217 |
|
| 218 |
messages = [{"role": "system", "content": T2V_CINEMATIC_PROMPT_SYSTEM}, {"role": "user", "content": prompt}]
|
| 219 |
print(f"Enhancing prompt: '{prompt}'")
|
| 220 |
try:
|
| 221 |
+
outputs = ENHANCER_PIPE_CACHE(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
|
| 222 |
final_answer = outputs[0]['generated_text'][-1]['content']
|
| 223 |
print(f"Enhanced prompt: '{final_answer.strip()}'")
|
| 224 |
return final_answer.strip()
|
|
|
|
| 338 |
|
| 339 |
# --- 6. Gradio UI Layout ---
|
| 340 |
|
| 341 |
+
def build_ui(t2v_pipe, available_loras):
|
| 342 |
"""Creates and configures the Gradio UI."""
|
| 343 |
with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
gr.Markdown("# β¨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
|
| 345 |
gr.Markdown("Generate videos from text, enhanced by the base `FusionX` LoRA and your choice of dynamic style LoRA.")
|
| 346 |
|
|
|
|
| 357 |
)
|
| 358 |
t2v_enhance_btn = gr.Button(
|
| 359 |
"π€ Enhance Prompt with AI",
|
| 360 |
+
# This is now always interactive. Errors are handled inside the click handler.
|
| 361 |
+
interactive=True
|
| 362 |
)
|
| 363 |
|
| 364 |
with gr.Group():
|
|
|
|
| 394 |
t2v_download = gr.File(label="π₯ Download Video", visible=False)
|
| 395 |
|
| 396 |
if t2v_pipe is not None:
|
|
|
|
|
|
|
|
|
|
| 397 |
t2v_enhance_btn.click(
|
| 398 |
+
fn=enhance_prompt_with_llm,
|
| 399 |
inputs=[t2v_prompt],
|
| 400 |
outputs=[t2v_prompt]
|
| 401 |
)
|
|
|
|
| 418 |
|
| 419 |
# --- 7. Main Execution ---
|
| 420 |
if __name__ == "__main__":
|
| 421 |
+
t2v_pipe = load_pipelines()
|
| 422 |
|
| 423 |
# Fetch LoRAs only if the main pipeline loaded successfully
|
| 424 |
available_loras = []
|
| 425 |
if t2v_pipe:
|
| 426 |
available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
|
| 427 |
|
| 428 |
+
app_ui = build_ui(t2v_pipe, available_loras)
|
| 429 |
app_ui.queue(max_size=10).launch()
|