from .vocos.decoder import SopranoDecoder from .utils.text import clean_text import torch import re from unidecode import unidecode from scipy.io import wavfile from huggingface_hub import hf_hub_download import os import time class SopranoTTS: def __init__(self, backend='auto', device='cuda', cache_size_mb=10, decoder_batch_size=1): RECOGNIZED_DEVICES = ['cuda'] RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers'] assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}" if backend == 'auto': if device == 'cpu': backend = 'transformers' else: try: import lmdeploy backend = 'lmdeploy' except ImportError: backend='transformers' print(f"Using backend {backend}.") assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}" if backend == 'lmdeploy': from .backends.lmdeploy import LMDeployModel self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb) elif backend == 'transformers': from .backends.transformers import TransformersModel self.pipeline = TransformersModel(device=device) self.decoder = SopranoDecoder().cuda() decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth') self.decoder.load_state_dict(torch.load(decoder_path)) self.decoder_batch_size=decoder_batch_size self.RECEPTIVE_FIELD = 4 # Decoder receptive field self.TOKEN_SIZE = 2048 # Number of samples per audio token self.infer("Hello world!") # warmup def _preprocess_text(self, texts, min_length=30): ''' adds prompt format and sentence/part index Enforces a minimum sentence length by merging short sentences. ''' res = [] for text_idx, text in enumerate(texts): text = text.strip() cleaned_text = clean_text(text) sentences = re.split(r"(?<=[.!?])\s+", cleaned_text) processed = [] for sentence in sentences: processed.append({ "text": sentence, "text_idx": text_idx, }) if min_length > 0 and len(processed) > 1: merged = [] i = 0 while i < len(processed): cur = processed[i] if len(cur["text"]) < min_length: if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip() else: if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip() else: merged.append(cur) else: merged.append(cur) i += 1 processed = merged sentence_idxes = {} for item in processed: if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0 res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']])) sentence_idxes[item['text_idx']] += 1 return res def infer(self, text, out_path=None, top_p=0.95, temperature=0.3, repetition_penalty=1.2): results = self.infer_batch([text], top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, out_dir=None)[0] if out_path: wavfile.write(out_path, 32000, results.cpu().numpy()) return results def infer_batch(self, texts, out_dir=None, top_p=0.95, temperature=0.3, repetition_penalty=1.2): sentence_data = self._preprocess_text(texts) prompts = list(map(lambda x: x[0], sentence_data)) responses = self.pipeline.infer(prompts, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty) hidden_states = [] for i, response in enumerate(responses): if response['finish_reason'] != 'stop': print(f"Warning: some sentences did not complete generation, likely due to hallucination.") hidden_state = response['hidden_state'] hidden_states.append(hidden_state) combined = list(zip(hidden_states, sentence_data)) combined.sort(key=lambda x: -x[0].size(0)) hidden_states, sentence_data = zip(*combined) num_texts = len(texts) audio_concat = [[] for _ in range(num_texts)] for sentence in sentence_data: audio_concat[sentence[1]].append(None) for idx in range(0, len(hidden_states), self.decoder_batch_size): batch_hidden_states = [] lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size])) N = len(lengths) for i in range(N): batch_hidden_states.append(torch.cat([ torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'), hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32), ], dim=2)) batch_hidden_states = torch.cat(batch_hidden_states) with torch.no_grad(): audio = self.decoder(batch_hidden_states) for i in range(N): text_id = sentence_data[idx+i][1] sentence_id = sentence_data[idx+i][2] audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):] audio_concat = [torch.cat(x).cpu() for x in audio_concat] if out_dir: os.makedirs(out_dir, exist_ok=True) for i in range(len(audio_concat)): wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy()) return audio_concat def infer_stream(self, text, chunk_size=1, top_p=0.95, temperature=0.3, repetition_penalty=1.2): start_time = time.time() sentence_data = self._preprocess_text([text]) first_chunk = True for sentence, _, _ in sentence_data: responses = self.pipeline.stream_infer(sentence, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty) hidden_states_buffer = [] chunk_counter = chunk_size for token in responses: finished = token['finish_reason'] is not None if not finished: hidden_states_buffer.append(token['hidden_state'][-1]) hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):] if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size: if finished or chunk_counter == chunk_size: batch_hidden_states = torch.stack(hidden_states_buffer) inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32) with torch.no_grad(): audio = self.decoder(inp)[0] if finished: audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):] else: audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)] chunk_counter = 0 if first_chunk: print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms") first_chunk = False yield audio_chunk.cpu() chunk_counter += 1