ChatbotRAG / agent_service.py
minhvtt's picture
Upload 3 files
8679400 verified
"""
Agent Service - Central Brain for Sales & Feedback Agents
Manages LLM conversation loop with native tool calling
"""
from typing import Dict, Any, List, Optional
import os
import json
from tools_service import ToolsService
class AgentService:
"""
Manages the conversation loop between User -> LLM -> Tools -> Response
Uses native tool calling via HuggingFace Inference API
"""
def __init__(
self,
tools_service: ToolsService,
embedding_service,
qdrant_service,
advanced_rag,
hf_token: str,
feedback_tracking=None # Optional feedback tracking
):
self.tools_service = tools_service
self.embedding_service = embedding_service
self.qdrant_service = qdrant_service
self.advanced_rag = advanced_rag
self.hf_token = hf_token
self.feedback_tracking = feedback_tracking
# Load system prompts
self.prompts = self._load_prompts()
def _load_prompts(self) -> Dict[str, str]:
"""Load system prompts from files"""
prompts = {}
prompts_dir = "prompts"
for mode in ["sales_agent", "feedback_agent"]:
filepath = os.path.join(prompts_dir, f"{mode}.txt")
try:
with open(filepath, 'r', encoding='utf-8') as f:
prompts[mode] = f.read()
print(f"✓ Loaded prompt: {mode}")
except Exception as e:
print(f"⚠️ Error loading {mode} prompt: {e}")
prompts[mode] = ""
return prompts
def _get_native_tools(self, mode: str = "sales") -> List[Dict]:
"""
Get tools formatted for native tool calling API.
Returns OpenAI-compatible tool definitions.
"""
common_tools = [
{
"type": "function",
"function": {
"name": "search_events",
"description": "Tìm kiếm sự kiện phù hợp theo từ khóa, vibe, hoặc thời gian.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Từ khóa tìm kiếm (VD: 'nhạc rock', 'hài kịch')"},
"vibe": {"type": "string", "description": "Vibe/Mood (VD: 'chill', 'sôi động', 'hẹn hò')"},
"time": {"type": "string", "description": "Thời gian (VD: 'cuối tuần này', 'tối nay')"}
}
}
}
},
{
"type": "function",
"function": {
"name": "get_event_details",
"description": "Lấy thông tin chi tiết (giá, địa điểm, thời gian) của sự kiện.",
"parameters": {
"type": "object",
"properties": {
"event_id": {"type": "string", "description": "ID của sự kiện (MongoDB ID)"}
},
"required": ["event_id"]
}
}
}
]
sales_tools = [
{
"type": "function",
"function": {
"name": "save_lead",
"description": "Lưu thông tin khách hàng quan tâm (Lead).",
"parameters": {
"type": "object",
"properties": {
"email": {"type": "string", "description": "Email address"},
"phone": {"type": "string", "description": "Phone number"},
"interest": {"type": "string", "description": "What they're interested in"}
}
}
}
}
]
feedback_tools = [
{
"type": "function",
"function": {
"name": "get_purchased_events",
"description": "Kiểm tra lịch sử các sự kiện user đã mua vé hoặc tham gia.",
"parameters": {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "ID của user"}
},
"required": ["user_id"]
}
}
},
{
"type": "function",
"function": {
"name": "save_feedback",
"description": "Lưu đánh giá/feedback của user về sự kiện.",
"parameters": {
"type": "object",
"properties": {
"event_id": {"type": "string", "description": "ID sự kiện"},
"rating": {"type": "integer", "description": "Số sao đánh giá (1-5)"},
"comment": {"type": "string", "description": "Nội dung nhận xét"}
},
"required": ["event_id", "rating"]
}
}
}
]
if mode == "feedback":
return common_tools + feedback_tools
else:
return common_tools + sales_tools
async def chat(
self,
user_message: str,
conversation_history: List[Dict],
mode: str = "sales", # "sales" or "feedback"
user_id: Optional[str] = None,
access_token: Optional[str] = None, # For authenticated API calls
max_iterations: int = 3
) -> Dict[str, Any]:
"""
Main conversation loop with native tool calling
Args:
user_message: User's input
conversation_history: Previous messages [{"role": "user", "content": ...}, ...]
mode: "sales" or "feedback"
user_id: User ID (for feedback mode to check purchase history)
access_token: JWT token for authenticated API calls
max_iterations: Maximum tool call iterations to prevent infinite loops
Returns:
{
"message": "Bot response",
"tool_calls": [...], # List of tools called (for debugging)
"mode": mode
}
"""
print(f"\n🤖 Agent Mode: {mode}")
print(f"👤 User Message: {user_message}")
print(f"🔑 Auth Info:")
print(f" - User ID: {user_id}")
print(f" - Access Token: {'✅ Received' if access_token else '❌ None'}")
# Store user_id and access_token for tool calls
self.current_user_id = user_id
self.current_access_token = access_token
if access_token:
print(f" - Stored access_token for tools: {access_token[:20]}...")
if user_id:
print(f" - Stored user_id for tools: {user_id}")
# Select system prompt (without tool instructions - native tools handle this)
system_prompt = self._get_system_prompt(mode)
# Get native tools for this mode
tools = self._get_native_tools(mode)
# Build conversation context
messages = self._build_messages(system_prompt, conversation_history, user_message)
# Agentic loop: LLM may call tools multiple times
tool_calls_made = []
current_response = None
for iteration in range(max_iterations):
print(f"\n🔄 Iteration {iteration + 1}")
# Call LLM with native tools
llm_result = await self._call_llm_with_tools(messages, tools)
# Check if this is a final text response or a tool call
if llm_result["type"] == "text":
current_response = llm_result["content"]
print(f"🧠 LLM Final Response: {current_response[:200]}...")
break
elif llm_result["type"] == "tool_calls":
# Process each tool call
for tool_call in llm_result["tool_calls"]:
tool_name = tool_call["function"]["name"]
arguments = json.loads(tool_call["function"]["arguments"])
print(f"🔧 Tool Called: {tool_name}")
print(f" Arguments: {arguments}")
# Auto-inject real user_id for get_purchased_events
if tool_name == 'get_purchased_events' and self.current_user_id:
print(f"🔄 Auto-injecting real user_id: {self.current_user_id}")
arguments['user_id'] = self.current_user_id
# Execute tool
tool_result = await self.tools_service.execute_tool(
tool_name,
arguments,
access_token=self.current_access_token
)
# Record tool call
tool_calls_made.append({
"function": tool_name,
"arguments": arguments,
"result": tool_result
})
# Handle RAG search specially
if isinstance(tool_result, dict) and tool_result.get("action") == "run_rag_search":
tool_result = await self._execute_rag_search(tool_result["query"])
# Add assistant's tool call to messages
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tool_call.get("id", f"call_{iteration}"),
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(arguments)
}
}]
})
# Add tool result to messages
messages.append({
"role": "tool",
"tool_call_id": tool_call.get("id", f"call_{iteration}"),
"content": self._format_tool_result({"result": tool_result})
})
elif llm_result["type"] == "error":
print(f"⚠️ LLM Error: {llm_result['content']}")
current_response = "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
break
# Get final response if we hit max iterations
final_response = current_response or "Tôi cần thêm thông tin để hỗ trợ bạn."
return {
"message": final_response,
"tool_calls": tool_calls_made,
"mode": mode
}
def _get_system_prompt(self, mode: str) -> str:
"""Get system prompt for selected mode (without tool instructions)"""
prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
return self.prompts.get(prompt_key, "")
def _build_messages(
self,
system_prompt: str,
history: List[Dict],
user_message: str
) -> List[Dict]:
"""Build messages array for LLM"""
messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
messages.extend(history)
# Add current user message
messages.append({"role": "user", "content": user_message})
return messages
async def _call_llm_with_tools(self, messages: List[Dict], tools: List[Dict]) -> Dict:
"""
Call HuggingFace LLM with native tool calling support
Returns:
{"type": "text", "content": "..."} for text responses
{"type": "tool_calls", "tool_calls": [...]} for tool call requests
{"type": "error", "content": "..."} for errors
"""
try:
from huggingface_hub import AsyncInferenceClient
# Create async client - Qwen2.5 works on default HuggingFace API
client = AsyncInferenceClient(token=self.hf_token)
# Call HF API with chat completion and native tools
# Qwen2.5-72B-Instruct: Best for Vietnamese - state-of-the-art performance
response = await client.chat_completion(
messages=messages,
model="Qwen/Qwen2.5-72B-Instruct", # Best for Vietnamese + tool calling
max_tokens=1024, # Increased to prevent truncation
temperature=0.7,
tools=tools,
tool_choice="auto" # Let model decide when to use tools
)
# Check if the model made tool calls
message = response.choices[0].message
if message.tool_calls:
print(f"🔧 Native tool calls detected: {len(message.tool_calls)}")
return {
"type": "tool_calls",
"tool_calls": [
{
"id": tc.id,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
}
for tc in message.tool_calls
]
}
else:
# Regular text response
return {
"type": "text",
"content": message.content or ""
}
except Exception as e:
print(f"⚠️ LLM Call Error: {e}")
return {
"type": "error",
"content": str(e)
}
def _format_tool_result(self, tool_result: Dict) -> str:
"""Format tool result for feeding back to LLM"""
result = tool_result.get("result", {})
# Special handling for purchased events list
if isinstance(result, list):
print(f"\n🔍 Formatting {len(result)} items for LLM")
if not result:
return "Không tìm thấy dữ liệu nào phù hợp."
# Format each event clearly
formatted_events = []
for i, event in enumerate(result, 1):
# Handle both object/dict and string results
if isinstance(event, str):
formatted_events.append(f"{i}. {event}")
continue
event_info = []
event_info.append(f"Event {i}:")
# Extract key fields
if 'eventName' in event:
event_info.append(f" Name: {event['eventName']}")
if 'eventCode' in event:
event_info.append(f" Code: {event['eventCode']}")
if '_id' in event:
event_info.append(f" ID: {event['_id']}")
if 'startTimeEventTime' in event:
event_info.append(f" Date: {event['startTimeEventTime']}")
# Handle RAG result payload structure
if 'texts' in event: # Flat text from RAG
event_info.append(f" Content: {event['texts']}")
if 'id_use' in event:
event_info.append(f" ID: {event['id_use']}")
formatted_events.append("\n".join(event_info))
formatted = "Tool Results:\n\n" + "\n\n".join(formatted_events)
# print(f"📤 Sending to LLM:\n{formatted}") # Reduce noise
return formatted
# Default formatting for other results
if isinstance(result, dict):
# Pretty print key info
formatted = []
for key, value in result.items():
if key not in ["success", "error"]:
formatted.append(f"{key}: {value}")
return "\n".join(formatted) if formatted else json.dumps(result)
return str(result)
async def _execute_rag_search(self, query_params: Dict) -> str:
"""
Execute RAG search for event discovery
Called when LLM wants to search_events
"""
query = query_params.get("query", "")
vibe = query_params.get("vibe", "")
time = query_params.get("time", "")
# Build search query
search_text = f"{query} {vibe} {time}".strip()
print(f"🔍 RAG Search Query: '{search_text}'")
if not search_text:
return "Vui lòng cung cấp từ khóa tìm kiếm."
# Use embedding + qdrant
embedding = self.embedding_service.encode_text(search_text)
results = self.qdrant_service.search(
query_embedding=embedding,
limit=5
)
print(f"📊 RAG Results Count: {len(results)}")
# Fallback if no results and query was complex
if not results and (query and vibe):
print(f"⚠️ No results for combined query. Retrying with just 'vibe': {vibe}")
search_text = vibe
embedding = self.embedding_service.encode_text(search_text)
results = self.qdrant_service.search(
query_embedding=embedding,
limit=5
)
print(f"📊 Retry Results Count: {len(results)}")
# Format results
formatted = []
for i, result in enumerate(results, 1):
# Result is a dict with keys: id, score, payload
payload = result.get("payload", {})
texts = payload.get("texts", [])
text = texts[0] if texts else ""
event_id = payload.get("id_use", "")
if not text:
continue
# Clean and truncate text for context window
clean_text = text.replace("\n", " ").strip()
formatted.append(f"Event Found: {clean_text[:300]}... (ID: {event_id})")
if not formatted:
print("❌ RAG Search returned 0 usable results")
return "SYSTEM_MESSAGE: Không tìm thấy sự kiện nào trong cơ sở dữ liệu phù hợp với yêu cầu. Hãy báo lại cho khách hàng: 'Hiện tại mình chưa tìm thấy sự kiện nào phù hợp với yêu cầu này, bạn thử đổi tiêu chí xem sao nhé?'"
print(f"✅ Returning {len(formatted)} events to LLM")
return "\n\n".join(formatted)