Spaces:
Running
on
Zero
Running
on
Zero
| import onnxruntime | |
| import torch | |
| import numpy as np | |
| import whisper | |
| from typing import Callable, Dict, Literal | |
| import torch | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| from stepvocoder.cosyvoice2.matcha.audio import mel_spectrogram | |
| class CosyVoiceFrontEnd(object): | |
| def __init__(self, | |
| mel_conf:Dict, | |
| campplus_model:str, | |
| speech_tokenizer_model:str, | |
| onnx_provider:str='CUDAExecutionProvider', | |
| ): | |
| super().__init__() | |
| assert onnx_provider in ['CUDAExecutionProvider', 'CPUExecutionProvider'], 'invalid onnx provider' | |
| self.mel_conf = mel_conf | |
| self.sample_rate = mel_conf['sampling_rate'] | |
| option = onnxruntime.SessionOptions() | |
| option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| option.intra_op_num_threads = 1 | |
| self.campplus_session = onnxruntime.InferenceSession( | |
| campplus_model, sess_options=option, | |
| providers=["CPUExecutionProvider"] | |
| ) | |
| self.speech_tokenizer_session = onnxruntime.InferenceSession( | |
| speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"], | |
| ) | |
| def extract_speech_feat(self, audio:torch.Tensor, audio_sr:int): | |
| if audio_sr != self.sample_rate: | |
| audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=self.sample_rate) | |
| audio_sr = self.sample_rate | |
| speech_feat = mel_spectrogram(y=audio, **self.mel_conf).transpose(1, 2) # (b=1, t, num_mels) | |
| speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.long) | |
| return speech_feat, speech_feat_len | |
| def extract_spk_embedding(self, audio:torch.Tensor, audio_sr:int): | |
| if audio_sr != 16000: | |
| audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) | |
| audio_sr = 16000 | |
| feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) | |
| feat = feat - feat.mean(dim=0, keepdim=True) | |
| onnx_in = { | |
| self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy() | |
| } | |
| embedding = self.campplus_session.run(None, onnx_in)[0].flatten().tolist() | |
| embedding = torch.tensor([embedding]) | |
| return embedding | |
| def extract_speech_token(self, audio:torch.Tensor, audio_sr:int): | |
| if audio_sr != 16000: | |
| audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) | |
| audio_sr = 16000 | |
| assert ( | |
| audio.shape[1] / 16000 <= 30 | |
| ), "do not support extract speech token for audio longer than 30s" | |
| feat = whisper.log_mel_spectrogram(audio, n_mels=128) | |
| onnx_in = { | |
| self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), | |
| self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32), | |
| } | |
| speech_token = self.speech_tokenizer_session.run(None, onnx_in)[0].flatten().tolist() | |
| speech_token = torch.tensor([speech_token], dtype=torch.int32) | |
| speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32) | |
| return speech_token, speech_token_len | |