|
|
""" |
|
|
Maya1 Generation Pipeline |
|
|
End-to-end pipeline for TTS generation (non-streaming). |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from typing import Optional, List |
|
|
from vllm import SamplingParams |
|
|
|
|
|
from .constants import ( |
|
|
CODE_END_TOKEN_ID, |
|
|
CODE_START_TOKEN_ID, |
|
|
SNAC_MIN_ID, |
|
|
SNAC_MAX_ID, |
|
|
DEFAULT_TEMPERATURE, |
|
|
DEFAULT_TOP_P, |
|
|
DEFAULT_MAX_TOKENS, |
|
|
DEFAULT_MIN_TOKENS, |
|
|
DEFAULT_REPETITION_PENALTY, |
|
|
DEFAULT_SEED, |
|
|
) |
|
|
|
|
|
|
|
|
class Maya1Pipeline: |
|
|
"""End-to-end TTS pipeline for Maya1.""" |
|
|
|
|
|
def __init__(self, model, prompt_builder, snac_decoder): |
|
|
""" |
|
|
Initialize pipeline. |
|
|
Args: |
|
|
model: Maya1Model instance |
|
|
prompt_builder: Maya1PromptBuilder instance |
|
|
snac_decoder: SNACDecoder instance |
|
|
""" |
|
|
self.model = model |
|
|
self.prompt_builder = prompt_builder |
|
|
self.snac_decoder = snac_decoder |
|
|
print(f"✅ Maya1Pipeline initialized") |
|
|
|
|
|
async def generate_speech( |
|
|
self, |
|
|
description: str, |
|
|
text: str, |
|
|
temperature: float = DEFAULT_TEMPERATURE, |
|
|
top_p: float = DEFAULT_TOP_P, |
|
|
max_tokens: int = DEFAULT_MAX_TOKENS, |
|
|
repetition_penalty: float = DEFAULT_REPETITION_PENALTY, |
|
|
seed: Optional[int] = None, |
|
|
) -> Optional[bytes]: |
|
|
""" |
|
|
Generate speech audio (non-streaming). |
|
|
Args: |
|
|
description: Voice description |
|
|
text: Text to synthesize (may include <emotion> tags) |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling |
|
|
max_tokens: Max SNAC tokens to generate |
|
|
repetition_penalty: Prevent loops |
|
|
seed: Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
Audio bytes (int16 PCM, 24kHz mono) or None if failed |
|
|
""" |
|
|
|
|
|
prompt = self.prompt_builder.build_prefix(description, text) |
|
|
|
|
|
|
|
|
sampling_params = SamplingParams( |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=max_tokens, |
|
|
min_tokens=DEFAULT_MIN_TOKENS, |
|
|
repetition_penalty=repetition_penalty, |
|
|
stop_token_ids=[CODE_END_TOKEN_ID], |
|
|
seed=seed if seed is not None else DEFAULT_SEED, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = await self.model.generate(prompt, sampling_params) |
|
|
|
|
|
if not outputs or len(outputs) == 0: |
|
|
return None |
|
|
|
|
|
output = outputs[0] |
|
|
generated_token_ids = output.outputs[0].token_ids |
|
|
|
|
|
|
|
|
snac_codes = self._extract_snac_codes(generated_token_ids) |
|
|
|
|
|
if not snac_codes: |
|
|
return None |
|
|
|
|
|
|
|
|
audio_bytes = await self.snac_decoder.decode_single_async(snac_codes) |
|
|
|
|
|
if audio_bytes: |
|
|
frames = len(snac_codes) // 7 |
|
|
duration_sec = frames / 6.86 |
|
|
print(f" Generated {frames} frames (~{duration_sec:.1f}s audio)") |
|
|
|
|
|
return audio_bytes |
|
|
|
|
|
def _extract_snac_codes(self, token_ids: List[int]) -> List[int]: |
|
|
|
|
|
try: |
|
|
sos_idx = token_ids.index(CODE_START_TOKEN_ID) |
|
|
except ValueError: |
|
|
sos_idx = -1 |
|
|
|
|
|
try: |
|
|
eos_idx = token_ids.index(CODE_END_TOKEN_ID) |
|
|
except ValueError: |
|
|
eos_idx = len(token_ids) |
|
|
|
|
|
|
|
|
if sos_idx >= 0: |
|
|
snac_tokens = token_ids[sos_idx + 1:eos_idx] |
|
|
else: |
|
|
|
|
|
snac_tokens = token_ids[:eos_idx] |
|
|
|
|
|
|
|
|
snac_codes = [ |
|
|
token_id for token_id in snac_tokens |
|
|
if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID |
|
|
] |
|
|
|
|
|
return snac_codes |
|
|
|