EuuIia commited on
Commit
d411dcc
·
verified ·
1 Parent(s): 1bbb7db

Update inference_cli.py

Browse files
Files changed (1) hide show
  1. inference_cli.py +83 -3
inference_cli.py CHANGED
@@ -211,7 +211,37 @@ def save_frames_to_png(frames_tensor, output_dir, base_name, debug=False):
211
  print(f"✅ PNG saving completed: {total} files in '{output_dir}'")
212
 
213
 
214
- def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  """Worker process that performs upscaling on a slice of frames using a dedicated GPU."""
216
  # 1. Limit CUDA visibility to the chosen GPU BEFORE importing torch-heavy deps
217
  os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
@@ -251,7 +281,52 @@ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue):
251
  return_queue.put((proc_idx, result_tensor.cpu().numpy()))
252
 
253
 
254
- def _gpu_processing(frames_tensor, device_list, args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  """Split frames and process them in parallel on multiple GPUs."""
256
  num_devices = len(device_list)
257
  # split frames tensor along time dimension
@@ -431,6 +506,9 @@ def main():
431
  finally:
432
  print(f"🧹 Process {os.getpid()} terminating - VRAM will be automatically freed")
433
 
 
 
 
434
  def run_inference_logic(args, progress_callback=None):
435
  """
436
  Função principal que executa o pipeline de upscaling.
@@ -484,6 +562,8 @@ def run_inference_logic(args, progress_callback=None):
484
  return result_tensor, original_fps, generation_time, len(frames_tensor)
485
 
486
 
 
 
487
  # FUNÇÃO MAIN ORIGINAL (agora um wrapper)
488
  def main():
489
  """Função principal do CLI"""
@@ -512,4 +592,4 @@ def main():
512
 
513
  # Ponto de entrada para execução via linha de comando
514
  if __name__ == "__main__":
515
- main()
 
211
  print(f"✅ PNG saving completed: {total} files in '{output_dir}'")
212
 
213
 
214
+ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None): # Adicionado progress_queue
215
+ """Worker que executa o upscaling em uma GPU dedicada."""
216
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
217
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
218
+
219
+ import torch
220
+ from src.core.model_manager import configure_runner
221
+ from src.core.generation import generation_loop
222
+
223
+ frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
224
+
225
+ # Cria uma função de callback local que envia o progresso para a fila
226
+ local_progress_callback = None
227
+ if progress_queue:
228
+ def callback_wrapper(batch_idx, total_batches, current_frames, message):
229
+ progress_queue.put((batch_idx, total_batches, message))
230
+ local_progress_callback = callback_wrapper
231
+
232
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
233
+
234
+ result_tensor = generation_loop(
235
+ runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
236
+ seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
237
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
238
+ debug=shared_args["debug"],
239
+ progress_callback=local_progress_callback # Passa o callback para o generation_loop
240
+ )
241
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
242
+
243
+
244
+ def _worker_process1(proc_idx, device_id, frames_np, shared_args, return_queue):
245
  """Worker process that performs upscaling on a slice of frames using a dedicated GPU."""
246
  # 1. Limit CUDA visibility to the chosen GPU BEFORE importing torch-heavy deps
247
  os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
 
281
  return_queue.put((proc_idx, result_tensor.cpu().numpy()))
282
 
283
 
284
+ def _gpu_processing(frames_tensor, device_list, args, progress_callback=None): # Adicionado progress_callback
285
+ """Divide os frames e os processa em paralelo em múltiplas GPUs."""
286
+ num_devices = len(device_list)
287
+ chunks = torch.chunk(frames_tensor, num_devices, dim=0)
288
+
289
+ manager = mp.Manager()
290
+ return_queue = manager.Queue()
291
+ progress_queue = manager.Queue() if progress_callback else None # Cria a fila de progresso
292
+ workers = []
293
+
294
+ shared_args = {
295
+ "model": args.model, "model_dir": args.model_dir or "./models/SEEDVR2",
296
+ "preserve_vram": args.preserve_vram, "debug": args.debug, "cfg_scale": 1.0,
297
+ "seed": args.seed, "res_w": args.resolution, "batch_size": args.batch_size, "temporal_overlap": 0,
298
+ }
299
+
300
+ for idx, (device_id, chunk_tensor) in enumerate(zip(device_list, chunks)):
301
+ p = mp.Process(target=_worker_process, args=(idx, device_id, chunk_tensor.cpu().numpy(), shared_args, return_queue, progress_queue))
302
+ p.start()
303
+ workers.append(p)
304
+
305
+ results_np = [None] * num_devices
306
+ collected = 0
307
+ total_batches_per_worker = -1 # Para calcular o progresso total
308
+ while collected < num_devices:
309
+ # Verifica as duas filas (resultado e progresso) de forma não-bloqueante
310
+ if progress_queue and not progress_queue.empty():
311
+ batch_idx, total_batches, message = progress_queue.get()
312
+ if total_batches_per_worker == -1: total_batches_per_worker = total_batches
313
+ total_progress = (collected + (batch_idx / total_batches_per_worker)) / num_devices
314
+ progress_callback(total_progress, desc=f"GPU {collected+1}/{num_devices}: {message}")
315
+
316
+ if not return_queue.empty():
317
+ proc_idx, res_np = return_queue.get()
318
+ results_np[proc_idx] = res_np
319
+ collected += 1
320
+
321
+ time.sleep(0.1) # Evita busy-waiting
322
+
323
+ for p in workers: p.join()
324
+
325
+ return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
326
+
327
+
328
+
329
+ def _gpu_processing1(frames_tensor, device_list, args):
330
  """Split frames and process them in parallel on multiple GPUs."""
331
  num_devices = len(device_list)
332
  # split frames tensor along time dimension
 
506
  finally:
507
  print(f"🧹 Process {os.getpid()} terminating - VRAM will be automatically freed")
508
 
509
+
510
+
511
+
512
  def run_inference_logic(args, progress_callback=None):
513
  """
514
  Função principal que executa o pipeline de upscaling.
 
562
  return result_tensor, original_fps, generation_time, len(frames_tensor)
563
 
564
 
565
+
566
+
567
  # FUNÇÃO MAIN ORIGINAL (agora um wrapper)
568
  def main():
569
  """Função principal do CLI"""
 
592
 
593
  # Ponto de entrada para execução via linha de comando
594
  if __name__ == "__main__":
595
+ main()