Spaces:
Runtime error
Runtime error
| from .vocos.decoder import SopranoDecoder | |
| from .utils.text import clean_text | |
| import torch | |
| import re | |
| from unidecode import unidecode | |
| from scipy.io import wavfile | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import time | |
| class SopranoTTS: | |
| def __init__(self, | |
| backend='auto', | |
| device='cuda', | |
| cache_size_mb=10, | |
| decoder_batch_size=1): | |
| RECOGNIZED_DEVICES = ['cuda'] | |
| RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers'] | |
| assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}" | |
| if backend == 'auto': | |
| if device == 'cpu': | |
| backend = 'transformers' | |
| else: | |
| try: | |
| import lmdeploy | |
| backend = 'lmdeploy' | |
| except ImportError: | |
| backend='transformers' | |
| print(f"Using backend {backend}.") | |
| assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}" | |
| if backend == 'lmdeploy': | |
| from .backends.lmdeploy import LMDeployModel | |
| self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb) | |
| elif backend == 'transformers': | |
| from .backends.transformers import TransformersModel | |
| self.pipeline = TransformersModel(device=device) | |
| self.decoder = SopranoDecoder().cuda() | |
| decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth') | |
| self.decoder.load_state_dict(torch.load(decoder_path)) | |
| self.decoder_batch_size=decoder_batch_size | |
| self.RECEPTIVE_FIELD = 4 # Decoder receptive field | |
| self.TOKEN_SIZE = 2048 # Number of samples per audio token | |
| self.infer("Hello world!") # warmup | |
| def _preprocess_text(self, texts, min_length=30): | |
| ''' | |
| adds prompt format and sentence/part index | |
| Enforces a minimum sentence length by merging short sentences. | |
| ''' | |
| res = [] | |
| for text_idx, text in enumerate(texts): | |
| text = text.strip() | |
| cleaned_text = clean_text(text) | |
| sentences = re.split(r"(?<=[.!?])\s+", cleaned_text) | |
| processed = [] | |
| for sentence in sentences: | |
| processed.append({ | |
| "text": sentence, | |
| "text_idx": text_idx, | |
| }) | |
| if min_length > 0 and len(processed) > 1: | |
| merged = [] | |
| i = 0 | |
| while i < len(processed): | |
| cur = processed[i] | |
| if len(cur["text"]) < min_length: | |
| if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip() | |
| else: | |
| if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip() | |
| else: merged.append(cur) | |
| else: merged.append(cur) | |
| i += 1 | |
| processed = merged | |
| sentence_idxes = {} | |
| for item in processed: | |
| if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0 | |
| res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']])) | |
| sentence_idxes[item['text_idx']] += 1 | |
| return res | |
| def infer(self, | |
| text, | |
| out_path=None, | |
| top_p=0.95, | |
| temperature=0.3, | |
| repetition_penalty=1.2): | |
| results = self.infer_batch([text], | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| out_dir=None)[0] | |
| if out_path: | |
| wavfile.write(out_path, 32000, results.cpu().numpy()) | |
| return results | |
| def infer_batch(self, | |
| texts, | |
| out_dir=None, | |
| top_p=0.95, | |
| temperature=0.3, | |
| repetition_penalty=1.2): | |
| sentence_data = self._preprocess_text(texts) | |
| prompts = list(map(lambda x: x[0], sentence_data)) | |
| responses = self.pipeline.infer(prompts, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty) | |
| hidden_states = [] | |
| for i, response in enumerate(responses): | |
| if response['finish_reason'] != 'stop': | |
| print(f"Warning: some sentences did not complete generation, likely due to hallucination.") | |
| hidden_state = response['hidden_state'] | |
| hidden_states.append(hidden_state) | |
| combined = list(zip(hidden_states, sentence_data)) | |
| combined.sort(key=lambda x: -x[0].size(0)) | |
| hidden_states, sentence_data = zip(*combined) | |
| num_texts = len(texts) | |
| audio_concat = [[] for _ in range(num_texts)] | |
| for sentence in sentence_data: | |
| audio_concat[sentence[1]].append(None) | |
| for idx in range(0, len(hidden_states), self.decoder_batch_size): | |
| batch_hidden_states = [] | |
| lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size])) | |
| N = len(lengths) | |
| for i in range(N): | |
| batch_hidden_states.append(torch.cat([ | |
| torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'), | |
| hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32), | |
| ], dim=2)) | |
| batch_hidden_states = torch.cat(batch_hidden_states) | |
| with torch.no_grad(): | |
| audio = self.decoder(batch_hidden_states) | |
| for i in range(N): | |
| text_id = sentence_data[idx+i][1] | |
| sentence_id = sentence_data[idx+i][2] | |
| audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):] | |
| audio_concat = [torch.cat(x).cpu() for x in audio_concat] | |
| if out_dir: | |
| os.makedirs(out_dir, exist_ok=True) | |
| for i in range(len(audio_concat)): | |
| wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy()) | |
| return audio_concat | |
| def infer_stream(self, | |
| text, | |
| chunk_size=1, | |
| top_p=0.95, | |
| temperature=0.3, | |
| repetition_penalty=1.2): | |
| start_time = time.time() | |
| sentence_data = self._preprocess_text([text]) | |
| first_chunk = True | |
| for sentence, _, _ in sentence_data: | |
| responses = self.pipeline.stream_infer(sentence, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty) | |
| hidden_states_buffer = [] | |
| chunk_counter = chunk_size | |
| for token in responses: | |
| finished = token['finish_reason'] is not None | |
| if not finished: hidden_states_buffer.append(token['hidden_state'][-1]) | |
| hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):] | |
| if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size: | |
| if finished or chunk_counter == chunk_size: | |
| batch_hidden_states = torch.stack(hidden_states_buffer) | |
| inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32) | |
| with torch.no_grad(): | |
| audio = self.decoder(inp)[0] | |
| if finished: | |
| audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):] | |
| else: | |
| audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)] | |
| chunk_counter = 0 | |
| if first_chunk: | |
| print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms") | |
| first_chunk = False | |
| yield audio_chunk.cpu() | |
| chunk_counter += 1 | |