maya1-txt2speech / maya1 /streaming_pipeline.py
Rajkumar Pramanik "RJproz
initial commit
3c92819
raw
history blame
5.98 kB
"""
Maya1 Streaming Pipeline - Sliding Window Approach
Implements sliding window technique for smooth streaming without artifacts.
"""
import asyncio
from typing import AsyncGenerator, Optional
from vllm import SamplingParams
from .constants import (
CODE_END_TOKEN_ID,
SNAC_MIN_ID,
SNAC_MAX_ID,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_TOKENS,
DEFAULT_MIN_TOKENS,
DEFAULT_REPETITION_PENALTY,
DEFAULT_SEED,
)
class Maya1SlidingWindowPipeline:
"""
Streaming TTS pipeline using sliding window approach.
Decodes overlapping 28-token windows (4 frames) and keeps only
the middle 2048 samples for smooth audio continuity.
"""
# Sliding window configuration
WINDOW_SIZE = 28 # 4 frames (7 tokens per frame)
YIELD_STRIDE = 7 # Yield every 1 frame
MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode
def __init__(self, model, prompt_builder, snac_decoder):
"""
Initialize sliding window streaming pipeline.
Args:
model: Maya1Model instance
prompt_builder: Maya1PromptBuilder instance
snac_decoder: SNACDecoder instance
"""
self.model = model
self.prompt_builder = prompt_builder
self.snac_decoder = snac_decoder
print(f"Sliding window pipeline initialized")
async def generate_speech_stream(
self,
description: str,
text: str,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
max_tokens: int = DEFAULT_MAX_TOKENS,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
seed: Optional[int] = None,
) -> AsyncGenerator[bytes, None]:
"""
Generate speech audio with sliding window streaming.
Args:
description: Voice description
text: Text to synthesize (may include <emotion> tags)
temperature: Sampling temperature
top_p: Nucleus sampling
max_tokens: Max SNAC tokens to generate
repetition_penalty: Prevent loops
seed: Random seed
Yields:
Audio bytes (int16 PCM, 24kHz mono)
"""
# Build prompt
prompt = self.prompt_builder.build_prefix(description, text)
# Configure sampling
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
min_tokens=DEFAULT_MIN_TOKENS,
repetition_penalty=repetition_penalty,
stop_token_ids=[CODE_END_TOKEN_ID],
seed=seed if seed is not None else DEFAULT_SEED,
)
# Stream tokens
snac_buffer = []
last_yield_position = 0
chunk_count = 0
total_tokens_seen = 0
async for output in self.model.generate_stream(prompt, sampling_params):
# Get latest generated tokens (cumulative list)
generated_token_ids = output.outputs[0].token_ids
# Process only NEW tokens since last iteration
new_tokens = generated_token_ids[total_tokens_seen:]
total_tokens_seen = len(generated_token_ids)
# Collect SNAC codes from new tokens
for token_id in new_tokens:
# Stop if we hit EOS
if token_id == CODE_END_TOKEN_ID:
break
# Only collect valid SNAC tokens
if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
snac_buffer.append(token_id)
# Yield audio when we have enough tokens for a window
while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE:
# Get window of 28 tokens
window_start = last_yield_position
window_end = window_start + self.WINDOW_SIZE
window = snac_buffer[window_start:window_end]
if len(window) == self.WINDOW_SIZE:
# Decode window to audio
audio_bytes = await self.snac_decoder.decode_single_async(window)
if audio_bytes:
# Extract middle portion of audio
audio_samples = len(audio_bytes) // 2
middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2
middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES
# Convert to byte positions
middle_start_byte = middle_start_sample * 2
middle_end_byte = middle_end_sample * 2
# Extract middle chunk
audio_chunk = audio_bytes[middle_start_byte:middle_end_byte]
chunk_count += 1
if chunk_count == 1:
print(f" First chunk ready")
yield audio_chunk
# Move forward by stride
last_yield_position += self.YIELD_STRIDE
# Check if generation is done
if CODE_END_TOKEN_ID in new_tokens:
break
# Final chunk: decode remaining tokens
remaining_tokens = len(snac_buffer) - last_yield_position
if remaining_tokens >= self.WINDOW_SIZE:
window = snac_buffer[-self.WINDOW_SIZE:]
audio_bytes = await self.snac_decoder.decode_single_async(window)
if audio_bytes:
yield audio_bytes[-self.MIDDLE_SAMPLES * 2:]
frames = len(snac_buffer) // 7
duration = frames / 6.86
print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)")