maya1-txt2speech / maya1 /pipeline.py
Rajkumar Pramanik "RJproz
initial commit
3c92819
raw
history blame
3.93 kB
"""
Maya1 Generation Pipeline
End-to-end pipeline for TTS generation (non-streaming).
"""
import asyncio
from typing import Optional, List
from vllm import SamplingParams
from .constants import (
CODE_END_TOKEN_ID,
CODE_START_TOKEN_ID,
SNAC_MIN_ID,
SNAC_MAX_ID,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_TOKENS,
DEFAULT_MIN_TOKENS,
DEFAULT_REPETITION_PENALTY,
DEFAULT_SEED,
)
class Maya1Pipeline:
"""End-to-end TTS pipeline for Maya1."""
def __init__(self, model, prompt_builder, snac_decoder):
"""
Initialize pipeline.
Args:
model: Maya1Model instance
prompt_builder: Maya1PromptBuilder instance
snac_decoder: SNACDecoder instance
"""
self.model = model
self.prompt_builder = prompt_builder
self.snac_decoder = snac_decoder
print(f"✅ Maya1Pipeline initialized")
async def generate_speech(
self,
description: str,
text: str,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
max_tokens: int = DEFAULT_MAX_TOKENS,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
seed: Optional[int] = None,
) -> Optional[bytes]:
"""
Generate speech audio (non-streaming).
Args:
description: Voice description
text: Text to synthesize (may include <emotion> tags)
temperature: Sampling temperature
top_p: Nucleus sampling
max_tokens: Max SNAC tokens to generate
repetition_penalty: Prevent loops
seed: Random seed for reproducibility
Returns:
Audio bytes (int16 PCM, 24kHz mono) or None if failed
"""
# Build prompt
prompt = self.prompt_builder.build_prefix(description, text)
# Configure sampling
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
min_tokens=DEFAULT_MIN_TOKENS,
repetition_penalty=repetition_penalty,
stop_token_ids=[CODE_END_TOKEN_ID],
seed=seed if seed is not None else DEFAULT_SEED,
)
# Generate tokens
outputs = await self.model.generate(prompt, sampling_params)
if not outputs or len(outputs) == 0:
return None
output = outputs[0]
generated_token_ids = output.outputs[0].token_ids
# Extract SNAC codes
snac_codes = self._extract_snac_codes(generated_token_ids)
if not snac_codes:
return None
# Decode to audio
audio_bytes = await self.snac_decoder.decode_single_async(snac_codes)
if audio_bytes:
frames = len(snac_codes) // 7
duration_sec = frames / 6.86
print(f" Generated {frames} frames (~{duration_sec:.1f}s audio)")
return audio_bytes
def _extract_snac_codes(self, token_ids: List[int]) -> List[int]:
# Find SOS and EOS positions
try:
sos_idx = token_ids.index(CODE_START_TOKEN_ID)
except ValueError:
sos_idx = -1
try:
eos_idx = token_ids.index(CODE_END_TOKEN_ID)
except ValueError:
eos_idx = len(token_ids)
# Extract tokens between SOS and EOS
if sos_idx >= 0:
snac_tokens = token_ids[sos_idx + 1:eos_idx]
else:
# If no SOS found, take everything before EOS
snac_tokens = token_ids[:eos_idx]
# Filter to only valid SNAC token IDs
snac_codes = [
token_id for token_id in snac_tokens
if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID
]
return snac_codes