Spaces:
Runtime error
Runtime error
File size: 8,302 Bytes
63d4ab6 44a4b98 63d4ab6 a459f6e 63d4ab6 a459f6e 63d4ab6 44a4b98 a459f6e 44a4b98 a459f6e 44a4b98 a459f6e 63d4ab6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from .vocos.decoder import SopranoDecoder
from .utils.text import clean_text
import torch
import re
from unidecode import unidecode
from scipy.io import wavfile
from huggingface_hub import hf_hub_download
import os
import time
class SopranoTTS:
def __init__(self,
backend='auto',
device='cuda',
cache_size_mb=10,
decoder_batch_size=1):
RECOGNIZED_DEVICES = ['cuda']
RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
if backend == 'auto':
if device == 'cpu':
backend = 'transformers'
else:
try:
import lmdeploy
backend = 'lmdeploy'
except ImportError:
backend='transformers'
print(f"Using backend {backend}.")
assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
if backend == 'lmdeploy':
from .backends.lmdeploy import LMDeployModel
self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
elif backend == 'transformers':
from .backends.transformers import TransformersModel
self.pipeline = TransformersModel(device=device)
self.decoder = SopranoDecoder().cuda()
decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
self.decoder.load_state_dict(torch.load(decoder_path))
self.decoder_batch_size=decoder_batch_size
self.RECEPTIVE_FIELD = 4 # Decoder receptive field
self.TOKEN_SIZE = 2048 # Number of samples per audio token
self.infer("Hello world!") # warmup
def _preprocess_text(self, texts, min_length=30):
'''
adds prompt format and sentence/part index
Enforces a minimum sentence length by merging short sentences.
'''
res = []
for text_idx, text in enumerate(texts):
text = text.strip()
cleaned_text = clean_text(text)
sentences = re.split(r"(?<=[.!?])\s+", cleaned_text)
processed = []
for sentence in sentences:
processed.append({
"text": sentence,
"text_idx": text_idx,
})
if min_length > 0 and len(processed) > 1:
merged = []
i = 0
while i < len(processed):
cur = processed[i]
if len(cur["text"]) < min_length:
if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip()
else:
if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip()
else: merged.append(cur)
else: merged.append(cur)
i += 1
processed = merged
sentence_idxes = {}
for item in processed:
if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0
res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']]))
sentence_idxes[item['text_idx']] += 1
return res
def infer(self,
text,
out_path=None,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
results = self.infer_batch([text],
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
out_dir=None)[0]
if out_path:
wavfile.write(out_path, 32000, results.cpu().numpy())
return results
def infer_batch(self,
texts,
out_dir=None,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
sentence_data = self._preprocess_text(texts)
prompts = list(map(lambda x: x[0], sentence_data))
responses = self.pipeline.infer(prompts,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
hidden_states = []
for i, response in enumerate(responses):
if response['finish_reason'] != 'stop':
print(f"Warning: some sentences did not complete generation, likely due to hallucination.")
hidden_state = response['hidden_state']
hidden_states.append(hidden_state)
combined = list(zip(hidden_states, sentence_data))
combined.sort(key=lambda x: -x[0].size(0))
hidden_states, sentence_data = zip(*combined)
num_texts = len(texts)
audio_concat = [[] for _ in range(num_texts)]
for sentence in sentence_data:
audio_concat[sentence[1]].append(None)
for idx in range(0, len(hidden_states), self.decoder_batch_size):
batch_hidden_states = []
lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size]))
N = len(lengths)
for i in range(N):
batch_hidden_states.append(torch.cat([
torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
], dim=2))
batch_hidden_states = torch.cat(batch_hidden_states)
with torch.no_grad():
audio = self.decoder(batch_hidden_states)
for i in range(N):
text_id = sentence_data[idx+i][1]
sentence_id = sentence_data[idx+i][2]
audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):]
audio_concat = [torch.cat(x).cpu() for x in audio_concat]
if out_dir:
os.makedirs(out_dir, exist_ok=True)
for i in range(len(audio_concat)):
wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy())
return audio_concat
def infer_stream(self,
text,
chunk_size=1,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
start_time = time.time()
sentence_data = self._preprocess_text([text])
first_chunk = True
for sentence, _, _ in sentence_data:
responses = self.pipeline.stream_infer(sentence,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
hidden_states_buffer = []
chunk_counter = chunk_size
for token in responses:
finished = token['finish_reason'] is not None
if not finished: hidden_states_buffer.append(token['hidden_state'][-1])
hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):]
if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
if finished or chunk_counter == chunk_size:
batch_hidden_states = torch.stack(hidden_states_buffer)
inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
with torch.no_grad():
audio = self.decoder(inp)[0]
if finished:
audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):]
else:
audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)]
chunk_counter = 0
if first_chunk:
print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
first_chunk = False
yield audio_chunk.cpu()
chunk_counter += 1
|