Tim Luka Horstmann
commited on
Commit
·
7ee4aae
1
Parent(s):
0e9cc30
Rate limiting
Browse files- app.py +29 -9
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
| 3 |
import time
|
| 4 |
import numpy as np
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
-
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 7 |
from fastapi.responses import StreamingResponse, Response
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
|
@@ -18,6 +18,9 @@ from google import genai
|
|
| 18 |
from google.genai import types
|
| 19 |
import httpx
|
| 20 |
from elevenlabs import ElevenLabs, VoiceSettings
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Set up logging
|
| 23 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -25,6 +28,18 @@ logger = logging.getLogger(__name__)
|
|
| 25 |
|
| 26 |
app = FastAPI()
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Add CORS middleware to handle cross-origin requests
|
| 29 |
app.add_middleware(
|
| 30 |
CORSMiddleware,
|
|
@@ -331,20 +346,22 @@ def get_ram_usage():
|
|
| 331 |
}
|
| 332 |
|
| 333 |
@app.post("/api/predict")
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
| 337 |
return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
|
| 338 |
|
| 339 |
@app.post("/api/tts")
|
| 340 |
-
|
|
|
|
| 341 |
"""Convert text to speech using ElevenLabs API"""
|
| 342 |
if not elevenlabs_client:
|
| 343 |
raise HTTPException(status_code=503, detail="TTS service not available")
|
| 344 |
|
| 345 |
try:
|
| 346 |
# Clean the text for TTS (remove markdown and special characters)
|
| 347 |
-
clean_text =
|
| 348 |
|
| 349 |
if not clean_text:
|
| 350 |
raise HTTPException(status_code=400, detail="No text provided for TTS")
|
|
@@ -381,11 +398,13 @@ async def text_to_speech(request: TTSRequest):
|
|
| 381 |
raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
|
| 382 |
|
| 383 |
@app.get("/health")
|
| 384 |
-
|
|
|
|
| 385 |
return {"status": "healthy"}
|
| 386 |
|
| 387 |
@app.get("/model_info")
|
| 388 |
-
|
|
|
|
| 389 |
base_info = {
|
| 390 |
"embedding_model": sentence_transformer_model,
|
| 391 |
"faiss_index_size": len(cv_chunks),
|
|
@@ -411,7 +430,8 @@ async def model_info():
|
|
| 411 |
return base_info
|
| 412 |
|
| 413 |
@app.get("/ram_usage")
|
| 414 |
-
|
|
|
|
| 415 |
"""Endpoint to get current RAM usage."""
|
| 416 |
try:
|
| 417 |
ram_stats = get_ram_usage()
|
|
|
|
| 3 |
import time
|
| 4 |
import numpy as np
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
| 7 |
from fastapi.responses import StreamingResponse, Response
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
|
|
|
| 18 |
from google.genai import types
|
| 19 |
import httpx
|
| 20 |
from elevenlabs import ElevenLabs, VoiceSettings
|
| 21 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 22 |
+
from slowapi.util import get_remote_address
|
| 23 |
+
from slowapi.errors import RateLimitExceeded
|
| 24 |
|
| 25 |
# Set up logging
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 28 |
|
| 29 |
app = FastAPI()
|
| 30 |
|
| 31 |
+
# Initialize rate limiter
|
| 32 |
+
limiter = Limiter(key_func=get_remote_address)
|
| 33 |
+
app.state.limiter = limiter
|
| 34 |
+
|
| 35 |
+
# Custom rate limit exceeded handler with logging
|
| 36 |
+
async def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
| 37 |
+
client_ip = get_remote_address(request)
|
| 38 |
+
logger.warning(f"Rate limit exceeded for IP {client_ip} on endpoint {request.url.path}")
|
| 39 |
+
return await _rate_limit_exceeded_handler(request, exc)
|
| 40 |
+
|
| 41 |
+
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_handler)
|
| 42 |
+
|
| 43 |
# Add CORS middleware to handle cross-origin requests
|
| 44 |
app.add_middleware(
|
| 45 |
CORSMiddleware,
|
|
|
|
| 346 |
}
|
| 347 |
|
| 348 |
@app.post("/api/predict")
|
| 349 |
+
@limiter.limit("5/minute") # Allow 10 chat requests per minute per IP
|
| 350 |
+
async def predict(request: Request, query_request: QueryRequest):
|
| 351 |
+
query = query_request.query
|
| 352 |
+
history = query_request.history
|
| 353 |
return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
|
| 354 |
|
| 355 |
@app.post("/api/tts")
|
| 356 |
+
@limiter.limit("5/minute") # Allow 5 TTS requests per minute per IP (more restrictive as TTS is more expensive)
|
| 357 |
+
async def text_to_speech(request: Request, tts_request: TTSRequest):
|
| 358 |
"""Convert text to speech using ElevenLabs API"""
|
| 359 |
if not elevenlabs_client:
|
| 360 |
raise HTTPException(status_code=503, detail="TTS service not available")
|
| 361 |
|
| 362 |
try:
|
| 363 |
# Clean the text for TTS (remove markdown and special characters)
|
| 364 |
+
clean_text = tts_request.text.replace("**", "").replace("*", "").replace("\n", " ").strip()
|
| 365 |
|
| 366 |
if not clean_text:
|
| 367 |
raise HTTPException(status_code=400, detail="No text provided for TTS")
|
|
|
|
| 398 |
raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
|
| 399 |
|
| 400 |
@app.get("/health")
|
| 401 |
+
@limiter.limit("30/minute") # Allow frequent health checks
|
| 402 |
+
async def health_check(request: Request):
|
| 403 |
return {"status": "healthy"}
|
| 404 |
|
| 405 |
@app.get("/model_info")
|
| 406 |
+
@limiter.limit("10/minute") # Limit model info requests
|
| 407 |
+
async def model_info(request: Request):
|
| 408 |
base_info = {
|
| 409 |
"embedding_model": sentence_transformer_model,
|
| 410 |
"faiss_index_size": len(cv_chunks),
|
|
|
|
| 430 |
return base_info
|
| 431 |
|
| 432 |
@app.get("/ram_usage")
|
| 433 |
+
@limiter.limit("20/minute") # Allow moderate monitoring requests
|
| 434 |
+
async def ram_usage(request: Request):
|
| 435 |
"""Endpoint to get current RAM usage."""
|
| 436 |
try:
|
| 437 |
ram_stats = get_ram_usage()
|
requirements.txt
CHANGED
|
@@ -10,4 +10,5 @@ google-genai
|
|
| 10 |
asyncio
|
| 11 |
elevenlabs
|
| 12 |
httpx
|
| 13 |
-
llama-cpp-python==0.2.85
|
|
|
|
|
|
| 10 |
asyncio
|
| 11 |
elevenlabs
|
| 12 |
httpx
|
| 13 |
+
llama-cpp-python==0.2.85
|
| 14 |
+
slowapi==0.1.9
|