File size: 9,688 Bytes
3c92819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import os
import io
import wave
import time
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from dotenv import load_dotenv

from .model_loader import Maya1Model
from .prompt_builder import Maya1PromptBuilder
from .snac_decoder import SNACDecoder
from .pipeline import Maya1Pipeline
from .streaming_pipeline import Maya1SlidingWindowPipeline
from .constants import (
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_REPETITION_PENALTY,
    AUDIO_SAMPLE_RATE,
)

# Timeout settings (seconds)
GENERATE_TIMEOUT = 60

# Load environment variables
load_dotenv()

# Initialize FastAPI app
app = FastAPI(
    title="Maya1 TTS API",
    description="Open source TTS inference for Maya1",
    version="1.0.0",
    docs_url=None,
    redoc_url=None,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global state
model = None
prompt_builder = None
snac_decoder = None
pipeline = None
streaming_pipeline = None


# ============================================================================
# Startup/Shutdown
# ============================================================================

@app.on_event("startup")
async def startup_event():
    """Initialize model on startup."""
    global model, prompt_builder, snac_decoder, pipeline, streaming_pipeline
    
    print("\n" + "="*60)
    print(" Starting Maya1 TTS API Server")
    print("="*60 + "\n")
    
    # Initialize components
    model = Maya1Model()
    prompt_builder = Maya1PromptBuilder(model.tokenizer, model)
    
    # Initialize SNAC decoder
    snac_decoder = SNACDecoder(enable_batching=True, max_batch_size=64, batch_timeout_ms=15)
    await snac_decoder.start_batch_processor()
    
    # Initialize pipelines
    pipeline = Maya1Pipeline(model, prompt_builder, snac_decoder)
    streaming_pipeline = Maya1SlidingWindowPipeline(model, prompt_builder, snac_decoder)
    
    print("\n" + "="*60)
    print("Maya1 TTS API Server Ready")
    print("="*60 + "\n")


@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup on shutdown."""
    print("\nShutting down Maya1 TTS API Server")
    
    if snac_decoder and snac_decoder.is_running:
        await snac_decoder.stop_batch_processor()


# ============================================================================
# Utility Functions
# ============================================================================

def create_wav_header(sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16, data_size: int = 0) -> bytes:
    """Create WAV file header."""
    import struct
    
    byte_rate = sample_rate * channels * bits_per_sample // 8
    block_align = channels * bits_per_sample // 8
    
    header = struct.pack(
        '<4sI4s4sIHHIIHH4sI',
        b'RIFF',
        36 + data_size,
        b'WAVE',
        b'fmt ',
        16,
        1,
        channels,
        sample_rate,
        byte_rate,
        block_align,
        bits_per_sample,
        b'data',
        data_size
    )
    
    return header


# ============================================================================
# Request/Response Models
# ============================================================================

class TTSRequest(BaseModel):
    """TTS generation request."""
    description: str = Field(
        ...,
        description="Voice description (e.g., 'Male voice in their 30s with american accent')"
    )
    text: str = Field(
        ...,
        description="Text to synthesize (can include <emotion> tags)"
    )
    temperature: Optional[float] = Field(
        default=DEFAULT_TEMPERATURE,
        description="Sampling temperature"
    )
    top_p: Optional[float] = Field(
        default=DEFAULT_TOP_P,
        description="Nucleus sampling"
    )
    max_tokens: Optional[int] = Field(
        default=DEFAULT_MAX_TOKENS,
        description="Maximum tokens to generate"
    )
    repetition_penalty: Optional[float] = Field(
        default=DEFAULT_REPETITION_PENALTY,
        description="Repetition penalty"
    )
    seed: Optional[int] = Field(
        default=None,
        description="Random seed for reproducibility",
        ge=0,
    )
    stream: bool = Field(
        default=False,
        description="Stream audio (True) or return complete WAV (False)"
    )


# ============================================================================
# Endpoints
# ============================================================================

@app.get("/")
async def root():
    """Root endpoint."""
    return {
        "service": "Maya1 TTS API",
        "version": "1.0.0",
        "status": "running",
        "model": "Maya1-Voice (open source)",
        "endpoints": {
            "generate": "/v1/tts/generate (POST)",
            "health": "/health (GET)",
        },
    }


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model": "Maya1-Voice",
        "timestamp": time.time(),
    }


# ============================================================================
# TTS Generation Endpoint
# ============================================================================

@app.post("/v1/tts/generate")
async def generate_tts(request: TTSRequest):
    """Generate TTS audio from description and text."""
    
    try:
        # Route to streaming or non-streaming
        if request.stream:
            return await _generate_tts_streaming(
                description=request.description,
                text=request.text,
                temperature=request.temperature,
                top_p=request.top_p,
                max_tokens=request.max_tokens,
                repetition_penalty=request.repetition_penalty,
                seed=request.seed,
            )
        else:
            return await _generate_tts_complete(
                description=request.description,
                text=request.text,
                temperature=request.temperature,
                top_p=request.top_p,
                max_tokens=request.max_tokens,
                repetition_penalty=request.repetition_penalty,
                seed=request.seed,
            )
    
    except HTTPException:
        raise
    except Exception as e:
        print(f" Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


async def _generate_tts_complete(
    description: str,
    text: str,
    temperature: float,
    top_p: float,
    max_tokens: int,
    repetition_penalty: float,
    seed: Optional[int],
):
    """Generate complete WAV file (non-streaming)."""
    
    try:
        import asyncio
        
        # Generate audio
        audio_bytes = await asyncio.wait_for(
            pipeline.generate_speech(
                description=description,
                text=text,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            ),
            timeout=GENERATE_TIMEOUT
        )
        
        if audio_bytes is None:
            raise Exception("Audio generation failed")
        
        # Create WAV file
        wav_buffer = io.BytesIO()
        with wave.open(wav_buffer, 'wb') as wav_file:
            wav_file.setnchannels(1)
            wav_file.setsampwidth(2)
            wav_file.setframerate(AUDIO_SAMPLE_RATE)
            wav_file.writeframes(audio_bytes)
        
        wav_buffer.seek(0)
        
        return StreamingResponse(
            wav_buffer,
            media_type="audio/wav",
            headers={"Content-Disposition": "attachment; filename=output.wav"}
        )
    
    except asyncio.TimeoutError:
        raise HTTPException(status_code=504, detail="Generation timeout")


async def _generate_tts_streaming(
    description: str,
    text: str,
    temperature: float,
    top_p: float,
    max_tokens: int,
    repetition_penalty: float,
    seed: Optional[int],
):
    """Generate streaming audio."""
    start_time = time.time()
    first_audio_time = None
    
    async def audio_stream_generator():
        """Generate audio stream with WAV header."""
        nonlocal first_audio_time
        
        # Send WAV header first
        yield create_wav_header(sample_rate=AUDIO_SAMPLE_RATE, channels=1, bits_per_sample=16)
        
        # Stream audio chunks
        async for audio_chunk in streaming_pipeline.generate_speech_stream(
            description=description,
            text=text,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        ):
            if first_audio_time is None:
                first_audio_time = time.time()
                ttfb_ms = (first_audio_time - start_time) * 1000
                print(f"⏱️  TTFB: {ttfb_ms:.1f}ms")
            
            yield audio_chunk
    
    try:
        return StreamingResponse(
            audio_stream_generator(),
            media_type="audio/wav",
            headers={"Cache-Control": "no-cache"}
        )
    
    except Exception as e:
        print(f"Streaming error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


# For running directly
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8000,
        log_level="info"
    )