ChatbotDemo / app /app.py
Kalpokoch's picture
improvements dec
0194a83
raw
history blame
23.5 kB
# app.py - Updated with better request handling
import os
import json
import asyncio
import logging
import uuid
import re
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
from pydantic import BaseModel
from llama_cpp import Llama
# Correctly reference the module within the 'app' package
from app.policy_vector_db import PolicyVectorDB, ensure_db_populated
# -----------------------------
# βœ… Logging Configuration
# -----------------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s')
class RequestIdAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
return '[%s] %s' % (self.extra['request_id'], msg), kwargs
logger = logging.getLogger("app")
# -----------------------------
# βœ… Queue Management Classes
# -----------------------------
class QueuedRequest:
def __init__(self, request_id: str, question: str, timestamp: datetime):
self.request_id = request_id
self.question = question
self.timestamp = timestamp
self.status = "queued" # queued, processing, completed, failed, timeout, cancelled
self.result: Optional[Dict] = None
self.error: Optional[str] = None
self.cancelled = False # Track if request was cancelled
self.last_accessed = datetime.now() # Track when status was last checked
class RequestQueue:
def __init__(self, max_size: int = 15):
self.queue: List[QueuedRequest] = []
self.processing: Optional[QueuedRequest] = None
self.completed_requests: Dict[str, QueuedRequest] = {}
self.max_size = max_size
self.lock = asyncio.Lock()
self.cleanup_interval = 300 # 5 minutes
self.max_completed_age = 600 # 10 minutes
async def add_request(self, request_id: str, question: str) -> Dict:
async with self.lock:
# Clean up old requests periodically
await self._cleanup_old_requests()
if len(self.queue) >= self.max_size:
return {
"status": "queue_full",
"message": f"Queue is full (max {self.max_size} requests). Please try again later.",
"queue_position": None,
"estimated_wait_time": None
}
queued_request = QueuedRequest(request_id, question, datetime.now())
# Always enqueue; the background worker is the single executor
self.queue.append(queued_request)
position = len(self.queue) # 1-based position in queue
# Compute an ETA that reflects whether something is currently processing
# ahead_of_you = (1 if a job is currently processing else 0) + (position - 1 already queued ahead)
ahead_of_you = (1 if self.processing else 0) + (position - 1)
# Each slot ~2 minutes based on your heuristic
estimated_wait = f"{ahead_of_you * 2}-{(ahead_of_you + 1) * 2} minutes"
message = (
"Using free CPU tier - can only process one request at a time. "
f"Your request is #{position} in queue and will be processed after current requests are completed."
)
return {
"status": "queued",
"message": message,
"queue_position": position,
"estimated_wait_time": estimated_wait
}
async def get_next_request(self) -> Optional[QueuedRequest]:
async with self.lock:
if self.queue:
next_request = self.queue.pop(0)
self.processing = next_request
next_request.status = "processing"
return next_request
return None
async def complete_request(self, request_id: str, result: Dict = None, error: str = None):
async with self.lock:
if self.processing and self.processing.request_id == request_id:
if self.processing.cancelled:
# Don't store results for cancelled requests
self.processing.status = "cancelled"
logger.info(f"Request {request_id} was cancelled, not storing result")
elif result:
self.processing.result = result
self.processing.status = "completed"
elif error:
self.processing.error = error
self.processing.status = "failed"
# Store completed request for result retrieval (even cancelled ones briefly)
self.completed_requests[request_id] = self.processing
self.processing = None
async def cancel_request(self, request_id: str) -> bool:
"""Cancel a request if it exists in queue or is processing"""
async with self.lock:
# Check if it's currently processing
if self.processing and self.processing.request_id == request_id:
self.processing.cancelled = True
logger.info(f"Marked processing request {request_id} as cancelled")
return True
# Check if it's in queue
for i, req in enumerate(self.queue):
if req.request_id == request_id:
cancelled_req = self.queue.pop(i)
cancelled_req.status = "cancelled"
cancelled_req.cancelled = True
self.completed_requests[request_id] = cancelled_req
logger.info(f"Cancelled queued request {request_id}")
return True
return False
async def get_request_status(self, request_id: str) -> Optional[Dict]:
async with self.lock:
# Update last accessed time for any request we're checking
current_time = datetime.now()
# Check if currently processing
if self.processing and self.processing.request_id == request_id:
self.processing.last_accessed = current_time
if self.processing.cancelled:
return {
"status": "cancelled",
"message": "Request was cancelled.",
"result": None,
"error": "Request cancelled by user"
}
return {
"status": self.processing.status,
"message": "Your request is currently being processed.",
"result": self.processing.result
}
# Check completed requests
if request_id in self.completed_requests:
req = self.completed_requests[request_id]
req.last_accessed = current_time
status_messages = {
"completed": "Request completed.",
"failed": "Request failed.",
"cancelled": "Request was cancelled.",
"timeout": "Request timed out."
}
return {
"status": req.status,
"message": status_messages.get(req.status, "Request processed."),
"result": req.result,
"error": req.error
}
# Check queue
for i, req in enumerate(self.queue):
if req.request_id == request_id:
req.last_accessed = current_time
return {
"status": "queued",
"message": f"Your request is #{i+1} in queue.",
"queue_position": i + 1,
"estimated_wait_time": f"{(i+1) * 2}-{(i+2) * 2} minutes"
}
return None
async def _cleanup_old_requests(self):
"""Clean up old completed requests and abandoned requests"""
current_time = datetime.now()
cutoff_time = current_time - timedelta(seconds=self.max_completed_age)
# Clean up old completed requests
to_remove = []
for request_id, req in self.completed_requests.items():
if req.last_accessed < cutoff_time:
to_remove.append(request_id)
for request_id in to_remove:
del self.completed_requests[request_id]
logger.info(f"Cleaned up old request: {request_id}")
async def get_queue_info(self) -> Dict:
async with self.lock:
return {
"queue_length": len(self.queue),
"currently_processing": self.processing.request_id if self.processing else None,
"max_queue_size": self.max_size,
"completed_requests_count": len(self.completed_requests)
}
# -----------------------------
# βœ… Configuration
# -----------------------------
DB_PERSIST_DIRECTORY = os.getenv("DB_PERSIST_DIRECTORY", "/app/vector_database")
CHUNKS_FILE_PATH = os.getenv("CHUNKS_FILE_PATH", "/app/granular_chunks_final.jsonl")
MODEL_PATH = os.getenv("MODEL_PATH", "/app/tinyllama_dop_q4_k_m.gguf")
LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT_SECONDS", "120"))
RELEVANCE_THRESHOLD = float(os.getenv("RELEVANCE_THRESHOLD", "0.3"))
TOP_K_SEARCH = int(os.getenv("TOP_K_SEARCH", "4"))
TOP_K_CONTEXT = int(os.getenv("TOP_K_CONTEXT", "2"))
MAX_QUEUE_SIZE = int(os.getenv("MAX_QUEUE_SIZE", "15"))
# -----------------------------
# βœ… Initialize FastAPI App
# -----------------------------
app = FastAPI(title="NEEPCO DoP RAG Chatbot", version="2.1.0")
# Initialize request queue
request_queue = RequestQueue(max_size=MAX_QUEUE_SIZE)
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
# -----------------------------
# βœ… Vector DB and Data Initialization
# -----------------------------
logger.info("Initializing vector DB...")
try:
db = PolicyVectorDB(
persist_directory=DB_PERSIST_DIRECTORY,
top_k_default=TOP_K_SEARCH,
relevance_threshold=RELEVANCE_THRESHOLD
)
if not ensure_db_populated(db, CHUNKS_FILE_PATH):
logger.warning("DB not populated on startup. RAG will not function correctly.")
db_ready = False
else:
logger.info("Vector DB is populated and ready.")
db_ready = True
except Exception as e:
logger.error(f"FATAL: Failed to initialize Vector DB: {e}", exc_info=True)
db = None
db_ready = False
# -----------------------------
# βœ… Load TinyLlama GGUF Model
# -----------------------------
logger.info(f"Loading GGUF model from: {MODEL_PATH}")
try:
llm = Llama(
model_path=MODEL_PATH,
n_ctx=2048,
n_threads=1,
n_batch=512,
use_mlock=True,
verbose=False
)
logger.info("GGUF model loaded successfully.")
model_ready = True
except Exception as e:
logger.error(f"FATAL: Failed to load GGUF model: {e}", exc_info=True)
llm = None
model_ready = False
# -----------------------------
# βœ… API Schemas
# -----------------------------
class Query(BaseModel):
question: str
class Feedback(BaseModel):
request_id: str
question: str
answer: str
context_used: str
feedback: str
comment: str | None = None
# -----------------------------
# βœ… Background Processing
# -----------------------------
async def process_queued_requests():
"""Background task to process queued requests"""
while True:
try:
next_request = await request_queue.get_next_request()
if next_request:
logger.info(f"Processing queued request: {next_request.request_id}")
try:
# Check if request was cancelled before processing
if next_request.cancelled:
logger.info(f"Skipping cancelled request: {next_request.request_id}")
await request_queue.complete_request(
next_request.request_id,
error="Request was cancelled"
)
continue
result = await process_chat_request(next_request.question, next_request.request_id)
# Check again if request was cancelled during processing
if next_request.cancelled:
logger.info(f"Request was cancelled during processing: {next_request.request_id}")
await request_queue.complete_request(
next_request.request_id,
error="Request was cancelled during processing"
)
else:
await request_queue.complete_request(next_request.request_id, result=result)
logger.info(f"Completed request: {next_request.request_id}")
except Exception as e:
error_msg = f"Error processing request: {str(e)}"
logger.error(f"Failed to process request {next_request.request_id}: {e}", exc_info=True)
await request_queue.complete_request(next_request.request_id, error=error_msg)
else:
# No requests in queue, wait a bit before checking again
await asyncio.sleep(2)
except Exception as e:
logger.error(f"Error in background processor: {e}", exc_info=True)
await asyncio.sleep(5)
# Start background processor
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_queued_requests())
# -----------------------------
# βœ… Core Processing Function
# βœ… Re-ranking function for improving relevance
def re_rank_by_relevance(results: List[Dict], question: str) -> List[Dict]:
"""Simple heuristic re-ranking based on question keyword overlap"""
question_terms = set(term.lower() for term in question.split() if len(term) > 3)
for result in results:
chunk_terms = set(term.lower() for term in result['text'].split() if len(term) > 3)
if question_terms:
keyword_overlap = len(question_terms & chunk_terms) / len(question_terms)
else:
keyword_overlap = 0
# Boost score if chunk contains question keywords
result['relevance_score'] *= (1 + 0.15 * keyword_overlap)
return sorted(results, key=lambda x: x['relevance_score'], reverse=True)
def get_logger_adapter(request_id: str):
return RequestIdAdapter(logger, {'request_id': request_id})
async def generate_llm_response(prompt: str, request_id: str):
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
lambda: llm(
prompt,
max_tokens=512, # Optimized for CPU performance
stop=["###", "Question:", "Context:", "</s>"],
temperature=0.1, # Lower for factuality
top_p=0.9, # Nucleus sampling for consistency
echo=False
)
)
answer = response["choices"][0]["text"].strip()
if not answer:
raise ValueError("Empty response from LLM")
return answer
async def process_chat_request(question: str, request_id: str) -> Dict:
"""Core chat processing logic extracted for reuse"""
adapter = get_logger_adapter(request_id)
question_lower = question.strip().lower()
# --- GREETING & INTRO HANDLING ---
greeting_keywords = ["hello", "hi", "hey", "what can you do", "who are you"]
if question_lower in greeting_keywords:
adapter.info(f"Handling a greeting or introductory query: '{question}'")
intro_message = (
"Hello! I am an AI assistant specifically trained on NEEPCO's Delegation of Powers (DoP) policy document. "
"My purpose is to help you find accurate information and answer questions based on this specific dataset. "
"I am currently running on a CPU-based environment. How can I assist you with the DoP policy today?"
)
return {
"request_id": request_id,
"question": question,
"context_used": "NA - Greeting",
"answer": intro_message
}
if not db_ready or not model_ready:
adapter.error("Service unavailable due to initialization failure.")
raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.")
adapter.info(f"Received query: '{question}'")
# 1. Search Vector DB with query expansion
search_results = db.search(question, top_k=TOP_K_SEARCH)
# 2. Re-rank results by keyword overlap for better relevance
search_results = re_rank_by_relevance(search_results, question)
if not search_results:
adapter.warning("No relevant context found in vector DB.")
return {
"request_id": request_id,
"question": question,
"context_used": "No relevant context found.",
"answer": "Sorry, I could not find a relevant policy to answer that question. Please try rephrasing."
}
scores = [f"{result['relevance_score']:.4f}" for result in search_results]
adapter.info(f"Found {len(search_results)} relevant chunks with scores: {scores}")
# 3. Prepare Context
context_chunks = [result['text'] for result in search_results[:TOP_K_CONTEXT]]
context = "\n---\n".join(context_chunks)
# 4. Build Enhanced Prompt
prompt = f"""<|system|>
You are NEEPCO's Delegation of Powers (DoP) policy expert. Answer ONLY using the provided context.
- Be concise and factual
- For lists/steps, use pipe separators: `Item1|Item2|Item3`
- If information is absent, say: "The provided policy context does not contain information on this topic."
- Do not assume or infer beyond what is stated
</s>
<|user|>
### Context:
{context}
### Question:
{question}
Answer based strictly on the context above.
</s>
<|assistant|>
"""
# 5. Generate Response
answer = "An error occurred while processing your request."
try:
adapter.info("Sending prompt to LLM for generation...")
raw_answer = await asyncio.wait_for(
generate_llm_response(prompt, request_id),
timeout=LLM_TIMEOUT_SECONDS
)
adapter.info(f"LLM generation successful. Raw response: {raw_answer[:250]}...")
# --- POST-PROCESSING LOGIC ---
# Check if the model used the pipe separator, indicating a list.
if '|' in raw_answer:
adapter.info("Pipe separator found. Formatting response as a bulleted list.")
# Split the string into a list of items
items = raw_answer.split('|')
# Clean up each item and format it as a bullet point
cleaned_items = [f"* {item.strip()}" for item in items if item.strip()]
# Join them back together with newlines
answer = "\n".join(cleaned_items)
else:
# If no separator, use the answer as is.
answer = raw_answer
except asyncio.TimeoutError:
adapter.warning(f"LLM generation timed out after {LLM_TIMEOUT_SECONDS} seconds.")
answer = "Sorry, the request took too long to process. Please try again with a simpler question."
except Exception as e:
adapter.error(f"An unexpected error occurred during LLM generation: {e}", exc_info=True)
answer = "Sorry, an unexpected error occurred while generating a response."
adapter.info(f"Final answer prepared. Returning result.")
return {
"request_id": request_id,
"question": question,
"context_used": context,
"answer": answer
}
# -----------------------------
# βœ… Endpoints
# -----------------------------
@app.get("/")
async def root():
return {"status": "βœ… Server is running."}
@app.get("/health")
async def health_check():
queue_info = await request_queue.get_queue_info()
status = {
"status": "ok",
"database_status": "ready" if db_ready else "error",
"model_status": "ready" if model_ready else "error",
"queue_info": queue_info
}
if not db_ready or not model_ready:
raise HTTPException(status_code=503, detail=status)
return status
@app.post("/chat")
async def chat(query: Query, request: Request):
"""Add request to queue and return queue status"""
if not db_ready or not model_ready:
raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.")
request_id = request.state.request_id
adapter = get_logger_adapter(request_id)
adapter.info(f"Received chat request: '{query.question}'")
# Add request to queue
queue_status = await request_queue.add_request(request_id, query.question)
return {
"request_id": request_id,
"question": query.question,
**queue_status
}
@app.get("/status/{request_id}")
async def get_request_status(request_id: str):
"""Check the status of a specific request"""
try:
status = await request_queue.get_request_status(request_id)
if not status:
raise HTTPException(status_code=404, detail="Request not found")
return {
"request_id": request_id,
**status
}
except Exception as e:
logger.error(f"Error checking status for {request_id}: {e}")
raise HTTPException(status_code=500, detail="Error checking request status")
@app.delete("/cancel/{request_id}")
async def cancel_request(request_id: str):
"""Cancel a specific request"""
try:
cancelled = await request_queue.cancel_request(request_id)
if not cancelled:
raise HTTPException(status_code=404, detail="Request not found or cannot be cancelled")
return {
"status": "cancelled",
"message": f"Request {request_id} has been cancelled",
"request_id": request_id
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling request {request_id}: {e}")
raise HTTPException(status_code=500, detail="Error cancelling request")
@app.get("/queue")
async def get_queue_status():
"""Get current queue information"""
return await request_queue.get_queue_info()
@app.post("/feedback")
async def collect_feedback(feedback: Feedback, request: Request):
adapter = get_logger_adapter(request.state.request_id)
feedback_log = {
"type": "USER_FEEDBACK",
"request_id": feedback.request_id,
"question": feedback.question,
"answer": feedback.answer,
"context_used": feedback.context_used,
"feedback": feedback.feedback,
"comment": feedback.comment
}
adapter.info(json.dumps(feedback_log))
return {"status": "βœ… Feedback recorded. Thank you!"}