Test2 / app_seedvr.py
EuuIia's picture
Update app_seedvr.py
1bbb7db verified
raw
history blame
7.55 kB
# app_seedvr.py
import os
import sys
from pathlib import Path
from typing import Optional
import gradio as gr
import cv2
# --- SERVER LOGIC INTEGRATION ---
try:
from api.seedvr_server import SeedVRServer
except ImportError as e:
print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}")
raise
# --- INITIALIZATION ---
server = SeedVRServer()
# --- HELPER FUNCTIONS ---
def _is_video(path: str) -> bool:
"""Checks if a file path corresponds to a video type."""
if not path: return False
import mimetypes
mime, _ = mimetypes.guess_type(path)
return (mime or "").startswith("video")
def _extract_first_frame(video_path: str) -> Optional[str]:
"""Extracts the first frame from a video and saves it as a JPG image."""
if not video_path or not os.path.exists(video_path): return None
try:
vid_cap = cv2.VideoCapture(video_path)
if not vid_cap.isOpened(): return None
success, image = vid_cap.read()
vid_cap.release()
if not success: return None
image_path = Path(video_path).with_suffix(".jpg")
cv2.imwrite(str(image_path), image)
return str(image_path)
except Exception as e:
print(f"Error extracting first frame: {e}")
return None
def on_file_upload(file_obj):
"""Callback triggered when a user uploads a file."""
if file_obj is None:
return 1
if _is_video(file_obj.name):
return gr.update(value=4, interactive=True)
else:
return gr.update(value=1, interactive=False)
# --- CORE INFERENCE FUNCTION ---
def run_inference_ui(
input_file_path: Optional[str],
resolution: str,
sp_size: int,
fps: float,
progress=gr.Progress(track_tqdm=True)
):
"""
The main callback function for Gradio, using generators (`yield`)
for real-time UI updates.
"""
# 1. Initial State & Validation
yield (
gr.update(interactive=False, value="Processing... 🚀"),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value="Waiting for logs...", visible=True)
)
if not input_file_path:
gr.Warning("Please upload a media file first.")
yield (
gr.update(interactive=True, value="Restore Media"),
None, None, None, gr.update(visible=False)
)
return
log_buffer = ["▶ Starting inference process...\n"]
yield gr.update(), None, None, None, ''.join(log_buffer)
# CORREÇÃO APLICADA AQUI
def progress_callback(step: float, desc: str):
"""A simple callback to append messages to our log buffer."""
log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
# A chamada correta para a API de progresso do Gradio
progress(step, desc=desc)
was_input_video = _is_video(input_file_path)
try:
# 2. Execute Inference
progress_callback(0.1, "Calling backend engine...")
yield gr.update(), None, None, None, ''.join(log_buffer)
video_result_path = server.run_inference_direct(
file_path=input_file_path,
seed=42,
res_h=int(resolution),
res_w=int(resolution),
sp_size=int(sp_size),
fps=float(fps) if fps and fps > 0 else None,
progress=progress,
)
progress_callback(1.0, "Inference complete! Processing final output...")
yield gr.update(), None, None, None, ''.join(log_buffer)
# 3. Process and Display Results
final_image, final_video = None, None
if was_input_video:
final_video = video_result_path
log_buffer.append(f"✅ Video result is ready.\n")
else:
final_image = _extract_first_frame(video_result_path)
final_video = video_result_path
log_buffer.append(f"✅ Image result extracted from video.\n")
yield (
gr.update(interactive=True, value="Restore Media"),
gr.update(value=final_image, visible=final_image is not None),
gr.update(value=final_video, visible=final_video is not None),
gr.update(value=video_result_path, visible=video_result_path is not None),
''.join(log_buffer)
)
except Exception as e:
error_message = f"❌ Inference failed: {e}"
gr.Error(error_message)
print(error_message)
import traceback
traceback.print_exc()
yield (
gr.update(interactive=True, value="Restore Media"),
None, None, None,
gr.update(value=f"{''.join(log_buffer)}\n{error_message}", visible=True)
)
# --- GRADIO UI LAYOUT ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
# Header
gr.Markdown(
"""
<div style='text-align: center; margin-bottom: 20px;'>
<h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
<p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Upload Media")
input_media = gr.File(label="Input File (Video or Image)", type="filepath")
gr.Markdown("### 2. Configure Settings")
with gr.Accordion("Generation Parameters", open=True):
resolution_select = gr.Dropdown(
label="Resolution (Short Edge)",
choices=["480", "560", "720", "960", "1024"],
value="480",
info="The output height and width will be set to this value."
)
sp_size_slider = gr.Slider(
label="Sequence Parallelism (sp_size)",
minimum=1, maximum=16, step=1, value=4,
info="For multi-GPU videos. This will be set to 1 for images."
)
fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
run_button = gr.Button("Restore Media", variant="primary", icon="✨")
with gr.Column(scale=2):
gr.Markdown("### 3. Results")
log_window = gr.Textbox(
label="Inference Log 📝", lines=8, max_lines=15,
interactive=False, visible=False, autoscroll=True,
)
output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
output_video = gr.Video(label="Video Result", visible=False)
output_download = gr.File(label="Download Full Result (Video)", visible=False)
gr.Markdown(
"""
---
*Space and Docker were developed by Carlex.*
*Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
"""
)
input_media.upload(fn=on_file_upload, inputs=[input_media], outputs=[sp_size_slider])
run_button.click(
fn=run_inference_ui,
inputs=[input_media, resolution_select, sp_size_slider, fps_out],
outputs=[run_button, output_image, output_video, output_download, log_window],
)
if __name__ == "__main__":
demo.launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
show_error=True
)