ai-girlfriend / main.py
jeevzz's picture
Update main.py
ea6263c verified
raw
history blame
3.23 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import List, Optional
import os
import chat, voice, database
import traceback
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize database on startup
@app.on_event("startup")
async def startup_event():
database.init_db()
class ChatRequest(BaseModel):
message: str
# history is now optional/deprecated as we use session_id, but keeping for backward compatibility if needed
history: List[dict] = []
class SessionCreateRequest(BaseModel):
name: str = "New Chat"
user_id: str
class VoiceRequest(BaseModel):
text: str
voice: str
@app.get("/")
async def root():
return {"message": "AI Girlfriend API"}
@app.get("/sessions/{user_id}")
async def get_sessions(user_id: str):
return database.get_sessions(user_id)
@app.post("/sessions")
async def create_session(request: SessionCreateRequest):
return database.create_session(request.user_id, request.name)
@app.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
database.delete_session(session_id)
return {"message": "Session deleted"}
@app.get("/sessions/{session_id}/messages")
async def get_session_messages(session_id: str):
return database.get_messages(session_id)
@app.post("/sessions/{session_id}/chat")
async def chat_session_endpoint(session_id: str, request: ChatRequest):
try:
# Get existing history from DB
db_history = database.get_messages(session_id)
# Convert to format expected by chat.py (list of dicts with role/content)
history_for_model = [{"role": msg["role"], "content": msg["content"]} for msg in db_history]
# Get response from AI
response_text = chat.get_chat_response(request.message, history_for_model)
# Save user message and AI response to DB
database.add_message(session_id, "user", request.message)
database.add_message(session_id, "assistant", response_text)
return {"response": response_text}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# Legacy endpoint - keeping for now or can be removed if frontend is fully updated
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
try:
response = chat.get_chat_response(request.message, request.history)
return {"response": response}
except Exception as e:
traceback.print_exc() # Print error to console
raise HTTPException(status_code=500, detail=str(e))
@app.get("/voices")
async def voices_endpoint():
return voice.get_voices()
@app.post("/speak")
async def speak_endpoint(request: VoiceRequest):
try:
audio_path = await voice.generate_audio(request.text, request.voice)
return FileResponse(audio_path, media_type="audio/mpeg", filename="response.mp3")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))