EuuIia commited on
Commit
8e5d88b
·
verified ·
1 Parent(s): 9500d02

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +231 -125
api/seedvr_server.py CHANGED
@@ -1,142 +1,248 @@
1
- # api/seedvr_server.py
2
 
3
  import os
4
  import sys
5
- import shutil
6
- import mimetypes
7
- import time
8
- import subprocess # Necessário para clonar o repositório na configuração inicial
9
  from pathlib import Path
10
- from typing import Optional, Callable
11
- from types import SimpleNamespace
 
12
 
13
- from huggingface_hub import hf_hub_download
14
-
15
- # Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
16
- # Isso é crucial para que a importação do 'inference_cli' funcione.
17
- SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
18
- if str(SEEDVR_REPO_PATH) not in sys.path:
19
- # Insere no início da lista para garantir prioridade de importação.
20
- sys.path.insert(0, str(SEEDVR_REPO_PATH))
21
-
22
- # Tenta importar as funções necessárias APÓS a modificação do path.
23
- # Se falhar, a aplicação não pode continuar.
24
  try:
25
- from inference_cli import run_inference_logic, save_frames_to_video
 
26
  except ImportError as e:
27
- print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
28
- print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
29
- raise e
30
 
31
- class SeedVRServer:
32
- def __init__(self, **kwargs):
33
- """
34
- Inicializa o servidor, define os caminhos e prepara o ambiente.
35
- """
36
- self.SEEDVR_ROOT = SEEDVR_REPO_PATH
37
- self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
38
- self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
39
- self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
40
- self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
41
- self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
42
- self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
45
- for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
46
- p.mkdir(parents=True, exist_ok=True)
47
 
48
- self.setup_dependencies()
49
- print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
50
-
51
- def setup_dependencies(self):
52
- """ Garante que o repositório e os modelos estão presentes. """
53
- self._ensure_repo()
54
- self._ensure_model()
55
-
56
- def _ensure_repo(self) -> None:
57
- """ Clona o repositório do SeedVR se ele não existir. """
58
- if not (self.SEEDVR_ROOT / ".git").exists():
59
- print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
60
- # Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
61
- subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
62
- else:
63
- print("[SeedVRServer] Repositório SeedVR existe.")
64
-
65
- def _ensure_model(self) -> None:
66
- """ Baixa os checkpoints do Hugging Face se não existirem localmente. """
67
- print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
68
- model_files = {
69
- "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
70
- "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
71
- "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
72
- "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
73
- }
74
- for filename, repo_id in model_files.items():
75
- if not (self.CKPTS_ROOT / filename).exists():
76
- print(f"Baixando {filename} de {repo_id}...")
77
- hf_hub_download(
78
- repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
79
- cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
80
- )
81
- print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
82
 
83
- def run_inference_direct(
84
- self,
85
- file_path: str, *,
86
- seed: int, res_h: int, res_w: int, sp_size: int,
87
- fps: Optional[float] = None, progress: Optional[Callable] = None
88
- ) -> str:
 
 
 
 
 
 
89
  """
90
- Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
 
 
 
91
  """
92
- # Cria um diretório de saída único para salvar o resultado.
93
- out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
94
- out_dir.mkdir(parents=True, exist_ok=True)
95
- output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
96
-
97
- # Simula o objeto 'args' que a função de lógica do inference_cli espera.
98
- # Usamos SimpleNamespace para criar um objeto simples com atributos.
99
- args = SimpleNamespace(
100
- video_path=file_path,
101
- output=str(output_filepath),
102
- model_dir=str(self.CKPTS_ROOT),
103
- seed=seed,
104
- resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
105
- batch_size=sp_size,
106
- model="seedvr2_ema_3b_fp16.safetensors",
107
- preserve_vram=True,
108
- debug=True, # Mantém o debug ativo para logs detalhados.
109
- cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
110
- skip_first_frames=0,
111
- load_cap=0,
112
- output_format='video' # Garante que sempre gere vídeo
113
- )
114
 
115
- try:
116
- # Informa a UI que o processo começou.
117
- if progress:
118
- progress(0.01, "Initializing...")
119
-
120
- # Chama a função importada do script original, passando o callback de progresso.
121
- # Este callback será chamado de dentro da lógica de multi-processamento.
122
- result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
123
 
124
- # Informa a UI que a inferência terminou e o salvamento vai começar.
125
- if progress:
126
- progress(0.95, "Saving the final video...")
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
129
- final_fps = fps if fps and fps > 0 else original_fps
130
- save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
131
-
132
- print(f"✅ Video saved successfully to: {output_filepath}")
133
-
134
- # Retorna o caminho do arquivo gerado para a UI.
135
- return str(output_filepath)
136
 
137
- except Exception as e:
138
- print(f"❌ Error during direct inference execution: {e}")
139
- import traceback
140
- traceback.print_exc()
141
- # Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
142
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_seedvr.py
2
 
3
  import os
4
  import sys
 
 
 
 
5
  from pathlib import Path
6
+ from typing import Optional
7
+ import gradio as gr
8
+ import cv2
9
 
10
+ # --- INTEGRAÇÃO COM A LÓGICA DO SERVIDOR ---
 
 
 
 
 
 
 
 
 
 
11
  try:
12
+ # Importa a classe SeedVRServer que agora atua como nossa biblioteca de inferência.
13
+ from api.seedvr_server import SeedVRServer
14
  except ImportError as e:
15
+ print(f"ERRO FATAL: Não foi possível importar o SeedVRServer. Detalhes: {e}")
16
+ # A aplicação não pode rodar sem a lógica do servidor.
17
+ raise
18
 
19
+ # --- INICIALIZAÇÃO ---
20
+ # Cria uma instância única e persistente do servidor.
21
+ # A inicialização (clonar repo, baixar modelos) acontece apenas uma vez, no início.
22
+ server = SeedVRServer()
23
+
24
+ # --- FUNÇÕES AUXILIARES ---
25
+
26
+ def _is_video(path: str) -> bool:
27
+ """Verifica se um caminho de arquivo corresponde a um tipo de vídeo."""
28
+ if not path: return False
29
+ import mimetypes
30
+ mime, _ = mimetypes.guess_type(path)
31
+ return (mime or "").startswith("video")
32
+
33
+ def _extract_first_frame(video_path: str) -> Optional[str]:
34
+ """Extrai o primeiro frame de um vídeo e o salva como uma imagem JPG."""
35
+ if not video_path or not os.path.exists(video_path): return None
36
+ try:
37
+ vid_cap = cv2.VideoCapture(video_path)
38
+ if not vid_cap.isOpened():
39
+ print(f"Erro: Não foi possível abrir o vídeo em {video_path}")
40
+ return None
41
+ success, image = vid_cap.read()
42
+ vid_cap.release()
43
+ if not success:
44
+ print(f"Erro: Não foi possível ler o primeiro frame de {video_path}")
45
+ return None
46
+
47
+ # Salva o frame no mesmo diretório do vídeo, com extensão .jpg
48
+ image_path = Path(video_path).with_suffix(".jpg")
49
+ cv2.imwrite(str(image_path), image)
50
+ return str(image_path)
51
+ except Exception as e:
52
+ print(f"Erro ao extrair o primeiro frame: {e}")
53
+ return None
54
+
55
+ def on_file_upload(file_obj):
56
+ """
57
+ Callback acionado quando o usuário faz o upload de um arquivo.
58
+ Verifica se o arquivo é um vídeo e sugere um `sp_size` apropriado.
59
+ """
60
+ if file_obj is None:
61
+ # Limpa os resultados e o log se o arquivo for removido
62
+ return gr.update(value=1), None, None, None, gr.update(value=None, visible=False)
63
+
64
+ if _is_video(file_obj.name):
65
+ # Para vídeos, sugere um valor padrão para multi-GPU e torna o slider interativo
66
+ return gr.update(value=4, interactive=True), None, None, None, gr.update(value=None, visible=False)
67
+ else:
68
+ # Para imagens, trava o valor em 1
69
+ return gr.update(value=1, interactive=False), None, None, None, gr.update(value=None, visible=False)
70
+
71
+ # --- FUNÇÃO PRINCIPAL DE INFERÊNCIA DA UI ---
72
+
73
+ def run_inference_ui(
74
+ input_file_path: Optional[str],
75
+ resolution: str,
76
+ sp_size: int,
77
+ fps: float,
78
+ progress=gr.Progress(track_tqdm=True)
79
+ ):
80
+ """
81
+ A função de callback principal do Gradio. Usa geradores (`yield`)
82
+ para permitir atualizações da UI em tempo real durante a tarefa de longa duração.
83
+ """
84
+ # 1. Estado Inicial e Validação
85
+ # No início, desabilita o botão, limpa resultados anteriores e mostra a janela de log.
86
+ yield (
87
+ gr.update(interactive=False, value="Processing... 🚀"),
88
+ gr.update(value=None, visible=False),
89
+ gr.update(value=None, visible=False),
90
+ gr.update(value=None, visible=False),
91
+ gr.update(value="▶ Starting inference process...\n", visible=True)
92
+ )
93
+
94
+ if not input_file_path:
95
+ gr.Warning("Please upload a media file first.")
96
+ # Reabilita o botão e esconde os componentes de saída
97
+ yield (gr.update(interactive=True, value="Restore Media"), None, None, None, gr.update(visible=False))
98
+ return
99
+
100
+ log_buffer = ["▶ Starting inference process...\n"]
101
+ last_log_message = ""
102
+ was_input_video = _is_video(input_file_path)
103
+
104
+ try:
105
+ # Define um callback que será chamado pelo backend para atualizar o progresso e o log
106
+ def progress_callback_wrapper(step: float, desc: str):
107
+ """ Wrapper para formatar logs e atualizar o progresso. """
108
+ nonlocal last_log_message
109
+ # Só adiciona ao log se a mensagem for nova, para evitar poluição visual
110
+ if desc != last_log_message:
111
+ log_buffer.append(f"⏳ {desc}\n")
112
+ last_log_message = desc
113
+ # Atualiza o objeto de progresso do Gradio
114
+ progress(step, desc=desc)
115
+
116
+ # 2. Executa a Inferência
117
+ # Chama o método direto do servidor, passando o nosso callback.
118
+ video_result_path = server.run_inference_direct(
119
+ file_path=input_file_path,
120
+ seed=42, # Semente fixa conforme solicitado
121
+ res_h=int(resolution),
122
+ res_w=int(resolution), # Largura igual à altura
123
+ sp_size=int(sp_size),
124
+ fps=float(fps) if fps and fps > 0 else None,
125
+ progress=progress_callback_wrapper, # Passa nossa função de callback
126
+ )
127
 
128
+ progress(1.0, desc="Complete!")
129
+ log_buffer.append("✅ Inference complete! Processing final output...\n")
 
130
 
131
+ # 3. Processa e Exibe os Resultados
132
+ final_image, final_video = None, None
133
+ if was_input_video:
134
+ final_video = video_result_path
135
+ log_buffer.append(" Video result is ready.\n")
136
+ else: # Se a entrada foi uma imagem
137
+ final_image = _extract_first_frame(video_result_path)
138
+ final_video = video_result_path # Também disponibiliza o vídeo de 1 frame
139
+ log_buffer.append("✅ Image result extracted from video.\n")
140
+
141
+ # Yield final para mostrar os resultados e reabilitar o botão
142
+ yield (
143
+ gr.update(interactive=True, value="Restore Media"),
144
+ gr.update(value=final_image, visible=final_image is not None),
145
+ gr.update(value=final_video, visible=final_video is not None),
146
+ gr.update(value=video_result_path, visible=video_result_path is not None),
147
+ ''.join(log_buffer)
148
+ )
149
+
150
+ except Exception as e:
151
+ error_message = f"❌ Inference failed: {e}"
152
+ gr.Error(error_message)
153
+ log_buffer.append(f"\n{error_message}")
154
+ import traceback
155
+ traceback.print_exc()
 
 
 
 
 
 
 
 
 
156
 
157
+ # Yield para estado de erro: reabilita o botão e mostra o log com o erro
158
+ yield (
159
+ gr.update(interactive=True, value="Restore Media"),
160
+ None, None, None,
161
+ gr.update(value=''.join(log_buffer), visible=True)
162
+ )
163
+
164
+ # --- LAYOUT DA INTERFACE GRÁFICA (GRADIO) ---
165
+
166
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="SeedVR Media Restoration") as demo:
167
+ # Cabeçalho
168
+ gr.Markdown(
169
  """
170
+ <div style='text-align: center; margin-bottom: 20px;'>
171
+ <h1>📸 SeedVR - Image & Video Restoration 🚀</h1>
172
+ <p>High-quality media upscaling powered by SeedVR-3B. Upload your file and see the magic.</p>
173
+ </div>
174
  """
175
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ with gr.Row():
178
+ # --- Coluna da Esquerda: Entradas e Controles ---
179
+ with gr.Column(scale=1):
180
+ gr.Markdown("### 1. Upload Media")
181
+ input_media = gr.File(label="Input File (Video or Image)", type="filepath", interactive=True)
 
 
 
182
 
183
+ gr.Markdown("### 2. Configure Settings")
184
+ with gr.Accordion("Generation Parameters", open=True):
185
+ resolution_select = gr.Dropdown(
186
+ label="Resolution",
187
+ choices=["480", "560", "720", "960", "1024"],
188
+ value="480",
189
+ info="Sets the output height and width to this value."
190
+ )
191
+
192
+ sp_size_slider = gr.Slider(
193
+ label="Frames per Batch (sp_size)",
194
+ minimum=1, maximum=16, step=1, value=4,
195
+ info="For multi-GPU videos. Automatically set to 1 for images."
196
+ )
197
 
198
+ fps_out = gr.Number(label="Output FPS (for Videos)", value=24, precision=0, info="Set to 0 to use the original FPS.")
199
+
200
+ run_button = gr.Button("Restore Media", variant="primary", icon="✨")
201
+
202
+ # --- Coluna da Direita: Resultados ---
203
+ with gr.Column(scale=2):
204
+ gr.Markdown("### 3. Results")
 
205
 
206
+ # Janela de Log
207
+ log_window = gr.Textbox(
208
+ label="Inference Log 📝",
209
+ lines=8, max_lines=15,
210
+ interactive=False, visible=False, autoscroll=True
211
+ )
212
+
213
+ # Componentes de saída (começam invisíveis)
214
+ output_image = gr.Image(label="Image Result", show_download_button=True, type="filepath", visible=False)
215
+ output_video = gr.Video(label="Video Result", visible=False)
216
+ output_download = gr.File(label="Download Full Result (Video)", visible=False)
217
+
218
+ # --- Rodapé ---
219
+ gr.Markdown(
220
+ """
221
+ ---
222
+ *Space and Docker were developed by Carlex.*
223
+ *Contact: Email: Carlex22@gmail.com | GitHub: [carlex22](https://github.com/carlex22)*
224
+ """
225
+ )
226
+
227
+ # --- Lógica de Eventos da UI ---
228
+
229
+ # Ao fazer upload de um arquivo, ajusta o slider `sp_size` e limpa saídas antigas.
230
+ input_media.upload(
231
+ fn=on_file_upload,
232
+ inputs=[input_media],
233
+ outputs=[sp_size_slider, output_image, output_video, output_download, log_window]
234
+ )
235
+
236
+ # Ao clicar no botão, executa a função de inferência principal.
237
+ run_button.click(
238
+ fn=run_inference_ui,
239
+ inputs=[input_media, resolution_select, sp_size_slider, fps_out],
240
+ outputs=[run_button, output_image, output_video, output_download, log_window],
241
+ )
242
+
243
+ if __name__ == "__main__":
244
+ demo.launch(
245
+ server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
246
+ server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
247
+ show_error=True
248
+ )