Veena commited on
Commit
002a88c
·
1 Parent(s): 30a893c

Update Maya1 Gradio app with preset characters

Browse files
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+ .cache/
12
+ .pytest_cache/
13
+ *.wav
14
+ *.mp3
15
+ .DS_Store
16
+
maya1/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya1 TTS Inference System
3
+ Open-source inference for description-conditioned TTS with emotion control.
4
+ """
5
+
6
+ __version__ = "1.0.0"
7
+ __author__ = "Maya Research AI"
maya1/api_v2.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import wave
4
+ import time
5
+ from typing import Optional
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.responses import StreamingResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ from dotenv import load_dotenv
11
+
12
+ from .model_loader import Maya1Model
13
+ from .prompt_builder import Maya1PromptBuilder
14
+ from .snac_decoder import SNACDecoder
15
+ from .pipeline import Maya1Pipeline
16
+ from .streaming_pipeline import Maya1SlidingWindowPipeline
17
+ from .constants import (
18
+ DEFAULT_TEMPERATURE,
19
+ DEFAULT_TOP_P,
20
+ DEFAULT_MAX_TOKENS,
21
+ DEFAULT_REPETITION_PENALTY,
22
+ AUDIO_SAMPLE_RATE,
23
+ )
24
+
25
+ # Timeout settings (seconds)
26
+ GENERATE_TIMEOUT = 60
27
+
28
+ # Load environment variables
29
+ load_dotenv()
30
+
31
+ # Initialize FastAPI app
32
+ app = FastAPI(
33
+ title="Maya1 TTS API",
34
+ description="Open source TTS inference for Maya1",
35
+ version="1.0.0",
36
+ docs_url=None,
37
+ redoc_url=None,
38
+ )
39
+
40
+ app.add_middleware(
41
+ CORSMiddleware,
42
+ allow_origins=["*"],
43
+ allow_credentials=True,
44
+ allow_methods=["*"],
45
+ allow_headers=["*"],
46
+ )
47
+
48
+ # Global state
49
+ model = None
50
+ prompt_builder = None
51
+ snac_decoder = None
52
+ pipeline = None
53
+ streaming_pipeline = None
54
+
55
+
56
+ # ============================================================================
57
+ # Startup/Shutdown
58
+ # ============================================================================
59
+
60
+ @app.on_event("startup")
61
+ async def startup_event():
62
+ """Initialize model on startup."""
63
+ global model, prompt_builder, snac_decoder, pipeline, streaming_pipeline
64
+
65
+ print("\n" + "="*60)
66
+ print(" Starting Maya1 TTS API Server")
67
+ print("="*60 + "\n")
68
+
69
+ # Initialize components
70
+ model = Maya1Model()
71
+ prompt_builder = Maya1PromptBuilder(model.tokenizer, model)
72
+
73
+ # Initialize SNAC decoder
74
+ snac_decoder = SNACDecoder(enable_batching=True, max_batch_size=64, batch_timeout_ms=15)
75
+ await snac_decoder.start_batch_processor()
76
+
77
+ # Initialize pipelines
78
+ pipeline = Maya1Pipeline(model, prompt_builder, snac_decoder)
79
+ streaming_pipeline = Maya1SlidingWindowPipeline(model, prompt_builder, snac_decoder)
80
+
81
+ print("\n" + "="*60)
82
+ print("Maya1 TTS API Server Ready")
83
+ print("="*60 + "\n")
84
+
85
+
86
+ @app.on_event("shutdown")
87
+ async def shutdown_event():
88
+ """Cleanup on shutdown."""
89
+ print("\nShutting down Maya1 TTS API Server")
90
+
91
+ if snac_decoder and snac_decoder.is_running:
92
+ await snac_decoder.stop_batch_processor()
93
+
94
+
95
+ # ============================================================================
96
+ # Utility Functions
97
+ # ============================================================================
98
+
99
+ def create_wav_header(sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16, data_size: int = 0) -> bytes:
100
+ """Create WAV file header."""
101
+ import struct
102
+
103
+ byte_rate = sample_rate * channels * bits_per_sample // 8
104
+ block_align = channels * bits_per_sample // 8
105
+
106
+ header = struct.pack(
107
+ '<4sI4s4sIHHIIHH4sI',
108
+ b'RIFF',
109
+ 36 + data_size,
110
+ b'WAVE',
111
+ b'fmt ',
112
+ 16,
113
+ 1,
114
+ channels,
115
+ sample_rate,
116
+ byte_rate,
117
+ block_align,
118
+ bits_per_sample,
119
+ b'data',
120
+ data_size
121
+ )
122
+
123
+ return header
124
+
125
+
126
+ # ============================================================================
127
+ # Request/Response Models
128
+ # ============================================================================
129
+
130
+ class TTSRequest(BaseModel):
131
+ """TTS generation request."""
132
+ description: str = Field(
133
+ ...,
134
+ description="Voice description (e.g., 'Male voice in their 30s with american accent')"
135
+ )
136
+ text: str = Field(
137
+ ...,
138
+ description="Text to synthesize (can include <emotion> tags)"
139
+ )
140
+ temperature: Optional[float] = Field(
141
+ default=DEFAULT_TEMPERATURE,
142
+ description="Sampling temperature"
143
+ )
144
+ top_p: Optional[float] = Field(
145
+ default=DEFAULT_TOP_P,
146
+ description="Nucleus sampling"
147
+ )
148
+ max_tokens: Optional[int] = Field(
149
+ default=DEFAULT_MAX_TOKENS,
150
+ description="Maximum tokens to generate"
151
+ )
152
+ repetition_penalty: Optional[float] = Field(
153
+ default=DEFAULT_REPETITION_PENALTY,
154
+ description="Repetition penalty"
155
+ )
156
+ seed: Optional[int] = Field(
157
+ default=None,
158
+ description="Random seed for reproducibility",
159
+ ge=0,
160
+ )
161
+ stream: bool = Field(
162
+ default=False,
163
+ description="Stream audio (True) or return complete WAV (False)"
164
+ )
165
+
166
+
167
+ # ============================================================================
168
+ # Endpoints
169
+ # ============================================================================
170
+
171
+ @app.get("/")
172
+ async def root():
173
+ """Root endpoint."""
174
+ return {
175
+ "service": "Maya1 TTS API",
176
+ "version": "1.0.0",
177
+ "status": "running",
178
+ "model": "Maya1-Voice (open source)",
179
+ "endpoints": {
180
+ "generate": "/v1/tts/generate (POST)",
181
+ "health": "/health (GET)",
182
+ },
183
+ }
184
+
185
+
186
+ @app.get("/health")
187
+ async def health_check():
188
+ """Health check endpoint."""
189
+ return {
190
+ "status": "healthy",
191
+ "model": "Maya1-Voice",
192
+ "timestamp": time.time(),
193
+ }
194
+
195
+
196
+ # ============================================================================
197
+ # TTS Generation Endpoint
198
+ # ============================================================================
199
+
200
+ @app.post("/v1/tts/generate")
201
+ async def generate_tts(request: TTSRequest):
202
+ """Generate TTS audio from description and text."""
203
+
204
+ try:
205
+ # Route to streaming or non-streaming
206
+ if request.stream:
207
+ return await _generate_tts_streaming(
208
+ description=request.description,
209
+ text=request.text,
210
+ temperature=request.temperature,
211
+ top_p=request.top_p,
212
+ max_tokens=request.max_tokens,
213
+ repetition_penalty=request.repetition_penalty,
214
+ seed=request.seed,
215
+ )
216
+ else:
217
+ return await _generate_tts_complete(
218
+ description=request.description,
219
+ text=request.text,
220
+ temperature=request.temperature,
221
+ top_p=request.top_p,
222
+ max_tokens=request.max_tokens,
223
+ repetition_penalty=request.repetition_penalty,
224
+ seed=request.seed,
225
+ )
226
+
227
+ except HTTPException:
228
+ raise
229
+ except Exception as e:
230
+ print(f" Error: {e}")
231
+ raise HTTPException(status_code=500, detail=str(e))
232
+
233
+
234
+ async def _generate_tts_complete(
235
+ description: str,
236
+ text: str,
237
+ temperature: float,
238
+ top_p: float,
239
+ max_tokens: int,
240
+ repetition_penalty: float,
241
+ seed: Optional[int],
242
+ ):
243
+ """Generate complete WAV file (non-streaming)."""
244
+
245
+ try:
246
+ import asyncio
247
+
248
+ # Generate audio
249
+ audio_bytes = await asyncio.wait_for(
250
+ pipeline.generate_speech(
251
+ description=description,
252
+ text=text,
253
+ temperature=temperature,
254
+ top_p=top_p,
255
+ max_tokens=max_tokens,
256
+ repetition_penalty=repetition_penalty,
257
+ seed=seed,
258
+ ),
259
+ timeout=GENERATE_TIMEOUT
260
+ )
261
+
262
+ if audio_bytes is None:
263
+ raise Exception("Audio generation failed")
264
+
265
+ # Create WAV file
266
+ wav_buffer = io.BytesIO()
267
+ with wave.open(wav_buffer, 'wb') as wav_file:
268
+ wav_file.setnchannels(1)
269
+ wav_file.setsampwidth(2)
270
+ wav_file.setframerate(AUDIO_SAMPLE_RATE)
271
+ wav_file.writeframes(audio_bytes)
272
+
273
+ wav_buffer.seek(0)
274
+
275
+ return StreamingResponse(
276
+ wav_buffer,
277
+ media_type="audio/wav",
278
+ headers={"Content-Disposition": "attachment; filename=output.wav"}
279
+ )
280
+
281
+ except asyncio.TimeoutError:
282
+ raise HTTPException(status_code=504, detail="Generation timeout")
283
+
284
+
285
+ async def _generate_tts_streaming(
286
+ description: str,
287
+ text: str,
288
+ temperature: float,
289
+ top_p: float,
290
+ max_tokens: int,
291
+ repetition_penalty: float,
292
+ seed: Optional[int],
293
+ ):
294
+ """Generate streaming audio."""
295
+ start_time = time.time()
296
+ first_audio_time = None
297
+
298
+ async def audio_stream_generator():
299
+ """Generate audio stream with WAV header."""
300
+ nonlocal first_audio_time
301
+
302
+ # Send WAV header first
303
+ yield create_wav_header(sample_rate=AUDIO_SAMPLE_RATE, channels=1, bits_per_sample=16)
304
+
305
+ # Stream audio chunks
306
+ async for audio_chunk in streaming_pipeline.generate_speech_stream(
307
+ description=description,
308
+ text=text,
309
+ temperature=temperature,
310
+ top_p=top_p,
311
+ max_tokens=max_tokens,
312
+ repetition_penalty=repetition_penalty,
313
+ seed=seed,
314
+ ):
315
+ if first_audio_time is None:
316
+ first_audio_time = time.time()
317
+ ttfb_ms = (first_audio_time - start_time) * 1000
318
+ print(f"⏱️ TTFB: {ttfb_ms:.1f}ms")
319
+
320
+ yield audio_chunk
321
+
322
+ try:
323
+ return StreamingResponse(
324
+ audio_stream_generator(),
325
+ media_type="audio/wav",
326
+ headers={"Cache-Control": "no-cache"}
327
+ )
328
+
329
+ except Exception as e:
330
+ print(f"Streaming error: {e}")
331
+ raise HTTPException(status_code=500, detail=str(e))
332
+
333
+
334
+ # For running directly
335
+ if __name__ == "__main__":
336
+ import uvicorn
337
+ uvicorn.run(
338
+ app,
339
+ host="0.0.0.0",
340
+ port=8000,
341
+ log_level="info"
342
+ )
maya1/constants.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya1 Constants
3
+ Token IDs and special tokens used in the model.
4
+ Matches training configuration exactly.
5
+ """
6
+
7
+ # Special control tokens
8
+ SOH_ID = 128259 # Start of Human turn
9
+ EOH_ID = 128260 # End of Human turn
10
+ SOA_ID = 128261 # Start of AI turn
11
+ EOA_ID = 128262 # End of AI turn (not used in maya1)
12
+ PAD_ID = 128263 # Padding token
13
+
14
+ # Text tokens
15
+ BOS_ID = 128000 # Begin of sequence (Llama BOS)
16
+ TEXT_EOT_ID = 128009 # End of text (appears in prefix, not a stop token!)
17
+
18
+ # Audio tokens
19
+ CODE_START_TOKEN_ID = 128257 # SOS - Start of Speech
20
+ CODE_END_TOKEN_ID = 128258 # EOS - End of Speech (audio stop token)
21
+ CODE_TOKEN_OFFSET = 128266 # Start of SNAC codes
22
+
23
+ # SNAC token range
24
+ SNAC_MIN_ID = 128266
25
+ SNAC_MAX_ID = 156937 # 128266 + (7 * 4096) - 1
26
+
27
+ # Stop tokens for generation
28
+ # CRITICAL: Only use CODE_END_TOKEN_ID (128258) for audio generation
29
+ # TEXT_EOT_ID (128009) appears in prefix and should NOT stop generation
30
+ TRAINING_STOP_TOKEN_IDS = [CODE_END_TOKEN_ID] # [128258]
31
+ ALL_POSSIBLE_STOP_TOKENS = [TEXT_EOT_ID, CODE_END_TOKEN_ID] # For reference only
32
+
33
+ # 20 Extended Emotion Tags (must be single tokens)
34
+ ALL_EMOTION_TAGS = [
35
+ '<angry>',
36
+ '<appalled>',
37
+ '<chuckle>',
38
+ '<cry>',
39
+ '<curious>',
40
+ '<disappointed>',
41
+ '<excited>',
42
+ '<exhale>',
43
+ '<gasp>',
44
+ '<giggle>',
45
+ '<gulp>',
46
+ '<laugh>',
47
+ '<laugh_harder>',
48
+ '<mischievous>',
49
+ '<sarcastic>',
50
+ '<scream>',
51
+ '<sigh>',
52
+ '<sing>',
53
+ '<snort>',
54
+ '<whisper>',
55
+ ]
56
+
57
+ # Model configuration
58
+ DEFAULT_MODEL_PATH = "maya-research/maya1"
59
+ DEFAULT_CHECKPOINT = "checkpoint-25000"
60
+ DEFAULT_MAX_MODEL_LEN = 8192
61
+
62
+ # SNAC configuration
63
+ SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz"
64
+ SNAC_SAMPLE_RATE = 24000
65
+ SNAC_TOKENS_PER_FRAME = 7
66
+ SNAC_LEVELS = 3
67
+
68
+ # Audio configuration
69
+ AUDIO_SAMPLE_RATE = 24000
70
+ AUDIO_CHANNELS = 1
71
+ AUDIO_BITS_PER_SAMPLE = 16
72
+
73
+ # Generation defaults
74
+ DEFAULT_TEMPERATURE = 0.4 # Lower temp for more stable generation
75
+ DEFAULT_TOP_P = 0.9
76
+ DEFAULT_MAX_TOKENS = 2048 # Reasonable default for most use cases
77
+ DEFAULT_MIN_TOKENS = 28 # At least 4 SNAC frames
78
+ DEFAULT_REPETITION_PENALTY = 1.1
79
+ DEFAULT_SEED = None # None = random, set integer for reproducibility
80
+
81
+ # IMPORTANT: Emotion tags consume audio time!
82
+ # <laugh> = ~4-6 seconds (~300-400 tokens)
83
+ # <excited>, <chuckle> = ~1-2 seconds (~50-150 tokens)
84
+
85
+ # Recommended max_tokens by use case:
86
+ # - Short phrases (< 10 words): 150-250 tokens (~3-5s)
87
+ # - Medium text (10-30 words): 250-500 tokens (~5-10s)
88
+ # - Long text (30+ words): 500-1500 tokens (~10-30s)
89
+ # - Very long text: 1500-2000 tokens (~30-42s)
90
+ # Note: 1 second ≈ 48 tokens (7 tokens/frame * 6.86 frames/sec)
91
+
92
+ # Streaming configuration
93
+ STREAM_BUFFER_SIZE = 28 # 4 frames (process every 28 tokens)
94
+ SNAC_BATCH_SIZE = 64
95
+ SNAC_BATCH_TIMEOUT_MS = 15
maya1/model_loader.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya1 Model Loader
3
+ Loads Maya1 model with vLLM engine and validates emotion tags.
4
+ """
5
+
6
+ import os
7
+ from transformers import AutoTokenizer
8
+ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
9
+ from .constants import (
10
+ ALL_EMOTION_TAGS,
11
+ DEFAULT_MAX_MODEL_LEN,
12
+ SOH_ID, EOH_ID, SOA_ID, BOS_ID, TEXT_EOT_ID, CODE_START_TOKEN_ID,
13
+ )
14
+
15
+
16
+ class Maya1Model:
17
+ """Maya1 TTS Model with vLLM inference engine."""
18
+
19
+ def __init__(
20
+ self,
21
+ model_path: str = None,
22
+ dtype: str = "bfloat16",
23
+ max_model_len: int = DEFAULT_MAX_MODEL_LEN,
24
+ gpu_memory_utilization: float = 0.85,
25
+ tensor_parallel_size: int = 1,
26
+ **engine_kwargs
27
+ ):
28
+ """
29
+ Initialize Maya1 model with vLLM.
30
+
31
+ Args:
32
+ model_path: Path to checkpoint (local or HF repo)
33
+ dtype: Model precision (bfloat16 recommended)
34
+ max_model_len: Maximum sequence length
35
+ gpu_memory_utilization: GPU memory fraction
36
+ tensor_parallel_size: Number of GPUs
37
+ """
38
+ # Use provided path or environment variable or default
39
+ if model_path is None:
40
+ model_path = os.environ.get(
41
+ 'MAYA1_MODEL_PATH',
42
+ os.path.expanduser('~/models/maya1-voice')
43
+ )
44
+
45
+ self.model_path = model_path
46
+ self.dtype = dtype
47
+
48
+ print(f"Initializing Maya1 Model")
49
+ print(f"Model: {model_path}")
50
+
51
+ # Load tokenizer
52
+ self.tokenizer = AutoTokenizer.from_pretrained(
53
+ model_path,
54
+ trust_remote_code=True,
55
+ )
56
+
57
+ print(f"Tokenizer loaded: {len(self.tokenizer)} tokens")
58
+
59
+ # Validate emotion tags
60
+ self._validate_emotion_tags()
61
+
62
+ # Precompute special token strings
63
+ self._init_special_tokens()
64
+
65
+ # Initialize vLLM engine
66
+ print(f"Initializing vLLM engine...")
67
+ engine_args = AsyncEngineArgs(
68
+ model=model_path,
69
+ tokenizer=model_path,
70
+ dtype=dtype,
71
+ max_model_len=max_model_len,
72
+ gpu_memory_utilization=gpu_memory_utilization,
73
+ tensor_parallel_size=tensor_parallel_size,
74
+ trust_remote_code=True,
75
+ disable_log_stats=False,
76
+ **engine_kwargs
77
+ )
78
+
79
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
80
+
81
+ print(f"Maya1 Model ready\n")
82
+
83
+ def _validate_emotion_tags(self):
84
+ """Validate that all 20 emotion tags are single tokens."""
85
+ failed_tags = []
86
+ for tag in ALL_EMOTION_TAGS:
87
+ token_ids = self.tokenizer.encode(tag, add_special_tokens=False)
88
+ if len(token_ids) != 1:
89
+ failed_tags.append((tag, len(token_ids)))
90
+
91
+ if failed_tags:
92
+ print(f"ERROR: {len(failed_tags)} emotion tags are NOT single tokens!")
93
+ raise AssertionError(f"Emotion tags validation failed")
94
+
95
+ print(f"All {len(ALL_EMOTION_TAGS)} emotion tags validated")
96
+
97
+ def _init_special_tokens(self):
98
+ """Precompute special token strings for fast prefix building."""
99
+ self.soh_token = self.tokenizer.decode([SOH_ID])
100
+ self.bos_token = self.tokenizer.bos_token
101
+ self.eot_token = self.tokenizer.decode([TEXT_EOT_ID])
102
+ self.eoh_token = self.tokenizer.decode([EOH_ID])
103
+ self.soa_token = self.tokenizer.decode([SOA_ID])
104
+ self.sos_token = self.tokenizer.decode([CODE_START_TOKEN_ID])
105
+
106
+ async def generate(self, prompt: str, sampling_params: SamplingParams):
107
+ """
108
+ Generate tokens from prompt (non-streaming).
109
+ Args:
110
+ prompt: Input prompt
111
+ sampling_params: vLLM sampling parameters
112
+ Returns:
113
+ Generated output from vLLM
114
+ """
115
+ request_id = f"req_{id(prompt)}"
116
+
117
+ # Collect results from async generator
118
+ final_output = None
119
+ async for output in self.engine.generate(
120
+ prompt=prompt,
121
+ sampling_params=sampling_params,
122
+ request_id=request_id
123
+ ):
124
+ final_output = output
125
+
126
+ return [final_output] if final_output else []
127
+
128
+ async def generate_stream(self, prompt: str, sampling_params: SamplingParams):
129
+ """
130
+ Generate tokens from prompt (streaming).
131
+ Args:
132
+ prompt: Input prompt
133
+ sampling_params: vLLM sampling parameters
134
+ Yields:
135
+ Generated outputs from vLLM
136
+ """
137
+ request_id = f"req_{id(prompt)}"
138
+
139
+ # Stream from engine
140
+ async for output in self.engine.generate(
141
+ prompt=prompt,
142
+ sampling_params=sampling_params,
143
+ request_id=request_id
144
+ ):
145
+ yield output
maya1/pipeline.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya1 Generation Pipeline
3
+ End-to-end pipeline for TTS generation (non-streaming).
4
+ """
5
+
6
+ import asyncio
7
+ from typing import Optional, List
8
+ from vllm import SamplingParams
9
+
10
+ from .constants import (
11
+ CODE_END_TOKEN_ID,
12
+ CODE_START_TOKEN_ID,
13
+ SNAC_MIN_ID,
14
+ SNAC_MAX_ID,
15
+ DEFAULT_TEMPERATURE,
16
+ DEFAULT_TOP_P,
17
+ DEFAULT_MAX_TOKENS,
18
+ DEFAULT_MIN_TOKENS,
19
+ DEFAULT_REPETITION_PENALTY,
20
+ DEFAULT_SEED,
21
+ )
22
+
23
+
24
+ class Maya1Pipeline:
25
+ """End-to-end TTS pipeline for Maya1."""
26
+
27
+ def __init__(self, model, prompt_builder, snac_decoder):
28
+ """
29
+ Initialize pipeline.
30
+ Args:
31
+ model: Maya1Model instance
32
+ prompt_builder: Maya1PromptBuilder instance
33
+ snac_decoder: SNACDecoder instance
34
+ """
35
+ self.model = model
36
+ self.prompt_builder = prompt_builder
37
+ self.snac_decoder = snac_decoder
38
+ print(f"✅ Maya1Pipeline initialized")
39
+
40
+ async def generate_speech(
41
+ self,
42
+ description: str,
43
+ text: str,
44
+ temperature: float = DEFAULT_TEMPERATURE,
45
+ top_p: float = DEFAULT_TOP_P,
46
+ max_tokens: int = DEFAULT_MAX_TOKENS,
47
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
48
+ seed: Optional[int] = None,
49
+ ) -> Optional[bytes]:
50
+ """
51
+ Generate speech audio (non-streaming).
52
+ Args:
53
+ description: Voice description
54
+ text: Text to synthesize (may include <emotion> tags)
55
+ temperature: Sampling temperature
56
+ top_p: Nucleus sampling
57
+ max_tokens: Max SNAC tokens to generate
58
+ repetition_penalty: Prevent loops
59
+ seed: Random seed for reproducibility
60
+
61
+ Returns:
62
+ Audio bytes (int16 PCM, 24kHz mono) or None if failed
63
+ """
64
+ # Build prompt
65
+ prompt = self.prompt_builder.build_prefix(description, text)
66
+
67
+ # Configure sampling
68
+ sampling_params = SamplingParams(
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ max_tokens=max_tokens,
72
+ min_tokens=DEFAULT_MIN_TOKENS,
73
+ repetition_penalty=repetition_penalty,
74
+ stop_token_ids=[CODE_END_TOKEN_ID],
75
+ seed=seed if seed is not None else DEFAULT_SEED,
76
+ )
77
+
78
+ # Generate tokens
79
+ outputs = await self.model.generate(prompt, sampling_params)
80
+
81
+ if not outputs or len(outputs) == 0:
82
+ return None
83
+
84
+ output = outputs[0]
85
+ generated_token_ids = output.outputs[0].token_ids
86
+
87
+ # Extract SNAC codes
88
+ snac_codes = self._extract_snac_codes(generated_token_ids)
89
+
90
+ if not snac_codes:
91
+ return None
92
+
93
+ # Decode to audio
94
+ audio_bytes = await self.snac_decoder.decode_single_async(snac_codes)
95
+
96
+ if audio_bytes:
97
+ frames = len(snac_codes) // 7
98
+ duration_sec = frames / 6.86
99
+ print(f" Generated {frames} frames (~{duration_sec:.1f}s audio)")
100
+
101
+ return audio_bytes
102
+
103
+ def _extract_snac_codes(self, token_ids: List[int]) -> List[int]:
104
+ # Find SOS and EOS positions
105
+ try:
106
+ sos_idx = token_ids.index(CODE_START_TOKEN_ID)
107
+ except ValueError:
108
+ sos_idx = -1
109
+
110
+ try:
111
+ eos_idx = token_ids.index(CODE_END_TOKEN_ID)
112
+ except ValueError:
113
+ eos_idx = len(token_ids)
114
+
115
+ # Extract tokens between SOS and EOS
116
+ if sos_idx >= 0:
117
+ snac_tokens = token_ids[sos_idx + 1:eos_idx]
118
+ else:
119
+ # If no SOS found, take everything before EOS
120
+ snac_tokens = token_ids[:eos_idx]
121
+
122
+ # Filter to only valid SNAC token IDs
123
+ snac_codes = [
124
+ token_id for token_id in snac_tokens
125
+ if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID
126
+ ]
127
+
128
+ return snac_codes
maya1/prompt_builder.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya1 Prompt Builder
3
+ Builds formatted prompts for description-conditioned TTS.
4
+ Format: <SOH><BOS><description="..."> text<EOT><EOH><SOA><SOS>
5
+ """
6
+
7
+ from .constants import ALL_EMOTION_TAGS
8
+
9
+
10
+ class Maya1PromptBuilder:
11
+ """Builds prompts in the format expected by Maya1 model."""
12
+
13
+ def __init__(self, tokenizer, model):
14
+ self.tokenizer = tokenizer
15
+ self.model = model
16
+
17
+ def build_prefix(self, description: str, text: str) -> str:
18
+ # Format as: <description="..."> text
19
+ formatted_text = f'<description="{description}"> {text}'
20
+ # Build full prefix with special tokens
21
+ prompt = (
22
+ self.model.soh_token +
23
+ self.model.bos_token +
24
+ formatted_text +
25
+ self.model.eot_token +
26
+ self.model.eoh_token +
27
+ self.model.soa_token +
28
+ self.model.sos_token
29
+ )
30
+
31
+ return prompt
maya1/snac_decoder.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import asyncio
4
+ from typing import List, Optional, Tuple
5
+ from snac import SNAC
6
+
7
+ from .constants import (
8
+ CODE_END_TOKEN_ID,
9
+ CODE_TOKEN_OFFSET,
10
+ SNAC_MODEL_NAME,
11
+ SNAC_SAMPLE_RATE,
12
+ SNAC_TOKENS_PER_FRAME,
13
+ )
14
+
15
+
16
+ class SNACDecoder:
17
+ """
18
+ SNAC Decoder for maya1.
19
+ Unpacks 7-token SNAC frames and decodes to audio waveforms.
20
+ Unpacking logic is the EXACT INVERSE of training preprocessing.
21
+ Supports async batching for concurrent requests.
22
+ CRITICAL: Any mismatch in unpacking will produce garbage audio.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ device: str = "cuda",
28
+ compile_decoder: bool = False,
29
+ enable_batching: bool = False,
30
+ max_batch_size: int = 64,
31
+ batch_timeout_ms: int = 15,
32
+ ):
33
+ """
34
+ Initialize SNAC decoder.
35
+
36
+ Args:
37
+ device: Device for SNAC model (cuda/cpu)
38
+ compile_decoder: Use torch.compile for speedup
39
+ enable_batching: Enable async batching
40
+ max_batch_size: Max sequences to batch together
41
+ batch_timeout_ms: Max wait time before processing batch
42
+ """
43
+ self.device = device
44
+ self.enable_batching = enable_batching
45
+ self.max_batch_size = max_batch_size
46
+ self.batch_timeout_ms = batch_timeout_ms
47
+
48
+ print(f"Loading SNAC 24kHz model to {device}...")
49
+ self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device)
50
+
51
+ if compile_decoder:
52
+ print(f"Compiling SNAC decoder with torch.compile...")
53
+ self._compile_model()
54
+
55
+ # Batching infrastructure
56
+ if enable_batching:
57
+ self.request_queue = asyncio.Queue()
58
+ self.batch_processor_task = None
59
+ self._running = False
60
+ print(f"Batching enabled (max_batch={max_batch_size}, timeout={batch_timeout_ms}ms)")
61
+
62
+ print(f"SNAC decoder initialized")
63
+
64
+ def _compile_model(self):
65
+ """Compile SNAC decoder with torch.compile"""
66
+ # Warm up with various sizes
67
+ for frames in [4, 16, 32]:
68
+ dummy_codes = [
69
+ torch.randint(0, 4096, (1, frames), device=self.device),
70
+ torch.randint(0, 4096, (1, frames * 2), device=self.device),
71
+ torch.randint(0, 4096, (1, frames * 4), device=self.device),
72
+ ]
73
+ with torch.inference_mode():
74
+ z_q = self.snac_model.quantizer.from_codes(dummy_codes)
75
+ _ = self.snac_model.decoder(z_q)
76
+
77
+ # Apply compilation
78
+ self.snac_model.decoder = torch.compile(
79
+ self.snac_model.decoder,
80
+ mode="max-autotune"
81
+ )
82
+ self.snac_model.quantizer = torch.compile(
83
+ self.snac_model.quantizer,
84
+ mode="reduce-overhead"
85
+ )
86
+
87
+ print(f"SNAC decoder compiled")
88
+
89
+ def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]:
90
+ """
91
+ Unpack 7-token SNAC frames to 3 hierarchical levels.
92
+
93
+ This is the EXACT INVERSE of the training preprocessing function
94
+ `pack_snac_to_7_and_offset()`.
95
+
96
+ Frame structure:
97
+ [slot0, slot1, slot2, slot3, slot4, slot5, slot6]
98
+
99
+ Unpacking:
100
+ - slot0: L1[i]
101
+ - slot1: L2[2*i] (even index)
102
+ - slot2: L3[4*i + 0]
103
+ - slot3: L3[4*i + 1]
104
+ - slot4: L2[2*i + 1] (odd index)
105
+ - slot5: L3[4*i + 2]
106
+ - slot6: L3[4*i + 3]
107
+
108
+ Args:
109
+ vocab_ids: List of SNAC token IDs (128266-156937)
110
+ Must be divisible by 7
111
+
112
+ Returns:
113
+ [L1, L2, L3] where:
114
+ L1: n elements (coarse level)
115
+ L2: 2n elements (medium level)
116
+ L3: 4n elements (fine level)
117
+ """
118
+ # Strip EOS token if present
119
+ if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID:
120
+ vocab_ids = vocab_ids[:-1]
121
+
122
+ # Ensure complete frames (divisible by 7)
123
+ frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME
124
+ vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME]
125
+
126
+ if frames == 0:
127
+ return [[], [], []]
128
+
129
+ l1, l2, l3 = [], [], []
130
+
131
+ for i in range(frames):
132
+ # Extract 7 slots for this frame
133
+ slots = vocab_ids[i*7:(i+1)*7]
134
+
135
+ # Subtract offset (128266) and mod 4096 to get original codes
136
+ # Each level uses 4096 codes (0-4095)
137
+ l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
138
+ l2.extend([
139
+ (slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even index
140
+ (slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd index
141
+ ])
142
+ l3.extend([
143
+ (slots[2] - CODE_TOKEN_OFFSET) % 4096,
144
+ (slots[3] - CODE_TOKEN_OFFSET) % 4096,
145
+ (slots[5] - CODE_TOKEN_OFFSET) % 4096,
146
+ (slots[6] - CODE_TOKEN_OFFSET) % 4096,
147
+ ])
148
+
149
+ return [l1, l2, l3]
150
+
151
+ @torch.inference_mode()
152
+ def decode(
153
+ self,
154
+ snac_tokens: List[int],
155
+ trim_warmup: bool = True,
156
+ trim_amount: Optional[int] = None,
157
+ use_sliding_window: bool = False
158
+ ) -> Optional[np.ndarray]:
159
+ """
160
+ Decode SNAC tokens to audio waveform.
161
+
162
+ Args:
163
+ snac_tokens: List of SNAC token IDs (7*n tokens)
164
+ trim_warmup: Whether to trim SNAC warmup samples (default: True)
165
+ trim_amount: Number of samples to trim (default: 2048 for first chunk, 0 for others)
166
+ Can be set to a smaller value (e.g., 512) for intermediate chunks
167
+ use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
168
+
169
+ Returns:
170
+ Audio waveform as numpy array (float32, 24kHz mono)
171
+ Shape: (samples,)
172
+ Returns None if not enough tokens
173
+ """
174
+ if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
175
+ print(f"Not enough SNAC tokens: {len(snac_tokens)} < {SNAC_TOKENS_PER_FRAME}")
176
+ return None
177
+
178
+ # Unpack to 3 levels
179
+ levels = self.unpack_snac_from_7(snac_tokens)
180
+
181
+ if not levels[0]: # No frames after unpacking
182
+ return None
183
+
184
+ # Convert to tensors
185
+ codes = [
186
+ torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0)
187
+ for level in levels
188
+ ]
189
+
190
+ # Decode through SNAC
191
+ z_q = self.snac_model.quantizer.from_codes(codes)
192
+ audio = self.snac_model.decoder(z_q)
193
+
194
+ # Extract audio (remove padding if any)
195
+ # SNAC decoder outputs: [batch, 1, samples]
196
+ audio = audio[0, 0].cpu().numpy()
197
+
198
+ # Sliding window mode: only keep middle 2048 samples
199
+ # This eliminates popping/cracking when using overlapping 28-token windows
200
+ if use_sliding_window:
201
+ if len(audio) >= 4096:
202
+ audio = audio[2048:4096] # Keep middle portion only
203
+ else:
204
+ # For shorter audio, keep everything (final chunk)
205
+ pass
206
+ else:
207
+ # Standard mode: trim warm-up samples
208
+ # Default: 2048 samples for first chunk, 0 for subsequent chunks
209
+ # Can be customized via trim_amount parameter
210
+ if trim_warmup:
211
+ if trim_amount is None:
212
+ trim_amount = 2048 # Default full trim
213
+
214
+ if len(audio) > trim_amount:
215
+ audio = audio[trim_amount:]
216
+
217
+ return audio
218
+
219
+ def decode_to_bytes(
220
+ self,
221
+ snac_tokens: List[int],
222
+ trim_warmup: bool = True,
223
+ use_sliding_window: bool = False
224
+ ) -> Optional[bytes]:
225
+ """
226
+ Decode SNAC tokens to audio bytes (int16 PCM).
227
+
228
+ Args:
229
+ snac_tokens: List of SNAC token IDs
230
+ trim_warmup: Whether to trim SNAC warmup samples (default: True)
231
+ use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
232
+
233
+ Returns:
234
+ Audio as bytes (int16 PCM, 24kHz mono)
235
+ Returns None if decode fails
236
+ """
237
+ audio = self.decode(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
238
+
239
+ if audio is None:
240
+ return None
241
+
242
+ # Convert float32 to int16 PCM
243
+ audio_int16 = (audio * 32767).astype(np.int16)
244
+
245
+ return audio_int16.tobytes()
246
+
247
+ def validate_tokens(self, snac_tokens: List[int]) -> bool:
248
+ """
249
+ Validate SNAC tokens before decoding.
250
+ Args:
251
+ snac_tokens: List of SNAC token IDs
252
+ Returns:
253
+ True if valid, False otherwise
254
+ """
255
+ # Check minimum length
256
+ if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
257
+ print(f"Too few tokens: {len(snac_tokens)}")
258
+ return False
259
+
260
+ # Check divisibility by 7
261
+ if len(snac_tokens) % SNAC_TOKENS_PER_FRAME != 0:
262
+ print(f" Warning: Token count {len(snac_tokens)} not divisible by 7")
263
+ print(f" Will truncate to {(len(snac_tokens) // 7) * 7}")
264
+
265
+ # Check token range
266
+ for i, token_id in enumerate(snac_tokens):
267
+ if token_id < CODE_TOKEN_OFFSET or token_id > 156937:
268
+ print(f" Invalid token at position {i}: {token_id}")
269
+ print(f" Expected range: [{CODE_TOKEN_OFFSET}, 156937]")
270
+ return False
271
+
272
+ return True
273
+
274
+ # ========== Async Batching Methods ==========
275
+
276
+ @property
277
+ def is_running(self) -> bool:
278
+ """Check if batch processor is running."""
279
+ return self._running if self.enable_batching else False
280
+
281
+ async def start_batch_processor(self):
282
+ """Start the background batch processor task."""
283
+ if not self.enable_batching:
284
+ return
285
+
286
+ if self._running:
287
+ print("Batch processor already running")
288
+ return
289
+
290
+ self._running = True
291
+ self.batch_processor_task = asyncio.create_task(self._batch_processor_loop())
292
+ print("Batch processor started")
293
+
294
+ async def stop_batch_processor(self):
295
+ """Stop the background batch processor task."""
296
+ if not self.enable_batching:
297
+ return
298
+
299
+ if not self._running:
300
+ return
301
+
302
+ self._running = False
303
+
304
+ if self.batch_processor_task:
305
+ self.batch_processor_task.cancel()
306
+ try:
307
+ await self.batch_processor_task
308
+ except asyncio.CancelledError:
309
+ pass
310
+
311
+ print("Batch processor stopped")
312
+
313
+ async def decode_single_async(
314
+ self,
315
+ snac_tokens: List[int],
316
+ trim_warmup: bool = True,
317
+ use_sliding_window: bool = False
318
+ ) -> Optional[bytes]:
319
+ """
320
+ Async decode for batching support.
321
+
322
+ Queues the request and waits for batched processing.
323
+
324
+ Args:
325
+ snac_tokens: List of SNAC token IDs
326
+ trim_warmup: Whether to trim SNAC warmup samples (default: True)
327
+ use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
328
+
329
+ Returns:
330
+ Audio bytes or None if decode fails
331
+ """
332
+ if not self.enable_batching:
333
+ # Fallback to synchronous decode
334
+ return self.decode_to_bytes(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
335
+
336
+ # Create future for result
337
+ result_future = asyncio.Future()
338
+
339
+ # Add to queue (include trim_warmup and sliding_window flags)
340
+ await self.request_queue.put((snac_tokens, trim_warmup, use_sliding_window, result_future))
341
+
342
+ # Wait for result
343
+ return await result_future
344
+
345
+ async def _batch_processor_loop(self):
346
+ """Background task that processes batched decode requests."""
347
+ while self._running:
348
+ try:
349
+ # Collect batch
350
+ batch = await self._collect_batch()
351
+
352
+ if not batch:
353
+ continue
354
+
355
+ # Process batch
356
+ await self._process_batch(batch)
357
+
358
+ except asyncio.CancelledError:
359
+ break
360
+ except Exception as e:
361
+ print(f"Batch processor error: {e}")
362
+ import traceback
363
+ traceback.print_exc()
364
+
365
+ async def _collect_batch(self) -> List[Tuple[List[int], bool, bool, asyncio.Future]]:
366
+ """
367
+ Collect requests into a batch.
368
+ Waits for timeout or until batch is full.
369
+ Returns:
370
+ List of (tokens, trim_warmup, use_sliding_window, future) tuples
371
+ """
372
+ batch = []
373
+ timeout_sec = self.batch_timeout_ms / 1000.0
374
+
375
+ try:
376
+ # Wait for first request (blocking)
377
+ first_item = await asyncio.wait_for(
378
+ self.request_queue.get(),
379
+ timeout=timeout_sec
380
+ )
381
+ batch.append(first_item)
382
+
383
+ # Collect more requests (non-blocking)
384
+ while len(batch) < self.max_batch_size:
385
+ try:
386
+ item = await asyncio.wait_for(
387
+ self.request_queue.get(),
388
+ timeout=timeout_sec
389
+ )
390
+ batch.append(item)
391
+ except asyncio.TimeoutError:
392
+ break # Timeout reached, process what we have
393
+
394
+ except asyncio.TimeoutError:
395
+ # No requests in timeout period
396
+ pass
397
+
398
+ return batch
399
+
400
+ @torch.inference_mode()
401
+ async def _process_batch(self, batch: List[Tuple[List[int], bool, bool, asyncio.Future]]):
402
+ """
403
+ Process a batch of decode requests.
404
+ Args:
405
+ batch: List of (tokens, trim_warmup, use_sliding_window, future) tuples
406
+ """
407
+ if not batch:
408
+ return
409
+
410
+ # Extract components
411
+ token_sequences = [item[0] for item in batch]
412
+ trim_warmup_flags = [item[1] for item in batch]
413
+ sliding_window_flags = [item[2] for item in batch]
414
+ futures = [item[3] for item in batch]
415
+
416
+ lengths = [len(tokens) for tokens in token_sequences]
417
+ can_batch_efficiently = len(set(lengths)) == 1
418
+
419
+ if can_batch_efficiently and len(batch) > 1:
420
+ # Efficient batching: all same length
421
+ try:
422
+ audio_bytes_list = await self._decode_batch_same_length(
423
+ token_sequences, trim_warmup_flags, sliding_window_flags
424
+ )
425
+
426
+ # Set results
427
+ for future, audio_bytes in zip(futures, audio_bytes_list):
428
+ if not future.done():
429
+ future.set_result(audio_bytes)
430
+
431
+ except Exception as e:
432
+ # Set exceptions
433
+ for future in futures:
434
+ if not future.done():
435
+ future.set_exception(e)
436
+ else:
437
+ # Sequential decode (different lengths or single item)
438
+ for tokens, trim_warmup, use_sliding_window, future in batch:
439
+ try:
440
+ audio_bytes = self.decode_to_bytes(
441
+ tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window
442
+ )
443
+ if not future.done():
444
+ future.set_result(audio_bytes)
445
+ except Exception as e:
446
+ if not future.done():
447
+ future.set_exception(e)
448
+
449
+ async def _decode_batch_same_length(
450
+ self,
451
+ token_sequences: List[List[int]],
452
+ trim_warmup_flags: List[bool],
453
+ sliding_window_flags: List[bool]
454
+ ) -> List[Optional[bytes]]:
455
+ """
456
+ Decode multiple sequences with same length in parallel.
457
+
458
+ Args:
459
+ token_sequences: List of token sequences (all same length)
460
+ trim_warmup_flags: List of trim_warmup flags for each sequence
461
+ sliding_window_flags: List of use_sliding_window flags for each sequence
462
+
463
+ Returns:
464
+ List of audio bytes
465
+ """
466
+ if not token_sequences:
467
+ return []
468
+
469
+ # Unpack all sequences
470
+ unpacked_list = [self.unpack_snac_from_7(tokens) for tokens in token_sequences]
471
+
472
+ # Check all have valid frames
473
+ valid_indices = [i for i, levels in enumerate(unpacked_list) if levels[0]]
474
+
475
+ if not valid_indices:
476
+ return [None] * len(token_sequences)
477
+
478
+ # Stack into batched tensors
479
+ batch_size = len(valid_indices)
480
+ frames = len(unpacked_list[valid_indices[0]][0])
481
+
482
+ # Build batched codes [batch, frames], [batch, 2*frames], [batch, 4*frames]
483
+ codes = [
484
+ torch.stack([
485
+ torch.tensor(unpacked_list[i][level_idx], dtype=torch.long, device=self.device)
486
+ for i in valid_indices
487
+ ], dim=0)
488
+ for level_idx in range(3)
489
+ ]
490
+
491
+ # Batched decode
492
+ z_q = self.snac_model.quantizer.from_codes(codes)
493
+ audio_batch = self.snac_model.decoder(z_q) # [batch, 1, samples]
494
+
495
+ # Extract and convert to bytes
496
+ audio_bytes_list = [None] * len(token_sequences)
497
+
498
+ for batch_idx, orig_idx in enumerate(valid_indices):
499
+ audio = audio_batch[batch_idx, 0].detach().cpu().numpy()
500
+
501
+ # Apply sliding window or trim warmup based on flags
502
+ if sliding_window_flags[orig_idx]:
503
+ # Sliding window mode: keep middle 2048 samples only
504
+ if len(audio) >= 4096:
505
+ audio = audio[2048:4096]
506
+ else:
507
+ # Standard mode: trim warm-up if requested
508
+ if trim_warmup_flags[orig_idx] and len(audio) > 2048:
509
+ audio = audio[2048:]
510
+
511
+ # Convert to int16
512
+ audio_int16 = (audio * 32767).astype(np.int16)
513
+ audio_bytes_list[orig_idx] = audio_int16.tobytes()
514
+
515
+ return audio_bytes_list
maya1/streaming_pipeline.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya1 Streaming Pipeline - Sliding Window Approach
3
+ Implements sliding window technique for smooth streaming without artifacts.
4
+ """
5
+
6
+ import asyncio
7
+ from typing import AsyncGenerator, Optional
8
+ from vllm import SamplingParams
9
+
10
+ from .constants import (
11
+ CODE_END_TOKEN_ID,
12
+ SNAC_MIN_ID,
13
+ SNAC_MAX_ID,
14
+ DEFAULT_TEMPERATURE,
15
+ DEFAULT_TOP_P,
16
+ DEFAULT_MAX_TOKENS,
17
+ DEFAULT_MIN_TOKENS,
18
+ DEFAULT_REPETITION_PENALTY,
19
+ DEFAULT_SEED,
20
+ )
21
+
22
+
23
+ class Maya1SlidingWindowPipeline:
24
+ """
25
+ Streaming TTS pipeline using sliding window approach.
26
+ Decodes overlapping 28-token windows (4 frames) and keeps only
27
+ the middle 2048 samples for smooth audio continuity.
28
+ """
29
+
30
+ # Sliding window configuration
31
+ WINDOW_SIZE = 28 # 4 frames (7 tokens per frame)
32
+ YIELD_STRIDE = 7 # Yield every 1 frame
33
+ MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode
34
+
35
+ def __init__(self, model, prompt_builder, snac_decoder):
36
+ """
37
+ Initialize sliding window streaming pipeline.
38
+
39
+ Args:
40
+ model: Maya1Model instance
41
+ prompt_builder: Maya1PromptBuilder instance
42
+ snac_decoder: SNACDecoder instance
43
+ """
44
+ self.model = model
45
+ self.prompt_builder = prompt_builder
46
+ self.snac_decoder = snac_decoder
47
+ print(f"Sliding window pipeline initialized")
48
+
49
+ async def generate_speech_stream(
50
+ self,
51
+ description: str,
52
+ text: str,
53
+ temperature: float = DEFAULT_TEMPERATURE,
54
+ top_p: float = DEFAULT_TOP_P,
55
+ max_tokens: int = DEFAULT_MAX_TOKENS,
56
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
57
+ seed: Optional[int] = None,
58
+ ) -> AsyncGenerator[bytes, None]:
59
+ """
60
+ Generate speech audio with sliding window streaming.
61
+
62
+ Args:
63
+ description: Voice description
64
+ text: Text to synthesize (may include <emotion> tags)
65
+ temperature: Sampling temperature
66
+ top_p: Nucleus sampling
67
+ max_tokens: Max SNAC tokens to generate
68
+ repetition_penalty: Prevent loops
69
+ seed: Random seed
70
+
71
+ Yields:
72
+ Audio bytes (int16 PCM, 24kHz mono)
73
+ """
74
+ # Build prompt
75
+ prompt = self.prompt_builder.build_prefix(description, text)
76
+
77
+ # Configure sampling
78
+ sampling_params = SamplingParams(
79
+ temperature=temperature,
80
+ top_p=top_p,
81
+ max_tokens=max_tokens,
82
+ min_tokens=DEFAULT_MIN_TOKENS,
83
+ repetition_penalty=repetition_penalty,
84
+ stop_token_ids=[CODE_END_TOKEN_ID],
85
+ seed=seed if seed is not None else DEFAULT_SEED,
86
+ )
87
+
88
+ # Stream tokens
89
+ snac_buffer = []
90
+ last_yield_position = 0
91
+ chunk_count = 0
92
+ total_tokens_seen = 0
93
+
94
+ async for output in self.model.generate_stream(prompt, sampling_params):
95
+ # Get latest generated tokens (cumulative list)
96
+ generated_token_ids = output.outputs[0].token_ids
97
+
98
+ # Process only NEW tokens since last iteration
99
+ new_tokens = generated_token_ids[total_tokens_seen:]
100
+ total_tokens_seen = len(generated_token_ids)
101
+
102
+ # Collect SNAC codes from new tokens
103
+ for token_id in new_tokens:
104
+ # Stop if we hit EOS
105
+ if token_id == CODE_END_TOKEN_ID:
106
+ break
107
+
108
+ # Only collect valid SNAC tokens
109
+ if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
110
+ snac_buffer.append(token_id)
111
+
112
+ # Yield audio when we have enough tokens for a window
113
+ while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE:
114
+ # Get window of 28 tokens
115
+ window_start = last_yield_position
116
+ window_end = window_start + self.WINDOW_SIZE
117
+ window = snac_buffer[window_start:window_end]
118
+
119
+ if len(window) == self.WINDOW_SIZE:
120
+ # Decode window to audio
121
+ audio_bytes = await self.snac_decoder.decode_single_async(window)
122
+
123
+ if audio_bytes:
124
+ # Extract middle portion of audio
125
+ audio_samples = len(audio_bytes) // 2
126
+ middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2
127
+ middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES
128
+
129
+ # Convert to byte positions
130
+ middle_start_byte = middle_start_sample * 2
131
+ middle_end_byte = middle_end_sample * 2
132
+
133
+ # Extract middle chunk
134
+ audio_chunk = audio_bytes[middle_start_byte:middle_end_byte]
135
+
136
+ chunk_count += 1
137
+ if chunk_count == 1:
138
+ print(f" First chunk ready")
139
+
140
+ yield audio_chunk
141
+
142
+ # Move forward by stride
143
+ last_yield_position += self.YIELD_STRIDE
144
+
145
+ # Check if generation is done
146
+ if CODE_END_TOKEN_ID in new_tokens:
147
+ break
148
+
149
+ # Final chunk: decode remaining tokens
150
+ remaining_tokens = len(snac_buffer) - last_yield_position
151
+ if remaining_tokens >= self.WINDOW_SIZE:
152
+ window = snac_buffer[-self.WINDOW_SIZE:]
153
+ audio_bytes = await self.snac_decoder.decode_single_async(window)
154
+ if audio_bytes:
155
+ yield audio_bytes[-self.MIDDLE_SAMPLES * 2:]
156
+
157
+ frames = len(snac_buffer) // 7
158
+ duration = frames / 6.86
159
+ print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)")