congcuong-cse commited on
Commit
c4568cc
·
1 Parent(s): 7a6836d

Add logging for infer_batch_process

Browse files
Files changed (1) hide show
  1. src/f5_tts/infer/utils_infer.py +14 -0
src/f5_tts/infer/utils_infer.py CHANGED
@@ -456,20 +456,34 @@ def infer_batch_process(
456
  fix_duration=None,
457
  device=None,
458
  ):
 
459
  audio, sr = ref_audio
 
 
460
  if audio.shape[0] > 1:
 
461
  audio = torch.mean(audio, dim=0, keepdim=True)
 
462
 
463
  rms = torch.sqrt(torch.mean(torch.square(audio)))
 
 
464
  if rms < target_rms:
 
465
  audio = audio * target_rms / rms
 
466
  if sr != target_sample_rate:
 
467
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
468
  audio = resampler(audio)
 
 
469
  audio = audio.to(device)
 
470
 
471
  generated_waves = []
472
  spectrograms = []
 
473
 
474
  if len(ref_text[-1].encode("utf-8")) == 1:
475
  ref_text = ref_text + " "
 
456
  fix_duration=None,
457
  device=None,
458
  ):
459
+ print("Starting audio preprocessing...")
460
  audio, sr = ref_audio
461
+ print(f"Original audio shape: {audio.shape}, sample rate: {sr}")
462
+
463
  if audio.shape[0] > 1:
464
+ print("Converting multi-channel audio to mono...")
465
  audio = torch.mean(audio, dim=0, keepdim=True)
466
+ print(f"Converted audio shape: {audio.shape}")
467
 
468
  rms = torch.sqrt(torch.mean(torch.square(audio)))
469
+ print(f"Calculated RMS: {rms}")
470
+
471
  if rms < target_rms:
472
+ print(f"Normalizing audio RMS to target RMS: {target_rms}")
473
  audio = audio * target_rms / rms
474
+
475
  if sr != target_sample_rate:
476
+ print(f"Resampling audio from {sr} Hz to {target_sample_rate} Hz...")
477
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
478
  audio = resampler(audio)
479
+ print("Resampling complete.")
480
+
481
  audio = audio.to(device)
482
+ print(f"Audio moved to device: {device}")
483
 
484
  generated_waves = []
485
  spectrograms = []
486
+ print("Initialized containers for generated waves and spectrograms.")
487
 
488
  if len(ref_text[-1].encode("utf-8")) == 1:
489
  ref_text = ref_text + " "