Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import logging | |
| import time | |
| import utils | |
| import config | |
| from langsmith import Client | |
| from langsmith.run_helpers import traceable | |
| import os | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize logging | |
| logger = logging.getLogger("swayam-chatbot") | |
| # Initialize LangSmith client if tracing is enabled | |
| langsmith_tracing = os.environ.get("LANGSMITH_TRACING", "false").lower() == "true" | |
| langsmith_client = None | |
| if langsmith_tracing: | |
| try: | |
| langsmith_client = Client() | |
| logger.info("LangSmith client initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize LangSmith client: {e}") | |
| langsmith_tracing = False | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Swayam's Personal Chatbot API") | |
| # Add CORS middleware to allow requests from the portfolio website | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For production, replace with actual domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ChatRequest(BaseModel): | |
| query: str | |
| class ChatResponse(BaseModel): | |
| response: str | |
| type: str # "personal" or "general" | |
| processing_time: float | |
| format: str = "markdown" # Add format field to indicate response format | |
| async def startup_event(): | |
| """Load resources on startup""" | |
| success = utils.load_resources() | |
| if not success: | |
| logger.error("Failed to load embedding resources. RAG functionality will not work.") | |
| def process_chat_request(query: str): | |
| """Process a chat request with LangSmith tracing""" | |
| # Determine if query is personal | |
| is_personal = utils.is_personal_query(query) | |
| query_type = "personal" if is_personal else "general" | |
| # Get relevant context if personal query | |
| context = "" | |
| if is_personal: | |
| context = utils.get_relevant_context(query) | |
| # Generate response | |
| result = utils.generate_response(query) | |
| return { | |
| "response": result["response"], | |
| "type": query_type, | |
| "context_used": context[:200] + "..." if context else "None" | |
| } | |
| async def chat_endpoint(request: ChatRequest): | |
| """Endpoint to handle chat requests""" | |
| start_time = time.time() | |
| # Log the incoming request | |
| logger.info(f"Received query: {request.query}") | |
| if not request.query or request.query.strip() == "": | |
| raise HTTPException(status_code=400, detail="Query cannot be empty") | |
| # Generate response with or without LangSmith tracing | |
| if langsmith_tracing: | |
| try: | |
| trace_result = process_chat_request(request.query) | |
| result = {"response": trace_result["response"], "type": trace_result["type"]} | |
| except Exception as e: | |
| logger.error(f"Error in LangSmith traced processing: {e}") | |
| # Fall back to non-traced processing | |
| result = utils.generate_response(request.query) | |
| else: | |
| # Standard processing without tracing | |
| result = utils.generate_response(request.query) | |
| # Calculate processing time | |
| processing_time = time.time() - start_time | |
| # Log the response | |
| logger.info(f"Response type: {result['type']}, Processing time: {processing_time:.2f}s") | |
| return ChatResponse( | |
| response=result["response"], | |
| type=result["type"], | |
| processing_time=processing_time, | |
| format="markdown" # Always return markdown format | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "ok", "embeddings_loaded": utils.embeddings is not None} | |
| # Add this below your existing health check endpoint | |
| async def head_request(): | |
| """HEAD request endpoint to check if server is running""" | |
| # This returns just the headers, no body content | |
| return {} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) | |