ekwek's picture
Upload 11 files
44a4b98 verified
from .vocos.decoder import SopranoDecoder
from .utils.text import clean_text
import torch
import re
from unidecode import unidecode
from scipy.io import wavfile
from huggingface_hub import hf_hub_download
import os
import time
class SopranoTTS:
def __init__(self,
backend='auto',
device='cuda',
cache_size_mb=10,
decoder_batch_size=1):
RECOGNIZED_DEVICES = ['cuda']
RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
if backend == 'auto':
if device == 'cpu':
backend = 'transformers'
else:
try:
import lmdeploy
backend = 'lmdeploy'
except ImportError:
backend='transformers'
print(f"Using backend {backend}.")
assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
if backend == 'lmdeploy':
from .backends.lmdeploy import LMDeployModel
self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
elif backend == 'transformers':
from .backends.transformers import TransformersModel
self.pipeline = TransformersModel(device=device)
self.decoder = SopranoDecoder().cuda()
decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
self.decoder.load_state_dict(torch.load(decoder_path))
self.decoder_batch_size=decoder_batch_size
self.RECEPTIVE_FIELD = 4 # Decoder receptive field
self.TOKEN_SIZE = 2048 # Number of samples per audio token
self.infer("Hello world!") # warmup
def _preprocess_text(self, texts, min_length=30):
'''
adds prompt format and sentence/part index
Enforces a minimum sentence length by merging short sentences.
'''
res = []
for text_idx, text in enumerate(texts):
text = text.strip()
cleaned_text = clean_text(text)
sentences = re.split(r"(?<=[.!?])\s+", cleaned_text)
processed = []
for sentence in sentences:
processed.append({
"text": sentence,
"text_idx": text_idx,
})
if min_length > 0 and len(processed) > 1:
merged = []
i = 0
while i < len(processed):
cur = processed[i]
if len(cur["text"]) < min_length:
if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip()
else:
if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip()
else: merged.append(cur)
else: merged.append(cur)
i += 1
processed = merged
sentence_idxes = {}
for item in processed:
if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0
res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']]))
sentence_idxes[item['text_idx']] += 1
return res
def infer(self,
text,
out_path=None,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
results = self.infer_batch([text],
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
out_dir=None)[0]
if out_path:
wavfile.write(out_path, 32000, results.cpu().numpy())
return results
def infer_batch(self,
texts,
out_dir=None,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
sentence_data = self._preprocess_text(texts)
prompts = list(map(lambda x: x[0], sentence_data))
responses = self.pipeline.infer(prompts,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
hidden_states = []
for i, response in enumerate(responses):
if response['finish_reason'] != 'stop':
print(f"Warning: some sentences did not complete generation, likely due to hallucination.")
hidden_state = response['hidden_state']
hidden_states.append(hidden_state)
combined = list(zip(hidden_states, sentence_data))
combined.sort(key=lambda x: -x[0].size(0))
hidden_states, sentence_data = zip(*combined)
num_texts = len(texts)
audio_concat = [[] for _ in range(num_texts)]
for sentence in sentence_data:
audio_concat[sentence[1]].append(None)
for idx in range(0, len(hidden_states), self.decoder_batch_size):
batch_hidden_states = []
lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size]))
N = len(lengths)
for i in range(N):
batch_hidden_states.append(torch.cat([
torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
], dim=2))
batch_hidden_states = torch.cat(batch_hidden_states)
with torch.no_grad():
audio = self.decoder(batch_hidden_states)
for i in range(N):
text_id = sentence_data[idx+i][1]
sentence_id = sentence_data[idx+i][2]
audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):]
audio_concat = [torch.cat(x).cpu() for x in audio_concat]
if out_dir:
os.makedirs(out_dir, exist_ok=True)
for i in range(len(audio_concat)):
wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy())
return audio_concat
def infer_stream(self,
text,
chunk_size=1,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
start_time = time.time()
sentence_data = self._preprocess_text([text])
first_chunk = True
for sentence, _, _ in sentence_data:
responses = self.pipeline.stream_infer(sentence,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
hidden_states_buffer = []
chunk_counter = chunk_size
for token in responses:
finished = token['finish_reason'] is not None
if not finished: hidden_states_buffer.append(token['hidden_state'][-1])
hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):]
if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
if finished or chunk_counter == chunk_size:
batch_hidden_states = torch.stack(hidden_states_buffer)
inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
with torch.no_grad():
audio = self.decoder(inp)[0]
if finished:
audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):]
else:
audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)]
chunk_counter = 0
if first_chunk:
print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
first_chunk = False
yield audio_chunk.cpu()
chunk_counter += 1