Spaces:
Running
Running
| # app.py - Updated with better request handling | |
| import os | |
| import json | |
| import asyncio | |
| import logging | |
| import uuid | |
| import re | |
| from typing import Dict, List, Optional | |
| from datetime import datetime, timedelta | |
| from fastapi import FastAPI, HTTPException, Request, BackgroundTasks | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| # Correctly reference the module within the 'app' package | |
| from app.policy_vector_db import PolicyVectorDB, ensure_db_populated | |
| # ----------------------------- | |
| # β Logging Configuration | |
| # ----------------------------- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s') | |
| class RequestIdAdapter(logging.LoggerAdapter): | |
| def process(self, msg, kwargs): | |
| return '[%s] %s' % (self.extra['request_id'], msg), kwargs | |
| logger = logging.getLogger("app") | |
| # ----------------------------- | |
| # β Queue Management Classes | |
| # ----------------------------- | |
| class QueuedRequest: | |
| def __init__(self, request_id: str, question: str, timestamp: datetime): | |
| self.request_id = request_id | |
| self.question = question | |
| self.timestamp = timestamp | |
| self.status = "queued" # queued, processing, completed, failed, timeout, cancelled | |
| self.result: Optional[Dict] = None | |
| self.error: Optional[str] = None | |
| self.cancelled = False # Track if request was cancelled | |
| self.last_accessed = datetime.now() # Track when status was last checked | |
| class RequestQueue: | |
| def __init__(self, max_size: int = 15): | |
| self.queue: List[QueuedRequest] = [] | |
| self.processing: Optional[QueuedRequest] = None | |
| self.completed_requests: Dict[str, QueuedRequest] = {} | |
| self.max_size = max_size | |
| self.lock = asyncio.Lock() | |
| self.cleanup_interval = 300 # 5 minutes | |
| self.max_completed_age = 600 # 10 minutes | |
| async def add_request(self, request_id: str, question: str) -> Dict: | |
| async with self.lock: | |
| # Clean up old requests periodically | |
| await self._cleanup_old_requests() | |
| if len(self.queue) >= self.max_size: | |
| return { | |
| "status": "queue_full", | |
| "message": f"Queue is full (max {self.max_size} requests). Please try again later.", | |
| "queue_position": None, | |
| "estimated_wait_time": None | |
| } | |
| queued_request = QueuedRequest(request_id, question, datetime.now()) | |
| # Always enqueue; the background worker is the single executor | |
| self.queue.append(queued_request) | |
| position = len(self.queue) # 1-based position in queue | |
| # Compute an ETA that reflects whether something is currently processing | |
| # ahead_of_you = (1 if a job is currently processing else 0) + (position - 1 already queued ahead) | |
| ahead_of_you = (1 if self.processing else 0) + (position - 1) | |
| # Each slot ~2 minutes based on your heuristic | |
| estimated_wait = f"{ahead_of_you * 2}-{(ahead_of_you + 1) * 2} minutes" | |
| message = ( | |
| "Using free CPU tier - can only process one request at a time. " | |
| f"Your request is #{position} in queue and will be processed after current requests are completed." | |
| ) | |
| return { | |
| "status": "queued", | |
| "message": message, | |
| "queue_position": position, | |
| "estimated_wait_time": estimated_wait | |
| } | |
| async def get_next_request(self) -> Optional[QueuedRequest]: | |
| async with self.lock: | |
| if self.queue: | |
| next_request = self.queue.pop(0) | |
| self.processing = next_request | |
| next_request.status = "processing" | |
| return next_request | |
| return None | |
| async def complete_request(self, request_id: str, result: Dict = None, error: str = None): | |
| async with self.lock: | |
| if self.processing and self.processing.request_id == request_id: | |
| if self.processing.cancelled: | |
| # Don't store results for cancelled requests | |
| self.processing.status = "cancelled" | |
| logger.info(f"Request {request_id} was cancelled, not storing result") | |
| elif result: | |
| self.processing.result = result | |
| self.processing.status = "completed" | |
| elif error: | |
| self.processing.error = error | |
| self.processing.status = "failed" | |
| # Store completed request for result retrieval (even cancelled ones briefly) | |
| self.completed_requests[request_id] = self.processing | |
| self.processing = None | |
| async def cancel_request(self, request_id: str) -> bool: | |
| """Cancel a request if it exists in queue or is processing""" | |
| async with self.lock: | |
| # Check if it's currently processing | |
| if self.processing and self.processing.request_id == request_id: | |
| self.processing.cancelled = True | |
| logger.info(f"Marked processing request {request_id} as cancelled") | |
| return True | |
| # Check if it's in queue | |
| for i, req in enumerate(self.queue): | |
| if req.request_id == request_id: | |
| cancelled_req = self.queue.pop(i) | |
| cancelled_req.status = "cancelled" | |
| cancelled_req.cancelled = True | |
| self.completed_requests[request_id] = cancelled_req | |
| logger.info(f"Cancelled queued request {request_id}") | |
| return True | |
| return False | |
| async def get_request_status(self, request_id: str) -> Optional[Dict]: | |
| async with self.lock: | |
| # Update last accessed time for any request we're checking | |
| current_time = datetime.now() | |
| # Check if currently processing | |
| if self.processing and self.processing.request_id == request_id: | |
| self.processing.last_accessed = current_time | |
| if self.processing.cancelled: | |
| return { | |
| "status": "cancelled", | |
| "message": "Request was cancelled.", | |
| "result": None, | |
| "error": "Request cancelled by user" | |
| } | |
| return { | |
| "status": self.processing.status, | |
| "message": "Your request is currently being processed.", | |
| "result": self.processing.result | |
| } | |
| # Check completed requests | |
| if request_id in self.completed_requests: | |
| req = self.completed_requests[request_id] | |
| req.last_accessed = current_time | |
| status_messages = { | |
| "completed": "Request completed.", | |
| "failed": "Request failed.", | |
| "cancelled": "Request was cancelled.", | |
| "timeout": "Request timed out." | |
| } | |
| return { | |
| "status": req.status, | |
| "message": status_messages.get(req.status, "Request processed."), | |
| "result": req.result, | |
| "error": req.error | |
| } | |
| # Check queue | |
| for i, req in enumerate(self.queue): | |
| if req.request_id == request_id: | |
| req.last_accessed = current_time | |
| return { | |
| "status": "queued", | |
| "message": f"Your request is #{i+1} in queue.", | |
| "queue_position": i + 1, | |
| "estimated_wait_time": f"{(i+1) * 2}-{(i+2) * 2} minutes" | |
| } | |
| return None | |
| async def _cleanup_old_requests(self): | |
| """Clean up old completed requests and abandoned requests""" | |
| current_time = datetime.now() | |
| cutoff_time = current_time - timedelta(seconds=self.max_completed_age) | |
| # Clean up old completed requests | |
| to_remove = [] | |
| for request_id, req in self.completed_requests.items(): | |
| if req.last_accessed < cutoff_time: | |
| to_remove.append(request_id) | |
| for request_id in to_remove: | |
| del self.completed_requests[request_id] | |
| logger.info(f"Cleaned up old request: {request_id}") | |
| async def get_queue_info(self) -> Dict: | |
| async with self.lock: | |
| return { | |
| "queue_length": len(self.queue), | |
| "currently_processing": self.processing.request_id if self.processing else None, | |
| "max_queue_size": self.max_size, | |
| "completed_requests_count": len(self.completed_requests) | |
| } | |
| # ----------------------------- | |
| # β Configuration | |
| # ----------------------------- | |
| DB_PERSIST_DIRECTORY = os.getenv("DB_PERSIST_DIRECTORY", "/app/vector_database") | |
| CHUNKS_FILE_PATH = os.getenv("CHUNKS_FILE_PATH", "/app/granular_chunks_final.jsonl") | |
| MODEL_PATH = os.getenv("MODEL_PATH", "/app/tinyllama_dop_q4_k_m.gguf") | |
| LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT_SECONDS", "120")) | |
| RELEVANCE_THRESHOLD = float(os.getenv("RELEVANCE_THRESHOLD", "0.3")) | |
| TOP_K_SEARCH = int(os.getenv("TOP_K_SEARCH", "4")) | |
| TOP_K_CONTEXT = int(os.getenv("TOP_K_CONTEXT", "2")) | |
| MAX_QUEUE_SIZE = int(os.getenv("MAX_QUEUE_SIZE", "15")) | |
| # ----------------------------- | |
| # β Initialize FastAPI App | |
| # ----------------------------- | |
| app = FastAPI(title="NEEPCO DoP RAG Chatbot", version="2.1.0") | |
| # Initialize request queue | |
| request_queue = RequestQueue(max_size=MAX_QUEUE_SIZE) | |
| async def add_request_id(request: Request, call_next): | |
| request_id = str(uuid.uuid4()) | |
| request.state.request_id = request_id | |
| response = await call_next(request) | |
| response.headers["X-Request-ID"] = request_id | |
| return response | |
| # ----------------------------- | |
| # β Vector DB and Data Initialization | |
| # ----------------------------- | |
| logger.info("Initializing vector DB...") | |
| try: | |
| db = PolicyVectorDB( | |
| persist_directory=DB_PERSIST_DIRECTORY, | |
| top_k_default=TOP_K_SEARCH, | |
| relevance_threshold=RELEVANCE_THRESHOLD | |
| ) | |
| if not ensure_db_populated(db, CHUNKS_FILE_PATH): | |
| logger.warning("DB not populated on startup. RAG will not function correctly.") | |
| db_ready = False | |
| else: | |
| logger.info("Vector DB is populated and ready.") | |
| db_ready = True | |
| except Exception as e: | |
| logger.error(f"FATAL: Failed to initialize Vector DB: {e}", exc_info=True) | |
| db = None | |
| db_ready = False | |
| # ----------------------------- | |
| # β Load TinyLlama GGUF Model | |
| # ----------------------------- | |
| logger.info(f"Loading GGUF model from: {MODEL_PATH}") | |
| try: | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=2048, | |
| n_threads=1, | |
| n_batch=512, | |
| use_mlock=True, | |
| verbose=False | |
| ) | |
| logger.info("GGUF model loaded successfully.") | |
| model_ready = True | |
| except Exception as e: | |
| logger.error(f"FATAL: Failed to load GGUF model: {e}", exc_info=True) | |
| llm = None | |
| model_ready = False | |
| # ----------------------------- | |
| # β API Schemas | |
| # ----------------------------- | |
| class Query(BaseModel): | |
| question: str | |
| class Feedback(BaseModel): | |
| request_id: str | |
| question: str | |
| answer: str | |
| context_used: str | |
| feedback: str | |
| comment: str | None = None | |
| # ----------------------------- | |
| # β Background Processing | |
| # ----------------------------- | |
| async def process_queued_requests(): | |
| """Background task to process queued requests""" | |
| while True: | |
| try: | |
| next_request = await request_queue.get_next_request() | |
| if next_request: | |
| logger.info(f"Processing queued request: {next_request.request_id}") | |
| try: | |
| # Check if request was cancelled before processing | |
| if next_request.cancelled: | |
| logger.info(f"Skipping cancelled request: {next_request.request_id}") | |
| await request_queue.complete_request( | |
| next_request.request_id, | |
| error="Request was cancelled" | |
| ) | |
| continue | |
| result = await process_chat_request(next_request.question, next_request.request_id) | |
| # Check again if request was cancelled during processing | |
| if next_request.cancelled: | |
| logger.info(f"Request was cancelled during processing: {next_request.request_id}") | |
| await request_queue.complete_request( | |
| next_request.request_id, | |
| error="Request was cancelled during processing" | |
| ) | |
| else: | |
| await request_queue.complete_request(next_request.request_id, result=result) | |
| logger.info(f"Completed request: {next_request.request_id}") | |
| except Exception as e: | |
| error_msg = f"Error processing request: {str(e)}" | |
| logger.error(f"Failed to process request {next_request.request_id}: {e}", exc_info=True) | |
| await request_queue.complete_request(next_request.request_id, error=error_msg) | |
| else: | |
| # No requests in queue, wait a bit before checking again | |
| await asyncio.sleep(2) | |
| except Exception as e: | |
| logger.error(f"Error in background processor: {e}", exc_info=True) | |
| await asyncio.sleep(5) | |
| # Start background processor | |
| async def startup_event(): | |
| asyncio.create_task(process_queued_requests()) | |
| # ----------------------------- | |
| # β Core Processing Function | |
| # β Re-ranking function for improving relevance | |
| def re_rank_by_relevance(results: List[Dict], question: str) -> List[Dict]: | |
| """Simple heuristic re-ranking based on question keyword overlap""" | |
| question_terms = set(term.lower() for term in question.split() if len(term) > 3) | |
| for result in results: | |
| chunk_terms = set(term.lower() for term in result['text'].split() if len(term) > 3) | |
| if question_terms: | |
| keyword_overlap = len(question_terms & chunk_terms) / len(question_terms) | |
| else: | |
| keyword_overlap = 0 | |
| # Boost score if chunk contains question keywords | |
| result['relevance_score'] *= (1 + 0.15 * keyword_overlap) | |
| return sorted(results, key=lambda x: x['relevance_score'], reverse=True) | |
| def get_logger_adapter(request_id: str): | |
| return RequestIdAdapter(logger, {'request_id': request_id}) | |
| async def generate_llm_response(prompt: str, request_id: str): | |
| loop = asyncio.get_running_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| lambda: llm( | |
| prompt, | |
| max_tokens=512, # Optimized for CPU performance | |
| stop=["###", "Question:", "Context:", "</s>"], | |
| temperature=0.1, # Lower for factuality | |
| top_p=0.9, # Nucleus sampling for consistency | |
| echo=False | |
| ) | |
| ) | |
| answer = response["choices"][0]["text"].strip() | |
| if not answer: | |
| raise ValueError("Empty response from LLM") | |
| return answer | |
| async def process_chat_request(question: str, request_id: str) -> Dict: | |
| """Core chat processing logic extracted for reuse""" | |
| adapter = get_logger_adapter(request_id) | |
| question_lower = question.strip().lower() | |
| # --- GREETING & INTRO HANDLING --- | |
| greeting_keywords = ["hello", "hi", "hey", "what can you do", "who are you"] | |
| if question_lower in greeting_keywords: | |
| adapter.info(f"Handling a greeting or introductory query: '{question}'") | |
| intro_message = ( | |
| "Hello! I am an AI assistant specifically trained on NEEPCO's Delegation of Powers (DoP) policy document. " | |
| "My purpose is to help you find accurate information and answer questions based on this specific dataset. " | |
| "I am currently running on a CPU-based environment. How can I assist you with the DoP policy today?" | |
| ) | |
| return { | |
| "request_id": request_id, | |
| "question": question, | |
| "context_used": "NA - Greeting", | |
| "answer": intro_message | |
| } | |
| if not db_ready or not model_ready: | |
| adapter.error("Service unavailable due to initialization failure.") | |
| raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.") | |
| adapter.info(f"Received query: '{question}'") | |
| # 1. Search Vector DB with query expansion | |
| search_results = db.search(question, top_k=TOP_K_SEARCH) | |
| # 2. Re-rank results by keyword overlap for better relevance | |
| search_results = re_rank_by_relevance(search_results, question) | |
| if not search_results: | |
| adapter.warning("No relevant context found in vector DB.") | |
| return { | |
| "request_id": request_id, | |
| "question": question, | |
| "context_used": "No relevant context found.", | |
| "answer": "Sorry, I could not find a relevant policy to answer that question. Please try rephrasing." | |
| } | |
| scores = [f"{result['relevance_score']:.4f}" for result in search_results] | |
| adapter.info(f"Found {len(search_results)} relevant chunks with scores: {scores}") | |
| # 3. Prepare Context | |
| context_chunks = [result['text'] for result in search_results[:TOP_K_CONTEXT]] | |
| context = "\n---\n".join(context_chunks) | |
| # 4. Build Enhanced Prompt | |
| prompt = f"""<|system|> | |
| You are NEEPCO's Delegation of Powers (DoP) policy expert. Answer ONLY using the provided context. | |
| - Be concise and factual | |
| - For lists/steps, use pipe separators: `Item1|Item2|Item3` | |
| - If information is absent, say: "The provided policy context does not contain information on this topic." | |
| - Do not assume or infer beyond what is stated | |
| </s> | |
| <|user|> | |
| ### Context: | |
| {context} | |
| ### Question: | |
| {question} | |
| Answer based strictly on the context above. | |
| </s> | |
| <|assistant|> | |
| """ | |
| # 5. Generate Response | |
| answer = "An error occurred while processing your request." | |
| try: | |
| adapter.info("Sending prompt to LLM for generation...") | |
| raw_answer = await asyncio.wait_for( | |
| generate_llm_response(prompt, request_id), | |
| timeout=LLM_TIMEOUT_SECONDS | |
| ) | |
| adapter.info(f"LLM generation successful. Raw response: {raw_answer[:250]}...") | |
| # --- POST-PROCESSING LOGIC --- | |
| # Check if the model used the pipe separator, indicating a list. | |
| if '|' in raw_answer: | |
| adapter.info("Pipe separator found. Formatting response as a bulleted list.") | |
| # Split the string into a list of items | |
| items = raw_answer.split('|') | |
| # Clean up each item and format it as a bullet point | |
| cleaned_items = [f"* {item.strip()}" for item in items if item.strip()] | |
| # Join them back together with newlines | |
| answer = "\n".join(cleaned_items) | |
| else: | |
| # If no separator, use the answer as is. | |
| answer = raw_answer | |
| except asyncio.TimeoutError: | |
| adapter.warning(f"LLM generation timed out after {LLM_TIMEOUT_SECONDS} seconds.") | |
| answer = "Sorry, the request took too long to process. Please try again with a simpler question." | |
| except Exception as e: | |
| adapter.error(f"An unexpected error occurred during LLM generation: {e}", exc_info=True) | |
| answer = "Sorry, an unexpected error occurred while generating a response." | |
| adapter.info(f"Final answer prepared. Returning result.") | |
| return { | |
| "request_id": request_id, | |
| "question": question, | |
| "context_used": context, | |
| "answer": answer | |
| } | |
| # ----------------------------- | |
| # β Endpoints | |
| # ----------------------------- | |
| async def root(): | |
| return {"status": "β Server is running."} | |
| async def health_check(): | |
| queue_info = await request_queue.get_queue_info() | |
| status = { | |
| "status": "ok", | |
| "database_status": "ready" if db_ready else "error", | |
| "model_status": "ready" if model_ready else "error", | |
| "queue_info": queue_info | |
| } | |
| if not db_ready or not model_ready: | |
| raise HTTPException(status_code=503, detail=status) | |
| return status | |
| async def chat(query: Query, request: Request): | |
| """Add request to queue and return queue status""" | |
| if not db_ready or not model_ready: | |
| raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.") | |
| request_id = request.state.request_id | |
| adapter = get_logger_adapter(request_id) | |
| adapter.info(f"Received chat request: '{query.question}'") | |
| # Add request to queue | |
| queue_status = await request_queue.add_request(request_id, query.question) | |
| return { | |
| "request_id": request_id, | |
| "question": query.question, | |
| **queue_status | |
| } | |
| async def get_request_status(request_id: str): | |
| """Check the status of a specific request""" | |
| try: | |
| status = await request_queue.get_request_status(request_id) | |
| if not status: | |
| raise HTTPException(status_code=404, detail="Request not found") | |
| return { | |
| "request_id": request_id, | |
| **status | |
| } | |
| except Exception as e: | |
| logger.error(f"Error checking status for {request_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Error checking request status") | |
| async def cancel_request(request_id: str): | |
| """Cancel a specific request""" | |
| try: | |
| cancelled = await request_queue.cancel_request(request_id) | |
| if not cancelled: | |
| raise HTTPException(status_code=404, detail="Request not found or cannot be cancelled") | |
| return { | |
| "status": "cancelled", | |
| "message": f"Request {request_id} has been cancelled", | |
| "request_id": request_id | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error cancelling request {request_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Error cancelling request") | |
| async def get_queue_status(): | |
| """Get current queue information""" | |
| return await request_queue.get_queue_info() | |
| async def collect_feedback(feedback: Feedback, request: Request): | |
| adapter = get_logger_adapter(request.state.request_id) | |
| feedback_log = { | |
| "type": "USER_FEEDBACK", | |
| "request_id": feedback.request_id, | |
| "question": feedback.question, | |
| "answer": feedback.answer, | |
| "context_used": feedback.context_used, | |
| "feedback": feedback.feedback, | |
| "comment": feedback.comment | |
| } | |
| adapter.info(json.dumps(feedback_log)) | |
| return {"status": "β Feedback recorded. Thank you!"} |