Spaces:
Running
Running
File size: 5,979 Bytes
3c92819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
Maya1 Streaming Pipeline - Sliding Window Approach
Implements sliding window technique for smooth streaming without artifacts.
"""
import asyncio
from typing import AsyncGenerator, Optional
from vllm import SamplingParams
from .constants import (
CODE_END_TOKEN_ID,
SNAC_MIN_ID,
SNAC_MAX_ID,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_TOKENS,
DEFAULT_MIN_TOKENS,
DEFAULT_REPETITION_PENALTY,
DEFAULT_SEED,
)
class Maya1SlidingWindowPipeline:
"""
Streaming TTS pipeline using sliding window approach.
Decodes overlapping 28-token windows (4 frames) and keeps only
the middle 2048 samples for smooth audio continuity.
"""
# Sliding window configuration
WINDOW_SIZE = 28 # 4 frames (7 tokens per frame)
YIELD_STRIDE = 7 # Yield every 1 frame
MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode
def __init__(self, model, prompt_builder, snac_decoder):
"""
Initialize sliding window streaming pipeline.
Args:
model: Maya1Model instance
prompt_builder: Maya1PromptBuilder instance
snac_decoder: SNACDecoder instance
"""
self.model = model
self.prompt_builder = prompt_builder
self.snac_decoder = snac_decoder
print(f"Sliding window pipeline initialized")
async def generate_speech_stream(
self,
description: str,
text: str,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
max_tokens: int = DEFAULT_MAX_TOKENS,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
seed: Optional[int] = None,
) -> AsyncGenerator[bytes, None]:
"""
Generate speech audio with sliding window streaming.
Args:
description: Voice description
text: Text to synthesize (may include <emotion> tags)
temperature: Sampling temperature
top_p: Nucleus sampling
max_tokens: Max SNAC tokens to generate
repetition_penalty: Prevent loops
seed: Random seed
Yields:
Audio bytes (int16 PCM, 24kHz mono)
"""
# Build prompt
prompt = self.prompt_builder.build_prefix(description, text)
# Configure sampling
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
min_tokens=DEFAULT_MIN_TOKENS,
repetition_penalty=repetition_penalty,
stop_token_ids=[CODE_END_TOKEN_ID],
seed=seed if seed is not None else DEFAULT_SEED,
)
# Stream tokens
snac_buffer = []
last_yield_position = 0
chunk_count = 0
total_tokens_seen = 0
async for output in self.model.generate_stream(prompt, sampling_params):
# Get latest generated tokens (cumulative list)
generated_token_ids = output.outputs[0].token_ids
# Process only NEW tokens since last iteration
new_tokens = generated_token_ids[total_tokens_seen:]
total_tokens_seen = len(generated_token_ids)
# Collect SNAC codes from new tokens
for token_id in new_tokens:
# Stop if we hit EOS
if token_id == CODE_END_TOKEN_ID:
break
# Only collect valid SNAC tokens
if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
snac_buffer.append(token_id)
# Yield audio when we have enough tokens for a window
while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE:
# Get window of 28 tokens
window_start = last_yield_position
window_end = window_start + self.WINDOW_SIZE
window = snac_buffer[window_start:window_end]
if len(window) == self.WINDOW_SIZE:
# Decode window to audio
audio_bytes = await self.snac_decoder.decode_single_async(window)
if audio_bytes:
# Extract middle portion of audio
audio_samples = len(audio_bytes) // 2
middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2
middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES
# Convert to byte positions
middle_start_byte = middle_start_sample * 2
middle_end_byte = middle_end_sample * 2
# Extract middle chunk
audio_chunk = audio_bytes[middle_start_byte:middle_end_byte]
chunk_count += 1
if chunk_count == 1:
print(f" First chunk ready")
yield audio_chunk
# Move forward by stride
last_yield_position += self.YIELD_STRIDE
# Check if generation is done
if CODE_END_TOKEN_ID in new_tokens:
break
# Final chunk: decode remaining tokens
remaining_tokens = len(snac_buffer) - last_yield_position
if remaining_tokens >= self.WINDOW_SIZE:
window = snac_buffer[-self.WINDOW_SIZE:]
audio_bytes = await self.snac_decoder.decode_single_async(window)
if audio_bytes:
yield audio_bytes[-self.MIDDLE_SAMPLES * 2:]
frames = len(snac_buffer) // 7
duration = frames / 6.86
print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)") |