EuuIia commited on
Commit
1bbb7db
·
verified ·
1 Parent(s): 66bcb74

Update app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +15 -60
app_seedvr.py CHANGED
@@ -8,18 +8,13 @@ import gradio as gr
8
  import cv2
9
 
10
  # --- SERVER LOGIC INTEGRATION ---
11
- # This section ensures we can import and use the SeedVR engine directly.
12
  try:
13
- # We need the SeedVRServer class which handles the inference logic.
14
  from api.seedvr_server import SeedVRServer
15
  except ImportError as e:
16
  print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}")
17
- # The application cannot run without the server logic.
18
  raise
19
 
20
  # --- INITIALIZATION ---
21
- # Create a single, persistent instance of the server.
22
- # This clones the repo and downloads models only once at startup.
23
  server = SeedVRServer()
24
 
25
  # --- HELPER FUNCTIONS ---
@@ -40,7 +35,6 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
40
  success, image = vid_cap.read()
41
  vid_cap.release()
42
  if not success: return None
43
-
44
  image_path = Path(video_path).with_suffix(".jpg")
45
  cv2.imwrite(str(image_path), image)
46
  return str(image_path)
@@ -49,18 +43,12 @@ def _extract_first_frame(video_path: str) -> Optional[str]:
49
  return None
50
 
51
  def on_file_upload(file_obj):
52
- """
53
- Callback triggered when a user uploads a file.
54
- It checks if the file is a video and suggests an appropriate `sp_size`.
55
- """
56
  if file_obj is None:
57
- return 1 # Default to 1 if file is cleared
58
-
59
  if _is_video(file_obj.name):
60
- # For videos, suggest a default value suitable for multi-GPU
61
  return gr.update(value=4, interactive=True)
62
  else:
63
- # For images, lock the value to 1
64
  return gr.update(value=1, interactive=False)
65
 
66
  # --- CORE INFERENCE FUNCTION ---
@@ -73,11 +61,10 @@ def run_inference_ui(
73
  progress=gr.Progress(track_tqdm=True)
74
  ):
75
  """
76
- The main callback function for Gradio. This is a generator (`yield`)
77
- to allow for real-time UI updates during the long-running task.
78
  """
79
  # 1. Initial State & Validation
80
- # On start, disable the button, clear previous results, and make the log visible.
81
  yield (
82
  gr.update(interactive=False, value="Processing... 🚀"),
83
  gr.update(value=None, visible=False),
@@ -88,23 +75,21 @@ def run_inference_ui(
88
 
89
  if not input_file_path:
90
  gr.Warning("Please upload a media file first.")
91
- # Re-enable button and hide outputs
92
  yield (
93
  gr.update(interactive=True, value="Restore Media"),
94
  None, None, None, gr.update(visible=False)
95
  )
96
  return
97
 
98
- # Use a simple list to act as a log buffer that can be updated by a callback
99
  log_buffer = ["▶ Starting inference process...\n"]
100
  yield gr.update(), None, None, None, ''.join(log_buffer)
101
 
 
102
  def progress_callback(step: float, desc: str):
103
  """A simple callback to append messages to our log buffer."""
104
- # This function can be passed to the backend if it supports it.
105
- # For now, we'll call it manually from this UI function.
106
  log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
107
- progress.update(amount=step, desc=desc)
 
108
 
109
  was_input_video = _is_video(input_file_path)
110
 
@@ -113,15 +98,14 @@ def run_inference_ui(
113
  progress_callback(0.1, "Calling backend engine...")
114
  yield gr.update(), None, None, None, ''.join(log_buffer)
115
 
116
- # Call the server's direct inference method. This is a blocking call.
117
  video_result_path = server.run_inference_direct(
118
  file_path=input_file_path,
119
- seed=42, # Using a fixed seed as requested
120
  res_h=int(resolution),
121
- res_w=int(resolution), # Set width equal to height
122
  sp_size=int(sp_size),
123
  fps=float(fps) if fps and fps > 0 else None,
124
- progress=progress, # Pass the Gradio progress object
125
  )
126
 
127
  progress_callback(1.0, "Inference complete! Processing final output...")
@@ -132,12 +116,11 @@ def run_inference_ui(
132
  if was_input_video:
133
  final_video = video_result_path
134
  log_buffer.append(f"✅ Video result is ready.\n")
135
- else: # If input was an image
136
  final_image = _extract_first_frame(video_result_path)
137
- final_video = video_result_path # Also provide the 1-frame video
138
  log_buffer.append(f"✅ Image result extracted from video.\n")
139
 
140
- # Final yield to show the results and re-enable the button
141
  yield (
142
  gr.update(interactive=True, value="Restore Media"),
143
  gr.update(value=final_image, visible=final_image is not None),
@@ -153,7 +136,6 @@ def run_inference_ui(
153
  import traceback
154
  traceback.print_exc()
155
 
156
- # Yield an error state and re-enable the button
157
  yield (
158
  gr.update(interactive=True, value="Restore Media"),
159
  None, None, None,
@@ -162,7 +144,6 @@ def run_inference_ui(
162
 
163
 
164
  # --- GRADIO UI LAYOUT ---
165
-
166
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
167
  # Header
168
  gr.Markdown(
@@ -173,13 +154,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
173
  </div>
174
  """
175
  )
176
-
177
  with gr.Row():
178
- # --- Left Column: Inputs & Controls ---
179
  with gr.Column(scale=1):
180
  gr.Markdown("### 1. Upload Media")
181
  input_media = gr.File(label="Input File (Video or Image)", type="filepath")
182
-
183
  gr.Markdown("### 2. Configure Settings")
184
  with gr.Accordion("Generation Parameters", open=True):
185
  resolution_select = gr.Dropdown(
@@ -188,37 +166,22 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
188
  value="480",
189
  info="The output height and width will be set to this value."
190
  )
191
-
192
  sp_size_slider = gr.Slider(
193
  label="Sequence Parallelism (sp_size)",
194
  minimum=1, maximum=16, step=1, value=4,
195
  info="For multi-GPU videos. This will be set to 1 for images."
196
  )
197
-
198
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
199
-
200
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
201
-
202
- # --- Right Column: Outputs ---
203
  with gr.Column(scale=2):
204
  gr.Markdown("### 3. Results")
205
-
206
- # Log window
207
  log_window = gr.Textbox(
208
- label="Inference Log 📝",
209
- lines=8,
210
- max_lines=15,
211
- interactive=False,
212
- visible=False, # Starts hidden
213
- autoscroll=True,
214
  )
215
-
216
- # Output components start hidden and are made visible upon completion
217
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
218
  output_video = gr.Video(label="Video Result", visible=False)
219
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
220
-
221
- # --- Footer ---
222
  gr.Markdown(
223
  """
224
  ---
@@ -227,16 +190,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Res
227
  """
228
  )
229
 
230
- # --- Event Handlers ---
231
-
232
- # When a file is uploaded, automatically adjust the sp_size slider
233
- input_media.upload(
234
- fn=on_file_upload,
235
- inputs=[input_media],
236
- outputs=[sp_size_slider]
237
- )
238
 
239
- # When the "Restore Media" button is clicked, run the main inference function
240
  run_button.click(
241
  fn=run_inference_ui,
242
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],
 
8
  import cv2
9
 
10
  # --- SERVER LOGIC INTEGRATION ---
 
11
  try:
 
12
  from api.seedvr_server import SeedVRServer
13
  except ImportError as e:
14
  print(f"FATAL ERROR: Could not import SeedVRServer. Details: {e}")
 
15
  raise
16
 
17
  # --- INITIALIZATION ---
 
 
18
  server = SeedVRServer()
19
 
20
  # --- HELPER FUNCTIONS ---
 
35
  success, image = vid_cap.read()
36
  vid_cap.release()
37
  if not success: return None
 
38
  image_path = Path(video_path).with_suffix(".jpg")
39
  cv2.imwrite(str(image_path), image)
40
  return str(image_path)
 
43
  return None
44
 
45
  def on_file_upload(file_obj):
46
+ """Callback triggered when a user uploads a file."""
 
 
 
47
  if file_obj is None:
48
+ return 1
 
49
  if _is_video(file_obj.name):
 
50
  return gr.update(value=4, interactive=True)
51
  else:
 
52
  return gr.update(value=1, interactive=False)
53
 
54
  # --- CORE INFERENCE FUNCTION ---
 
61
  progress=gr.Progress(track_tqdm=True)
62
  ):
63
  """
64
+ The main callback function for Gradio, using generators (`yield`)
65
+ for real-time UI updates.
66
  """
67
  # 1. Initial State & Validation
 
68
  yield (
69
  gr.update(interactive=False, value="Processing... 🚀"),
70
  gr.update(value=None, visible=False),
 
75
 
76
  if not input_file_path:
77
  gr.Warning("Please upload a media file first.")
 
78
  yield (
79
  gr.update(interactive=True, value="Restore Media"),
80
  None, None, None, gr.update(visible=False)
81
  )
82
  return
83
 
 
84
  log_buffer = ["▶ Starting inference process...\n"]
85
  yield gr.update(), None, None, None, ''.join(log_buffer)
86
 
87
+ # CORREÇÃO APLICADA AQUI
88
  def progress_callback(step: float, desc: str):
89
  """A simple callback to append messages to our log buffer."""
 
 
90
  log_buffer.append(f"⏳ [{int(step*100)}%] {desc}\n")
91
+ # A chamada correta para a API de progresso do Gradio
92
+ progress(step, desc=desc)
93
 
94
  was_input_video = _is_video(input_file_path)
95
 
 
98
  progress_callback(0.1, "Calling backend engine...")
99
  yield gr.update(), None, None, None, ''.join(log_buffer)
100
 
 
101
  video_result_path = server.run_inference_direct(
102
  file_path=input_file_path,
103
+ seed=42,
104
  res_h=int(resolution),
105
+ res_w=int(resolution),
106
  sp_size=int(sp_size),
107
  fps=float(fps) if fps and fps > 0 else None,
108
+ progress=progress,
109
  )
110
 
111
  progress_callback(1.0, "Inference complete! Processing final output...")
 
116
  if was_input_video:
117
  final_video = video_result_path
118
  log_buffer.append(f"✅ Video result is ready.\n")
119
+ else:
120
  final_image = _extract_first_frame(video_result_path)
121
+ final_video = video_result_path
122
  log_buffer.append(f"✅ Image result extracted from video.\n")
123
 
 
124
  yield (
125
  gr.update(interactive=True, value="Restore Media"),
126
  gr.update(value=final_image, visible=final_image is not None),
 
136
  import traceback
137
  traceback.print_exc()
138
 
 
139
  yield (
140
  gr.update(interactive=True, value="Restore Media"),
141
  None, None, None,
 
144
 
145
 
146
  # --- GRADIO UI LAYOUT ---
 
147
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
148
  # Header
149
  gr.Markdown(
 
154
  </div>
155
  """
156
  )
 
157
  with gr.Row():
 
158
  with gr.Column(scale=1):
159
  gr.Markdown("### 1. Upload Media")
160
  input_media = gr.File(label="Input File (Video or Image)", type="filepath")
 
161
  gr.Markdown("### 2. Configure Settings")
162
  with gr.Accordion("Generation Parameters", open=True):
163
  resolution_select = gr.Dropdown(
 
166
  value="480",
167
  info="The output height and width will be set to this value."
168
  )
 
169
  sp_size_slider = gr.Slider(
170
  label="Sequence Parallelism (sp_size)",
171
  minimum=1, maximum=16, step=1, value=4,
172
  info="For multi-GPU videos. This will be set to 1 for images."
173
  )
 
174
  fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
 
175
  run_button = gr.Button("Restore Media", variant="primary", icon="✨")
 
 
176
  with gr.Column(scale=2):
177
  gr.Markdown("### 3. Results")
 
 
178
  log_window = gr.Textbox(
179
+ label="Inference Log 📝", lines=8, max_lines=15,
180
+ interactive=False, visible=False, autoscroll=True,
 
 
 
 
181
  )
 
 
182
  output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
183
  output_video = gr.Video(label="Video Result", visible=False)
184
  output_download = gr.File(label="Download Full Result (Video)", visible=False)
 
 
185
  gr.Markdown(
186
  """
187
  ---
 
190
  """
191
  )
192
 
193
+ input_media.upload(fn=on_file_upload, inputs=[input_media], outputs=[sp_size_slider])
 
 
 
 
 
 
 
194
 
 
195
  run_button.click(
196
  fn=run_inference_ui,
197
  inputs=[input_media, resolution_select, sp_size_slider, fps_out],