import onnxruntime import torch import numpy as np import whisper from typing import Callable, Dict, Literal import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from stepvocoder.cosyvoice2.matcha.audio import mel_spectrogram class CosyVoiceFrontEnd(object): def __init__(self, mel_conf:Dict, campplus_model:str, speech_tokenizer_model:str, onnx_provider:str='CUDAExecutionProvider', ): super().__init__() assert onnx_provider in ['CUDAExecutionProvider', 'CPUExecutionProvider'], 'invalid onnx provider' self.mel_conf = mel_conf self.sample_rate = mel_conf['sampling_rate'] option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.campplus_session = onnxruntime.InferenceSession( campplus_model, sess_options=option, providers=["CPUExecutionProvider"] ) self.speech_tokenizer_session = onnxruntime.InferenceSession( speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"], ) def extract_speech_feat(self, audio:torch.Tensor, audio_sr:int): if audio_sr != self.sample_rate: audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=self.sample_rate) audio_sr = self.sample_rate speech_feat = mel_spectrogram(y=audio, **self.mel_conf).transpose(1, 2) # (b=1, t, num_mels) speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.long) return speech_feat, speech_feat_len def extract_spk_embedding(self, audio:torch.Tensor, audio_sr:int): if audio_sr != 16000: audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) audio_sr = 16000 feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) onnx_in = { self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy() } embedding = self.campplus_session.run(None, onnx_in)[0].flatten().tolist() embedding = torch.tensor([embedding]) return embedding def extract_speech_token(self, audio:torch.Tensor, audio_sr:int): if audio_sr != 16000: audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) audio_sr = 16000 assert ( audio.shape[1] / 16000 <= 30 ), "do not support extract speech token for audio longer than 30s" feat = whisper.log_mel_spectrogram(audio, n_mels=128) onnx_in = { self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32), } speech_token = self.speech_tokenizer_session.run(None, onnx_in)[0].flatten().tolist() speech_token = torch.tensor([speech_token], dtype=torch.int32) speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32) return speech_token, speech_token_len