xieli
audio edit
6852edb
from functools import cached_property, reduce
from typing import List, Optional, Union
from copy import deepcopy
from collections import defaultdict
import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from hyperpyyaml import load_hyperpyyaml
from stepvocoder.cosyvoice2.cli.frontend import CosyVoiceFrontEnd
from stepvocoder.cosyvoice2.flow.flow import CausalMaskedDiffWithXvec
from stepvocoder.cosyvoice2.hifigan.generator import HiFTGenerator
from stepvocoder.cosyvoice2.bigvgan.bigvgan import BigVGAN
# from stepvocoder.cosyvoice2.utils.common import fade_in_out
import threading
"""perform fade_in_out in tensor style
"""
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel = fade_in_mel.clone()
fade_in_mel[..., :mel_overlap_len] = \
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel
# torch._dynamo.config.cache_size_limit = 128
# torch._dynamo.config.accumulated_cache_size_limit = 128
"""
A wrapper for managing stream caches.
"""
class CosyVoice_stream_impl_(torch.nn.Module):
def __init__(self,
flow: CausalMaskedDiffWithXvec,
hift: Union[HiFTGenerator, BigVGAN],
chunk_size_list: List = [15, 24, 48], # (0.6s, 0.96s, 1.92s)
mel_cache_len: int = 8,
n_timesteps: int = 10, # for both stream/non-stream
):
super().__init__()
self.flow = flow
self.hift = hift
self.n_timesteps = n_timesteps
# hard coded!
# self.sample_rate = hift.sampling_rate
self.token_lookahead = flow.pre_lookahead_len
# stream conf
self.mel_cache_len = mel_cache_len
if isinstance(self.hift, BigVGAN):
# bigvgan use left 3 frames and right 3 frames as context
self.source_cache_len = int((mel_cache_len - 6)* 480) # 50hz mel -> 24k wave
elif isinstance(self.hift, HiFTGenerator):
self.source_cache_len = int(mel_cache_len * 480) # 50hz mel -> 24k wave
else:
raise ValueError(f'unsupported vocoder type {type(self.hift)}')
self.register_buffer('speech_window', torch.from_numpy(np.hamming(2 * self.source_cache_len)), persistent=False)
# session management
self.speech_token_dict = defaultdict(list)
self.chunk_size_list = chunk_size_list
self.chunk_size_dict = {}
self.b_first_chunk_dict = {} # indicate if it's the first chunk of this session
# hifigan cache
self.hift_cache_dict = {}
# model att/cnn cache
self.chunk_cache_dict = {}
self.estimator_prompt_length_dict = {}
# speaker embedding cache
self.spk_embedding_cache_dict = {}
# setup lock
self.setup_lock = threading.Lock()
@cached_property
def device(self):
return next(self.hift.parameters()).device
@cached_property
def dtype(self):
return next(self.hift.parameters()).dtype
"""NOTE Non-stream interface.
"""
def token2wav_nonstream(self,
token: torch.Tensor,
prompt_token: torch.Tensor,
prompt_feat: torch.Tensor,
embedding: torch.Tensor,
):
def _make_len(ts:torch.Tensor):
return torch.tensor([ts.shape[1]], dtype=torch.long, device=ts.device)
# [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
token = self._reshape(
token.squeeze().tolist()
).unsqueeze(0)
prompt_token = self._reshape(
prompt_token.squeeze().tolist()
).unsqueeze(0)
# align prompt mel
prompt_feat = F.interpolate(
prompt_feat.transpose(1, 2),
size=prompt_token.shape[1]*2,
mode='nearest'
).transpose(1, 2)
token, prompt_token, prompt_feat, embedding = map(
lambda ts: ts.to(self.device),
(token, prompt_token, prompt_feat, embedding),
)
# inference flow
mel = self.flow.inference(
token,
_make_len(token),
prompt_token,
_make_len(prompt_token),
prompt_feat.to(self.dtype),
_make_len(prompt_feat),
embedding.to(self.dtype),
self.n_timesteps,
)
# inference vocoder
with torch.no_grad():
if isinstance(self.hift, BigVGAN):
mel = torch.nn.functional.pad(mel, (3,3), mode='reflect')
speech = self.hift.inference(mel).squeeze(0) # [1,1,T] -> [1,T]
elif isinstance(self.hift, HiFTGenerator):
speech, _ = self.hift.inference(mel)
else:
raise ValueError(f'unsupported vocoder type {type(self.hift)}')
speech = speech.cpu().to(torch.float32)
return speech
"""NOTE Internal method, do not call this method!
Handle device & dtype transfer.
"""
def _setup_cache(self,
token: torch.Tensor,
mel: torch.Tensor,
spk: torch.Tensor,
session_id: str,
):
# att/cnn-cache
with self.setup_lock:
cache = self.flow.setup_cache(
token.to(self.device),
mel.to(self.device, self.dtype),
spk.to(self.device, self.dtype),
self.n_timesteps,
)
# 对 cache dict 里的每个 tensor 做 clone().detach()
cache = {k: (v.clone().detach() if isinstance(v, torch.Tensor) else v) for k, v in cache.items()}
self.chunk_cache_dict[session_id] = cache
self.estimator_prompt_length_dict[session_id] = mel.shape[1]
self.b_first_chunk_dict[session_id] = True
# spk embedding
self.spk_embedding_cache_dict[session_id] = spk.to(self.device, self.dtype).clone()
# hift cache
self.hift_cache_dict[session_id] = dict(
mel = torch.zeros(1, mel.shape[2], 0, device=self.device, dtype=self.dtype),
source = torch.zeros(1, 1, 0, device=self.device, dtype=self.dtype),
speech = torch.zeros(1, 0, device=self.device, dtype=self.dtype),
)
return
"""NOTE Internal method, do not call this method!
Handle device transfer.
"""
def _token2wav_stream(self,
token: torch.Tensor,
session_id: str,
last_chunk: bool,
):
assert session_id in self.chunk_cache_dict, 'call setup_cache first to obtain cache'
# fetch cache & speaker embedding
cache = self.chunk_cache_dict[session_id]
embedding = self.spk_embedding_cache_dict[session_id]
# inference this chunk
mel, new_cache = self.flow.inference_chunk(
token.to(self.device), # int64
embedding,
cache,
last_chunk,
self.n_timesteps,
)
# NOTE(sfy) truncate attention cache (prompt_length + 2s left context)
left_context_length = int(2 * 48)
estimator_att_cache = new_cache['estimator_att_cache']
prompt_length = self.estimator_prompt_length_dict[session_id]
if estimator_att_cache.shape[4] > (prompt_length + left_context_length):
new_cache['estimator_att_cache'] = torch.cat([
estimator_att_cache[:, :, :, :, :left_context_length],
estimator_att_cache[:, :, :, :, -prompt_length:],
], dim=4)
self.chunk_cache_dict[session_id] = {k: v.clone().detach() for k, v in new_cache.items()}
# vocoder cache
hift_cache_mel = self.hift_cache_dict[session_id]['mel']
hift_cache_source = self.hift_cache_dict[session_id]['source']
hift_cache_speech = self.hift_cache_dict[session_id]['speech']
mel = torch.concat([hift_cache_mel, mel], dim=2)
# inference vocoder
with torch.no_grad():
if isinstance(self.hift, BigVGAN):
if self.b_first_chunk_dict[session_id] and mel.shape[2] > 0:
print(f'[INFO] first chunk mel len: {mel.shape[2]}')
self.b_first_chunk_dict[session_id] = False
mel = F.pad(mel, (3,0), mode='reflect')
if last_chunk:
mel = F.pad(mel, (0,3), mode='reflect')
speech = self.hift.inference(mel).squeeze(0) # [1,1,T] -> [1,T]
source = torch.zeros(1, 1, 0, device=self.device, dtype=self.dtype) # dummy source
elif isinstance(self.hift, HiFTGenerator):
speech, source = self.hift.inference(mel, hift_cache_source)
# overlap speech smooth
if hift_cache_speech.shape[-1] > 0:
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
# update vocoder cache
self.hift_cache_dict[session_id] = dict(
mel = mel[..., -self.mel_cache_len:].clone().detach(),
source = source[:, :, -self.source_cache_len:].clone().detach(),
speech = speech[:, -self.source_cache_len:].clone().detach(),
)
if not last_chunk:
speech = speech[:, :-self.source_cache_len]
return speech.cpu().to(torch.float32)
@staticmethod
def _reshape(mix_seq: List[int])->torch.Tensor:
# assert len(mix_seq)%5 == 0, len(mix_seq)
# NOTE add padding to avoid assert error
# (don't care the final speech as it's wrong anyway)
if len(mix_seq)%5 > 0:
pad_len = 5-(len(mix_seq)%5)
mix_seq += [0, 0, 0, 1024, 1024, 1024][-pad_len:]
num_groups = len(mix_seq) // 5
vq02 = reduce(
lambda x, y: x+y,
[mix_seq[i*5: i*5+2] + [1024] for i in range(num_groups)]
)
vq06 = reduce(
lambda x, y: x+y,
[mix_seq[i*5+2: i*5+5] for i in range(num_groups)]
)
vq0206 = torch.stack([
torch.tensor(vq02, dtype=torch.long),
torch.tensor(vq06, dtype=torch.long)-1024+1025,
], dim=1)
return vq0206
"""NOTE Stream interface. Called whenever one token is generated.
NOTE(sfy) not need to transfer device or dtype
This is a specialized version for vq0206, we change the mixed sequence to time-aligned sequence.
eg.: [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
"""
def token2wav_stream(self,
token: List[int], # vq0206 mixed seq tokens
prompt_token: torch.Tensor,
prompt_feat: torch.Tensor,
embedding: torch.Tensor,
session_id: str,
last_chunk: bool,
)->Optional[torch.Tensor]:
# FIXME hard coded
def _mixed_len(l:int):
return (l // 3) * 5
# init chunk size tracking
if session_id not in self.chunk_size_dict:
self.chunk_size_dict[session_id] = deepcopy(self.chunk_size_list)
# add token
self.speech_token_dict[session_id].extend(token)
# waiting to setup cache
mix_token_lookahead_len = _mixed_len(self.token_lookahead)
if session_id not in self.chunk_cache_dict:
if len(self.speech_token_dict[session_id]) >= mix_token_lookahead_len:
# [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
lookahead_token = self._reshape(
self.speech_token_dict[session_id][:mix_token_lookahead_len]
).unsqueeze(0) # (1, t, 2)
prompt_token = self._reshape(
prompt_token.squeeze().tolist()
).unsqueeze(0)
# align prompt mel
prompt_feat = F.interpolate(
prompt_feat.transpose(1, 2),
size=prompt_token.shape[1]*2,
mode='nearest'
).transpose(1, 2)
self._setup_cache(
torch.cat([prompt_token, lookahead_token], dim=1),
prompt_feat,
embedding,
session_id,
)
return None
# deal with remaining tokens
if last_chunk:
this_token = self.speech_token_dict[session_id]
else:
# cut to one chunk
this_token = None
mix_token_chunk_len = _mixed_len(self.chunk_size_dict[session_id][0])
if len(self.speech_token_dict[session_id]) >= (mix_token_chunk_len+mix_token_lookahead_len):
this_token = self.speech_token_dict[session_id][:(mix_token_chunk_len+mix_token_lookahead_len)]
self.speech_token_dict[session_id] = self.speech_token_dict[session_id][mix_token_chunk_len:]
# go synthesis
if this_token is not None:
# [02, 02, 06, 06, 06] -> [[02, 02, PAD], [06, 06, 06]]
this_token = self._reshape(this_token).unsqueeze(0)
this_speech = self._token2wav_stream(
this_token,
session_id,
last_chunk,
)
# update chunk size
if len(self.chunk_size_dict[session_id]) > 1:
self.chunk_size_dict[session_id].pop(0)
else:
this_speech = None
# clear all caches
if last_chunk:
self.clean_up(session_id)
return this_speech
def clean_up(self, session_id: str):
self.chunk_size_dict.pop(session_id, None)
self.hift_cache_dict.pop(session_id, None)
self.chunk_cache_dict.pop(session_id, None)
self.estimator_prompt_length_dict.pop(session_id, None)
self.spk_embedding_cache_dict.pop(session_id, None)
self.speech_token_dict.pop(session_id, None)
torch.cuda.empty_cache()
"""Keep compatible with cosyvoice1
"""
class CosyVoice:
def __init__(self,
model_dir:str,
chunk_size_list: List = [15, 24, 48], # (0.6s, 0.96s, 1.92s)
mel_cache_len: int = 8,
n_timesteps: int = 10,
enable_cuda_graph: bool = True,
dtype=torch.float32,
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
# initiate streaming wrapper
self.model_dir = model_dir
with open("{}/cosyvoice.yaml".format(model_dir), "r") as f:
configs = load_hyperpyyaml(f)
flow, hift = configs['flow'], configs['hift']
mel_conf = configs['mel_conf']
flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location='cpu'))
flow = flow.eval()
hift.load_state_dict(torch.load(f"{model_dir}/hift.pt", map_location='cpu'))
hift = hift.eval()
cosy_impl = CosyVoice_stream_impl_(flow, hift, chunk_size_list, mel_cache_len, n_timesteps)
self.cosy_impl = cosy_impl.to(self.device, self.dtype)
if enable_cuda_graph:
self.cosy_impl.flow.scatter_cuda_graph(enable_cuda_graph)
self.cosy_impl.hift._init_cuda_graph()
# feature frontend
self.frontend = CosyVoiceFrontEnd(
mel_conf,
campplus_model='{}/campplus.onnx'.format(model_dir),
speech_tokenizer_model='{}/speech_tokenizer_v1.onnx'.format(model_dir),
)
# Just proxy
def token2wav_nonstream(self,
token: torch.Tensor, # vq0206 mixed seq
prompt_token: torch.Tensor,
prompt_feat: torch.Tensor,
embedding: torch.Tensor,
)->torch.Tensor:
return self.cosy_impl.token2wav_nonstream(
token,
prompt_token,
prompt_feat,
embedding,
)
# Just proxy
def token2wav_stream(self,
token: List[int], # vq0206 mixed seq tokens
prompt_token: torch.Tensor,
prompt_feat: torch.Tensor,
embedding: torch.Tensor,
session_id: str,
last_chunk: bool,
)->Optional[torch.Tensor]:
return self.cosy_impl.token2wav_stream(
token,
prompt_token,
prompt_feat,
embedding,
session_id,
last_chunk,
)
def clean_up(self, session_id: str):
self.cosy_impl.clean_up(session_id)