Spaces:
Running
Running
| import argparse | |
| import os | |
| import warnings | |
| from pathlib import Path | |
| from time import perf_counter | |
| import numpy as np | |
| import onnxruntime as ort | |
| import soundfile as sf | |
| import torch | |
| from matcha.cli import plot_spectrogram_to_numpy, process_text | |
| def validate_args(args): | |
| assert ( | |
| args.text or args.file | |
| ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." | |
| assert args.temperature >= 0, "Sampling temperature cannot be negative" | |
| assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" | |
| return args | |
| def write_wavs(model, inputs, output_dir, external_vocoder=None): | |
| if external_vocoder is None: | |
| print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") | |
| t0 = perf_counter() | |
| wavs, wav_lengths = model.run(None, inputs) | |
| infer_secs = perf_counter() - t0 | |
| mel_infer_secs = vocoder_infer_secs = None | |
| else: | |
| print("[🍵] Generating mel using Matcha") | |
| mel_t0 = perf_counter() | |
| mels, mel_lengths = model.run(None, inputs) | |
| mel_infer_secs = perf_counter() - mel_t0 | |
| print("Generating waveform from mel using external vocoder") | |
| vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} | |
| vocoder_t0 = perf_counter() | |
| wavs = external_vocoder.run(None, vocoder_inputs)[0] | |
| vocoder_infer_secs = perf_counter() - vocoder_t0 | |
| wavs = wavs.squeeze(1) | |
| wav_lengths = mel_lengths * 256 | |
| infer_secs = mel_infer_secs + vocoder_infer_secs | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): | |
| output_filename = output_dir.joinpath(f"output_{i + 1}.wav") | |
| audio = wav[:wav_length] | |
| print(f"Writing audio to {output_filename}") | |
| sf.write(output_filename, audio, 22050, "PCM_24") | |
| wav_secs = wav_lengths.sum() / 22050 | |
| print(f"Inference seconds: {infer_secs}") | |
| print(f"Generated wav seconds: {wav_secs}") | |
| rtf = infer_secs / wav_secs | |
| if mel_infer_secs is not None: | |
| mel_rtf = mel_infer_secs / wav_secs | |
| print(f"Matcha RTF: {mel_rtf}") | |
| if vocoder_infer_secs is not None: | |
| vocoder_rtf = vocoder_infer_secs / wav_secs | |
| print(f"Vocoder RTF: {vocoder_rtf}") | |
| print(f"Overall RTF: {rtf}") | |
| def write_mels(model, inputs, output_dir): | |
| t0 = perf_counter() | |
| mels, mel_lengths = model.run(None, inputs) | |
| infer_secs = perf_counter() - t0 | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| for i, mel in enumerate(mels): | |
| output_stem = output_dir.joinpath(f"output_{i + 1}") | |
| plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) | |
| np.save(output_stem.with_suffix(".numpy"), mel) | |
| wav_secs = (mel_lengths * 256).sum() / 22050 | |
| print(f"Inference seconds: {infer_secs}") | |
| print(f"Generated wav seconds: {wav_secs}") | |
| rtf = infer_secs / wav_secs | |
| print(f"RTF: {rtf}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" | |
| ) | |
| parser.add_argument( | |
| "model", | |
| type=str, | |
| help="ONNX model to use", | |
| ) | |
| parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") | |
| parser.add_argument("--text", type=str, default=None, help="Text to synthesize") | |
| parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") | |
| parser.add_argument("--spk", type=int, default=None, help="Speaker ID") | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| default=0.667, | |
| help="Variance of the x0 noise (default: 0.667)", | |
| ) | |
| parser.add_argument( | |
| "--speaking-rate", | |
| type=float, | |
| default=1.0, | |
| help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", | |
| ) | |
| parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default=os.getcwd(), | |
| help="Output folder to save results (default: current dir)", | |
| ) | |
| args = parser.parse_args() | |
| args = validate_args(args) | |
| if args.gpu: | |
| providers = ["GPUExecutionProvider"] | |
| else: | |
| providers = ["CPUExecutionProvider"] | |
| model = ort.InferenceSession(args.model, providers=providers) | |
| model_inputs = model.get_inputs() | |
| model_outputs = list(model.get_outputs()) | |
| if args.text: | |
| text_lines = args.text.splitlines() | |
| else: | |
| with open(args.file, encoding="utf-8") as file: | |
| text_lines = file.read().splitlines() | |
| processed_lines = [process_text(0, line, "cpu") for line in text_lines] | |
| x = [line["x"].squeeze() for line in processed_lines] | |
| # Pad | |
| x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) | |
| x = x.detach().cpu().numpy() | |
| x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) | |
| inputs = { | |
| "x": x, | |
| "x_lengths": x_lengths, | |
| "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), | |
| } | |
| is_multi_speaker = len(model_inputs) == 4 | |
| if is_multi_speaker: | |
| if args.spk is None: | |
| args.spk = 0 | |
| warn = "[!] Speaker ID not provided! Using speaker ID 0" | |
| warnings.warn(warn, UserWarning) | |
| inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) | |
| has_vocoder_embedded = model_outputs[0].name == "wav" | |
| if has_vocoder_embedded: | |
| write_wavs(model, inputs, args.output_dir) | |
| elif args.vocoder: | |
| external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) | |
| write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) | |
| else: | |
| warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" | |
| warnings.warn(warn, UserWarning) | |
| write_mels(model, inputs, args.output_dir) | |
| if __name__ == "__main__": | |
| main() | |