File size: 5,979 Bytes
3c92819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Maya1 Streaming Pipeline - Sliding Window Approach
Implements sliding window technique for smooth streaming without artifacts.
"""

import asyncio
from typing import AsyncGenerator, Optional
from vllm import SamplingParams

from .constants import (
    CODE_END_TOKEN_ID,
    SNAC_MIN_ID,
    SNAC_MAX_ID,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_MIN_TOKENS,
    DEFAULT_REPETITION_PENALTY,
    DEFAULT_SEED,
)


class Maya1SlidingWindowPipeline:
    """
    Streaming TTS pipeline using sliding window approach.
    Decodes overlapping 28-token windows (4 frames) and keeps only 
    the middle 2048 samples for smooth audio continuity.
    """
    
    # Sliding window configuration
    WINDOW_SIZE = 28  # 4 frames (7 tokens per frame)
    YIELD_STRIDE = 7  # Yield every 1 frame
    MIDDLE_SAMPLES = 2048  # Keep middle 2048 samples from each decode
    
    def __init__(self, model, prompt_builder, snac_decoder):
        """
        Initialize sliding window streaming pipeline.
        
        Args:
            model: Maya1Model instance
            prompt_builder: Maya1PromptBuilder instance
            snac_decoder: SNACDecoder instance
        """
        self.model = model
        self.prompt_builder = prompt_builder
        self.snac_decoder = snac_decoder
        print(f"Sliding window pipeline initialized")
    
    async def generate_speech_stream(
        self,
        description: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech audio with sliding window streaming.
        
        Args:
            description: Voice description
            text: Text to synthesize (may include <emotion> tags)
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
            seed: Random seed
        
        Yields:
            Audio bytes (int16 PCM, 24kHz mono)
        """
        # Build prompt
        prompt = self.prompt_builder.build_prefix(description, text)
        
        # Configure sampling
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],
            seed=seed if seed is not None else DEFAULT_SEED,
        )
        
        # Stream tokens
        snac_buffer = []
        last_yield_position = 0
        chunk_count = 0
        total_tokens_seen = 0
        
        async for output in self.model.generate_stream(prompt, sampling_params):
            # Get latest generated tokens (cumulative list)
            generated_token_ids = output.outputs[0].token_ids
            
            # Process only NEW tokens since last iteration
            new_tokens = generated_token_ids[total_tokens_seen:]
            total_tokens_seen = len(generated_token_ids)
            
            # Collect SNAC codes from new tokens
            for token_id in new_tokens:
                # Stop if we hit EOS
                if token_id == CODE_END_TOKEN_ID:
                    break
                
                # Only collect valid SNAC tokens
                if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
                    snac_buffer.append(token_id)
            
            # Yield audio when we have enough tokens for a window
            while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE:
                # Get window of 28 tokens
                window_start = last_yield_position
                window_end = window_start + self.WINDOW_SIZE
                window = snac_buffer[window_start:window_end]
                
                if len(window) == self.WINDOW_SIZE:
                    # Decode window to audio
                    audio_bytes = await self.snac_decoder.decode_single_async(window)
                    
                    if audio_bytes:
                        # Extract middle portion of audio
                        audio_samples = len(audio_bytes) // 2
                        middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2
                        middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES
                        
                        # Convert to byte positions
                        middle_start_byte = middle_start_sample * 2
                        middle_end_byte = middle_end_sample * 2
                        
                        # Extract middle chunk
                        audio_chunk = audio_bytes[middle_start_byte:middle_end_byte]
                        
                        chunk_count += 1
                        if chunk_count == 1:
                            print(f" First chunk ready")
                        
                        yield audio_chunk
                
                # Move forward by stride
                last_yield_position += self.YIELD_STRIDE
            
            # Check if generation is done
            if CODE_END_TOKEN_ID in new_tokens:
                break
        
        # Final chunk: decode remaining tokens
        remaining_tokens = len(snac_buffer) - last_yield_position
        if remaining_tokens >= self.WINDOW_SIZE:
            window = snac_buffer[-self.WINDOW_SIZE:]
            audio_bytes = await self.snac_decoder.decode_single_async(window)
            if audio_bytes:
                yield audio_bytes[-self.MIDDLE_SAMPLES * 2:]
        
        frames = len(snac_buffer) // 7
        duration = frames / 6.86
        print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)")