from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel from typing import List, Optional import os import chat, voice, database import traceback app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize database on startup @app.on_event("startup") async def startup_event(): database.init_db() class ChatRequest(BaseModel): message: str # history is now optional/deprecated as we use session_id, but keeping for backward compatibility if needed history: List[dict] = [] class SessionCreateRequest(BaseModel): name: str = "New Chat" user_id: str class VoiceRequest(BaseModel): text: str voice: str @app.get("/") async def root(): return {"message": "AI Girlfriend API"} @app.get("/sessions/{user_id}") async def get_sessions(user_id: str): return database.get_sessions(user_id) @app.post("/sessions") async def create_session(request: SessionCreateRequest): return database.create_session(request.user_id, request.name) @app.delete("/sessions/{session_id}") async def delete_session(session_id: str): database.delete_session(session_id) return {"message": "Session deleted"} @app.get("/sessions/{session_id}/messages") async def get_session_messages(session_id: str): return database.get_messages(session_id) @app.post("/sessions/{session_id}/chat") async def chat_session_endpoint(session_id: str, request: ChatRequest): try: # Get existing history from DB db_history = database.get_messages(session_id) # Convert to format expected by chat.py (list of dicts with role/content) history_for_model = [{"role": msg["role"], "content": msg["content"]} for msg in db_history] # Get response from AI response_text = chat.get_chat_response(request.message, history_for_model) # Save user message and AI response to DB database.add_message(session_id, "user", request.message) database.add_message(session_id, "assistant", response_text) return {"response": response_text} except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) # Legacy endpoint - keeping for now or can be removed if frontend is fully updated @app.post("/chat") async def chat_endpoint(request: ChatRequest): try: response = chat.get_chat_response(request.message, request.history) return {"response": response} except Exception as e: traceback.print_exc() # Print error to console raise HTTPException(status_code=500, detail=str(e)) @app.get("/voices") async def voices_endpoint(): return voice.get_voices() @app.post("/speak") async def speak_endpoint(request: VoiceRequest): try: audio_path = await voice.generate_audio(request.text, request.voice) return FileResponse(audio_path, media_type="audio/mpeg", filename="response.mp3") except Exception as e: raise HTTPException(status_code=500, detail=str(e))