import re import os from vanna import Agent, AgentConfig from vanna.core.registry import ToolRegistry from vanna.core.user import UserResolver, User, RequestContext from vanna.tools import RunSqlTool from vanna.tools.agent_memory import SaveQuestionToolArgsTool, SearchSavedCorrectToolUsesTool from vanna.integrations.postgres import PostgresRunner from vanna.integrations.local.agent_memory import DemoAgentMemory from .vanna_huggingface_llm_service import VannaHuggingFaceLlmService from typing import List, Dict, Any, Optional from vanna.core.system_prompt import SystemPromptBuilder from vanna.core.registry import ToolSchema from datetime import datetime class CustomSQLSystemPromptBuilder(SystemPromptBuilder): """Complete system prompt builder for Vanna SQL assistant v2.""" VERSION = "2.2.0" def __init__(self, company_name: str = "CoJournalist", sql_runner: Optional[PostgresRunner] = None): self.company_name = company_name self.sql_runner = sql_runner async def build_system_prompt( self, user: User, tool_schemas: List[ToolSchema], context: Optional[Dict[str, Any]] = None ) -> str: today = datetime.now().strftime("%Y-%m-%d") username = getattr(user, "username", user.id) # ====================== # BASE PROMPT # ====================== prompt = f"[System Prompt v{self.VERSION}]\n\n" prompt += f"You are an expert SQL assistant for the company {self.company_name}.\n" prompt += f"Date: {today}\nUser: {username}\nGroups: {', '.join(user.group_memberships)}\n\n" prompt += ( "Your role: generate correct and efficient SQL queries from natural language.\n" "You always respond in **raw CSV format**, with no explanation or extra text.\n" "You have full access to all tables and relationships described in the schema.\n" ) # ====================== # SQL DIRECTIVES # ====================== prompt += ( "\n## SQL Directives\n" "- Always use table aliases in JOINs\n" "- Never use SELECT *\n" "- Prefer window functions over subqueries when possible\n" "- Always include a LIMIT for exploratory queries\n" "- Exclude posts where provider = 'SND'\n" "- Exclude posts where type = 'resource'\n" "- Exclude posts where type = 'insight'\n" "- Format dates and numbers for readability\n" ) # ====================== # DATABASE SCHEMA # ====================== if context and "database_schema" in context: prompt += "\n## Database Schema\n" prompt += context["database_schema"] else: prompt += ( "\n## Database Schema\n" "Tables:\n" "- posts (id, title, source_url, author, published_date, image_url, type, provider_id, created_at, updated_at)\n" "- providers (id, name)\n" "- provider_attributes (id, provider_id, type, name)\n" "- post_provider_attributes (post_id, attribute_id)\n" "- tags (id, name)\n" "- post_tags (post_id, tag_id, weight)\n" "\nRelationships:\n" " - posts.provider_id → providers.id\n" " - post_provider_attributes.post_id → posts.id\n" " - post_provider_attributes.attribute_id → provider_attributes.id\n" " - provider_attributes.provider_id → providers.id\n" " - post_tags.post_id → posts.id\n" " - post_tags.tag_id → tags.id\n" ) # ====================== # SEMANTIC INFORMATION # ====================== prompt += ( "\n## Semantic Information\n" "- `posts.title`: title of the content (often descriptive, may contain keywords).\n" "- `posts.source_url`: external link to the article or resource.\n" "- `posts.author`: author, journalist, or organization name (e.g., 'The New York Times').\n" "- `posts.published_date`: publication date.\n" "- `posts.type`: content type ENUM ('spotlight', 'resource', 'insight').\n" "- `providers.name`: name of the publishing organization (e.g., 'Nuanced', 'SND').\n" "- `tags.name`: thematic keyword or topic (e.g., '3D', 'AI', 'Design').\n" "- `post_tags.weight`: relevance score between a post and a tag.\n" ) # ====================== # BUSINESS LOGIC # ====================== prompt += ( "\n## Business Logic\n" "- Providers named 'SND' must always be excluded.\n" "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n" "- By default, only posts with `type = 'spotlight'` are returned.\n" "- Posts of type `resource` or `insight` are excluded unless explicitly requested.\n" "- Tags link posts to specific themes or disciplines.\n" "- A single post may have multiple tags, awards, or categories.\n" "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n" "- If the user says 'recently', filter posts from the last 90 days.\n" "- Always limit exploratory results to 9 rows.\n" ) # ====================== # AVAILABLE TOOLS # ====================== if tool_schemas: prompt += "\n## Available Tools\n" for tool in tool_schemas: prompt += f"- {tool.name}: {getattr(tool, 'description', 'No description')}\n" prompt += f" Parameters: {getattr(tool, 'parameters', 'N/A')}\n" # ====================== # MEMORY SYSTEM # ====================== tool_names = [t.name for t in tool_schemas] has_search = "search_saved_correct_tool_uses" in tool_names has_save = "save_question_tool_args" in tool_names if has_search or has_save: prompt += "\n## Memory System\n" if has_search: prompt += "- Use `search_saved_correct_tool_uses` to detect past patterns.\n" if has_save: prompt += "- Use `save_question_tool_args` to store successful pairs.\n" # ====================== # EXAMPLES # ====================== prompt += ( "\n## Example Interactions\n" "User: 'Show me posts related to 3D'\n" "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " "FROM posts p " "JOIN post_tags pt ON p.id = pt.post_id " "JOIN tags t ON pt.tag_id = t.id " "JOIN providers pr ON p.provider_id = pr.id " "WHERE t.name ILIKE '%3D%' AND pr.name != 'SND' AND p.type = 'spotlight' " "LIMIT 9;\"]\n" "\nUser: 'Show me posts from The New York Times'\n" "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " "FROM posts p " "LEFT JOIN providers pr ON pr.id = p.provider_id " "WHERE LOWER(p.author) LIKE '%new york times%' OR LOWER(pr.name) LIKE '%new york times%' " "AND pr.name != 'SND' AND p.type = 'spotlight' " "LIMIT 9;\"]\n" ) # ====================== # FINAL INSTRUCTIONS # ====================== prompt += ( "\nIMPORTANT:\n" "- Always exclude posts with provider = 'SND'.\n" "- Always exclude posts with type = 'resource' or 'insight'.\n" "- Always return **only the raw CSV result** — no explanations, no JSON, no commentary.\n" "- Stop tool execution once the query result is obtained.\n" ) return prompt class SimpleUserResolver(UserResolver): async def resolve_user(self, request_context: RequestContext) -> User: user_email = request_context.get_cookie('vanna_email') or 'guest@example.com' group = 'admin' if user_email == 'admin@example.com' else 'user' return User(id=user_email, email=user_email, group_memberships=[group]) class VannaComponent: def __init__( self, hf_model: str, hf_token: str, hf_provider: str, connection_string: str, ): llm = VannaHuggingFaceLlmService(model=hf_model, token=hf_token, provider=hf_provider) self.sql_runner = PostgresRunner(connection_string=connection_string) db_tool = RunSqlTool(sql_runner=self.sql_runner) agent_memory = DemoAgentMemory(max_items=1000) save_memory_tool = SaveQuestionToolArgsTool(agent_memory) search_memory_tool = SearchSavedCorrectToolUsesTool(agent_memory) self.user_resolver = SimpleUserResolver() tools = ToolRegistry() tools.register_local_tool(db_tool, access_groups=['admin', 'user']) tools.register_local_tool(save_memory_tool, access_groups=['admin']) tools.register_local_tool(search_memory_tool, access_groups=['admin', 'user']) self.agent = Agent( llm_service=llm, tool_registry=tools, user_resolver=self.user_resolver, system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner), config=AgentConfig(stream_responses=False, max_tool_iterations=1) ) async def ask(self, prompt_for_llm: str): ctx = RequestContext() print(f"🙋 Prompt sent to LLM: {prompt_for_llm}") final_text = "" seen_texts = set() async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm): simple = getattr(component, "simple_component", None) text = getattr(simple, "text", "") if simple else "" if text and text not in seen_texts: print(f"💬 LLM says (part): {text[:200]}...") final_text += text + "\n" seen_texts.add(text) sql_query = getattr(component, "sql", None) if sql_query: print(f"🧾 SQL Query Generated: {sql_query}") metadata = getattr(component, "metadata", None) if metadata: print(f"📋 Metadata: {metadata}") component_type = getattr(component, "type", None) if component_type: print(f"🔖 Component Type: {component_type}") match = re.search(r"query_results_[\w-]+\.csv", final_text) if match: filename = match.group(0) folder = "513935c4d2db2d2d" full_path = os.path.join(folder, filename) if os.path.exists(full_path): print(f"📂 Reading result file: {full_path}") with open(full_path, "r", encoding="utf-8") as f: csv_data = f.read().strip() print("🤖 Response sent to user (from file):", csv_data[:300]) return csv_data else: print(f"⚠️ File not found: {full_path}") return final_text