ChatbotRAG / hybrid_chat_endpoint.py
minhvtt's picture
Upload 36 files
ffb5f88 verified
"""
Hybrid Chat Endpoint: RAG + Scenario FSM
Routes between scripted scenarios and knowledge retrieval
"""
from fastapi import HTTPException
from datetime import datetime
from typing import Dict, Any
# Import scenario handlers
from scenario_handlers.price_inquiry import PriceInquiryHandler
from scenario_handlers.event_recommendation import EventRecommendationHandler
from scenario_handlers.post_event_feedback import PostEventFeedbackHandler
from scenario_handlers.exit_intent_rescue import ExitIntentRescueHandler
async def hybrid_chat_endpoint(
request, # ChatRequest
conversation_service,
intent_classifier,
embedding_service, # For handlers
qdrant_service, # For handlers
tools_service,
advanced_rag,
chat_history_collection,
hf_token,
lead_storage
):
"""
Hybrid conversational chatbot: Scenario FSM + RAG
Flow:
1. Load session & scenario state
2. Classify intent (scenario vs RAG)
3. Route:
- Scenario: Execute FSM flow with dedicated handlers
- RAG: Knowledge retrieval
- RAG+Resume: Answer question then resume scenario
4. Save state & history
"""
try:
# ===== SESSION MANAGEMENT =====
session_id = request.session_id
if not session_id:
session_id = conversation_service.create_session(
metadata={"user_agent": "api", "created_via": "hybrid_chat"},
user_id=request.user_id
)
print(f"✓ Created session: {session_id} (user: {request.user_id or 'anon'})")
else:
if not conversation_service.session_exists(session_id):
raise HTTPException(404, detail=f"Session {session_id} not found")
# ===== LOAD SCENARIO STATE =====
scenario_state = conversation_service.get_scenario_state(session_id) or {}
# ===== INTENT CLASSIFICATION =====
intent = intent_classifier.classify(request.message, scenario_state)
print(f"🎯 Intent: {intent}")
# ===== ROUTING =====
if intent.startswith("scenario:"):
# Route to dedicated scenario handler
response_data = await handle_scenario(
intent,
request.message,
session_id,
scenario_state,
embedding_service,
qdrant_service,
conversation_service,
lead_storage
)
elif intent == "rag:with_resume":
# Answer question but keep scenario active
response_data = await handle_rag_with_resume(
request,
session_id,
scenario_state,
embedding_service,
qdrant_service,
conversation_service
)
else: # rag:general
# Pure RAG query
response_data = await handle_pure_rag(
request,
session_id,
advanced_rag,
embedding_service,
qdrant_service,
tools_service,
chat_history_collection,
hf_token,
conversation_service
)
# ===== SAVE HISTORY =====
conversation_service.add_message(
session_id,
"user",
request.message,
metadata={"intent": intent}
)
conversation_service.add_message(
session_id,
"assistant",
response_data["response"],
metadata={
"mode": response_data.get("mode", "unknown"),
"context_used": response_data.get("context_used", [])[:3]
}
)
return {
"response": response_data["response"],
"session_id": session_id,
"mode": response_data.get("mode"),
"scenario_active": response_data.get("scenario_active", False),
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
print(f"❌ Error in hybrid_chat: {str(e)}")
raise HTTPException(500, detail=f"Chat error: {str(e)}")
async def handle_scenario(
intent,
user_message,
session_id,
scenario_state,
embedding_service,
qdrant_service,
conversation_service,
lead_storage
):
"""
Handle scenario-based conversation using dedicated handlers
Replaces old scenario_engine with per-scenario handlers
"""
# Initialize all scenario handlers
handlers = {
'price_inquiry': PriceInquiryHandler(embedding_service, qdrant_service, lead_storage),
'event_recommendation': EventRecommendationHandler(embedding_service, qdrant_service, lead_storage),
'post_event_feedback': PostEventFeedbackHandler(embedding_service, qdrant_service, lead_storage),
'exit_intent_rescue': ExitIntentRescueHandler(embedding_service, qdrant_service, lead_storage)
}
if intent == "scenario:continue":
# Continue existing scenario
scenario_id = scenario_state.get("active_scenario")
if scenario_id not in handlers:
return {
"response": f"Xin lỗi, scenario '{scenario_id}' không tồn tại.",
"mode": "error",
"scenario_active": False
}
handler = handlers[scenario_id]
result = handler.next_step(
current_step=scenario_state.get("scenario_step", 1),
user_input=user_message,
scenario_data=scenario_state.get("scenario_data", {})
)
else:
# Start new scenario
scenario_type = intent.split(":", 1)[1]
if scenario_type not in handlers:
return {
"response": f"Xin lỗi, scenario '{scenario_type}' không tồn tại.",
"mode": "error",
"scenario_active": False
}
handler = handlers[scenario_type]
# Get initial_data from scenario_state (if any)
initial_data = scenario_state.get("scenario_data", {})
result = handler.start(initial_data=initial_data)
# Update scenario state
if result.get("end_scenario") or not result.get("scenario_active", True):
conversation_service.clear_scenario(session_id)
scenario_active = False
elif result.get("new_state"):
conversation_service.set_scenario_state(session_id, result["new_state"])
scenario_active = True
else:
# new_state is None → stay at same step (e.g., validation failed)
scenario_active = True
return {
"response": result.get("message", ""),
"mode": "scenario",
"scenario_active": scenario_active,
"loading_message": result.get("loading_message") # For UI
}
async def handle_rag_with_resume(
request,
session_id,
scenario_state,
embedding_service,
qdrant_service,
conversation_service
):
"""
Handle RAG query mid-scenario
Answer question properly, then remind user to continue scenario
"""
# Query RAG with proper search
context_used = []
if request.use_rag:
query_embedding = embedding_service.encode_text(request.message)
results = qdrant_service.search(
query_embedding=query_embedding,
limit=request.top_k,
score_threshold=request.score_threshold,
ef=256
)
context_used = results
# Build REAL RAG response (not placeholder)
if context_used and len(context_used) > 0:
# Format top results nicely
top_result = context_used[0]
text = top_result['metadata'].get('text', '')
# Extract most relevant snippet (first 300 chars)
if text:
rag_response = text[:300].strip()
if len(text) > 300:
rag_response += "..."
else:
rag_response = "Tôi tìm thấy thông tin nhưng không thể hiển thị chi tiết."
# If multiple results, add count
if len(context_used) > 1:
rag_response += f"\n\n(Tìm thấy {len(context_used)} kết quả liên quan)"
else:
rag_response = "Xin lỗi, tôi không tìm thấy thông tin về câu hỏi này trong tài liệu."
# Add resume hint
resume_hint = "\n\n---\n💬 Vậy nha! Quay lại câu hỏi trước, bạn đã quyết định chưa?"
return {
"response": rag_response + resume_hint,
"mode": "rag_with_resume",
"scenario_active": True,
"context_used": context_used
}
async def handle_pure_rag(
request,
session_id,
advanced_rag,
embedding_service,
qdrant_service,
tools_service,
chat_history_collection,
hf_token,
conversation_service
):
"""
Handle pure RAG query (fallback to existing logic)
"""
# Import existing chat_endpoint logic
from chat_endpoint import chat_endpoint
# Call existing endpoint
result = await chat_endpoint(
request,
conversation_service,
tools_service,
advanced_rag,
embedding_service,
qdrant_service,
chat_history_collection,
hf_token
)
return {
"response": result["response"],
"mode": "rag",
"context_used": result.get("context_used", [])
}