Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/eval/eval_infer_batch.py
CHANGED
|
@@ -189,13 +189,13 @@ def main():
|
|
| 189 |
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
| 190 |
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
| 191 |
if mel_spec_type == "vocos":
|
| 192 |
-
generated_wave = vocoder.decode(gen_mel_spec)
|
| 193 |
elif mel_spec_type == "bigvgan":
|
| 194 |
-
generated_wave = vocoder(gen_mel_spec)
|
| 195 |
|
| 196 |
if ref_rms_list[i] < target_rms:
|
| 197 |
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
| 198 |
-
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave
|
| 199 |
|
| 200 |
accelerator.wait_for_everyone()
|
| 201 |
if accelerator.is_main_process:
|
|
|
|
| 189 |
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
| 190 |
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
| 191 |
if mel_spec_type == "vocos":
|
| 192 |
+
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
| 193 |
elif mel_spec_type == "bigvgan":
|
| 194 |
+
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
| 195 |
|
| 196 |
if ref_rms_list[i] < target_rms:
|
| 197 |
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
| 198 |
+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
| 199 |
|
| 200 |
accelerator.wait_for_everyone()
|
| 201 |
if accelerator.is_main_process:
|
src/f5_tts/infer/speech_edit.py
CHANGED
|
@@ -181,13 +181,13 @@ with torch.inference_mode():
|
|
| 181 |
generated = generated[:, ref_audio_len:, :]
|
| 182 |
gen_mel_spec = generated.permute(0, 2, 1)
|
| 183 |
if mel_spec_type == "vocos":
|
| 184 |
-
generated_wave = vocoder.decode(gen_mel_spec)
|
| 185 |
elif mel_spec_type == "bigvgan":
|
| 186 |
-
generated_wave = vocoder(gen_mel_spec)
|
| 187 |
|
| 188 |
if rms < target_rms:
|
| 189 |
generated_wave = generated_wave * rms / target_rms
|
| 190 |
|
| 191 |
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
| 192 |
-
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave
|
| 193 |
print(f"Generated wav: {generated_wave.shape}")
|
|
|
|
| 181 |
generated = generated[:, ref_audio_len:, :]
|
| 182 |
gen_mel_spec = generated.permute(0, 2, 1)
|
| 183 |
if mel_spec_type == "vocos":
|
| 184 |
+
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
| 185 |
elif mel_spec_type == "bigvgan":
|
| 186 |
+
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
| 187 |
|
| 188 |
if rms < target_rms:
|
| 189 |
generated_wave = generated_wave * rms / target_rms
|
| 190 |
|
| 191 |
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
| 192 |
+
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
|
| 193 |
print(f"Generated wav: {generated_wave.shape}")
|
src/f5_tts/model/trainer.py
CHANGED
|
@@ -324,26 +324,31 @@ class Trainer:
|
|
| 324 |
self.save_checkpoint(global_step)
|
| 325 |
|
| 326 |
if self.log_samples and self.accelerator.is_local_main_process:
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
with torch.inference_mode():
|
| 332 |
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
| 333 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
| 334 |
-
text=
|
| 335 |
duration=ref_audio_len * 2,
|
| 336 |
steps=nfe_step,
|
| 337 |
cfg_strength=cfg_strength,
|
| 338 |
sway_sampling_coef=sway_sampling_coef,
|
| 339 |
)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
if global_step % self.last_per_steps == 0:
|
| 349 |
self.save_checkpoint(global_step, last=True)
|
|
|
|
| 324 |
self.save_checkpoint(global_step)
|
| 325 |
|
| 326 |
if self.log_samples and self.accelerator.is_local_main_process:
|
| 327 |
+
ref_audio_len = mel_lengths[0]
|
| 328 |
+
infer_text = [
|
| 329 |
+
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
|
| 330 |
+
]
|
| 331 |
with torch.inference_mode():
|
| 332 |
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
| 333 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
| 334 |
+
text=infer_text,
|
| 335 |
duration=ref_audio_len * 2,
|
| 336 |
steps=nfe_step,
|
| 337 |
cfg_strength=cfg_strength,
|
| 338 |
sway_sampling_coef=sway_sampling_coef,
|
| 339 |
)
|
| 340 |
+
generated = generated.to(torch.float32)
|
| 341 |
+
gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
|
| 342 |
+
ref_mel_spec = batch["mel"][0].unsqueeze(0)
|
| 343 |
+
if self.vocoder_name == "vocos":
|
| 344 |
+
gen_audio = vocoder.decode(gen_mel_spec).cpu()
|
| 345 |
+
ref_audio = vocoder.decode(ref_mel_spec).cpu()
|
| 346 |
+
elif self.vocoder_name == "bigvgan":
|
| 347 |
+
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
| 348 |
+
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
| 349 |
+
|
| 350 |
+
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
| 351 |
+
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
| 352 |
|
| 353 |
if global_step % self.last_per_steps == 0:
|
| 354 |
self.save_checkpoint(global_step, last=True)
|