xieli
audio edit
6852edb
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable, Dict, Literal
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from stepvocoder.cosyvoice2.matcha.audio import mel_spectrogram
class CosyVoiceFrontEnd(object):
def __init__(self,
mel_conf:Dict,
campplus_model:str,
speech_tokenizer_model:str,
onnx_provider:str='CUDAExecutionProvider',
):
super().__init__()
assert onnx_provider in ['CUDAExecutionProvider', 'CPUExecutionProvider'], 'invalid onnx provider'
self.mel_conf = mel_conf
self.sample_rate = mel_conf['sampling_rate']
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(
campplus_model, sess_options=option,
providers=["CPUExecutionProvider"]
)
self.speech_tokenizer_session = onnxruntime.InferenceSession(
speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"],
)
def extract_speech_feat(self, audio:torch.Tensor, audio_sr:int):
if audio_sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=self.sample_rate)
audio_sr = self.sample_rate
speech_feat = mel_spectrogram(y=audio, **self.mel_conf).transpose(1, 2) # (b=1, t, num_mels)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.long)
return speech_feat, speech_feat_len
def extract_spk_embedding(self, audio:torch.Tensor, audio_sr:int):
if audio_sr != 16000:
audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000)
audio_sr = 16000
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
onnx_in = {
self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()
}
embedding = self.campplus_session.run(None, onnx_in)[0].flatten().tolist()
embedding = torch.tensor([embedding])
return embedding
def extract_speech_token(self, audio:torch.Tensor, audio_sr:int):
if audio_sr != 16000:
audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000)
audio_sr = 16000
assert (
audio.shape[1] / 16000 <= 30
), "do not support extract speech token for audio longer than 30s"
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
onnx_in = {
self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32),
}
speech_token = self.speech_tokenizer_session.run(None, onnx_in)[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32)
return speech_token, speech_token_len