Spaces:
Running
Running
| import re | |
| import os | |
| from vanna import Agent, AgentConfig | |
| from vanna.core.registry import ToolRegistry | |
| from vanna.core.user import UserResolver, User, RequestContext | |
| from vanna.tools import RunSqlTool | |
| from vanna.tools.agent_memory import SaveQuestionToolArgsTool, SearchSavedCorrectToolUsesTool | |
| from vanna.integrations.postgres import PostgresRunner | |
| from vanna.integrations.local.agent_memory import DemoAgentMemory | |
| from .vanna_huggingface_llm_service import VannaHuggingFaceLlmService | |
| from typing import List, Dict, Any, Optional | |
| from vanna.core.system_prompt import SystemPromptBuilder | |
| from vanna.core.registry import ToolSchema | |
| from datetime import datetime | |
| class CustomSQLSystemPromptBuilder(SystemPromptBuilder): | |
| """Complete system prompt builder for Vanna SQL assistant v2.""" | |
| VERSION = "2.2.0" | |
| def __init__(self, company_name: str = "CoJournalist", sql_runner: Optional[PostgresRunner] = None): | |
| self.company_name = company_name | |
| self.sql_runner = sql_runner | |
| async def build_system_prompt( | |
| self, | |
| user: User, | |
| tool_schemas: List[ToolSchema], | |
| context: Optional[Dict[str, Any]] = None | |
| ) -> str: | |
| today = datetime.now().strftime("%Y-%m-%d") | |
| username = getattr(user, "username", user.id) | |
| # ====================== | |
| # BASE PROMPT | |
| # ====================== | |
| prompt = f"[System Prompt v{self.VERSION}]\n\n" | |
| prompt += f"You are an expert SQL assistant for the company {self.company_name}.\n" | |
| prompt += f"Date: {today}\nUser: {username}\nGroups: {', '.join(user.group_memberships)}\n\n" | |
| prompt += ( | |
| "Your role: generate correct and efficient SQL queries from natural language.\n" | |
| "You always respond in **raw CSV format**, with no explanation or extra text.\n" | |
| "You have full access to all tables and relationships described in the schema.\n" | |
| ) | |
| # ====================== | |
| # SQL DIRECTIVES | |
| # ====================== | |
| prompt += ( | |
| "\n## SQL Directives\n" | |
| "- Always use table aliases in JOINs\n" | |
| "- Never use SELECT *\n" | |
| "- Prefer window functions over subqueries when possible\n" | |
| "- Always include a LIMIT for exploratory queries\n" | |
| "- Exclude posts where provider = 'SND'\n" | |
| "- Exclude posts where type = 'resource'\n" | |
| "- Exclude posts where type = 'insight'\n" | |
| "- Format dates and numbers for readability\n" | |
| ) | |
| # ====================== | |
| # DATABASE SCHEMA | |
| # ====================== | |
| if context and "database_schema" in context: | |
| prompt += "\n## Database Schema\n" | |
| prompt += context["database_schema"] | |
| else: | |
| prompt += ( | |
| "\n## Database Schema\n" | |
| "Tables:\n" | |
| "- posts (id, title, source_url, author, published_date, image_url, type, provider_id, created_at, updated_at)\n" | |
| "- providers (id, name)\n" | |
| "- provider_attributes (id, provider_id, type, name)\n" | |
| "- post_provider_attributes (post_id, attribute_id)\n" | |
| "- tags (id, name)\n" | |
| "- post_tags (post_id, tag_id, weight)\n" | |
| "\nRelationships:\n" | |
| " - posts.provider_id β providers.id\n" | |
| " - post_provider_attributes.post_id β posts.id\n" | |
| " - post_provider_attributes.attribute_id β provider_attributes.id\n" | |
| " - provider_attributes.provider_id β providers.id\n" | |
| " - post_tags.post_id β posts.id\n" | |
| " - post_tags.tag_id β tags.id\n" | |
| ) | |
| # ====================== | |
| # SEMANTIC INFORMATION | |
| # ====================== | |
| prompt += ( | |
| "\n## Semantic Information\n" | |
| "- `posts.title`: title of the content (often descriptive, may contain keywords).\n" | |
| "- `posts.source_url`: external link to the article or resource.\n" | |
| "- `posts.author`: author, journalist, or organization name (e.g., 'The New York Times').\n" | |
| "- `posts.published_date`: publication date.\n" | |
| "- `posts.type`: content type ENUM ('spotlight', 'resource', 'insight').\n" | |
| "- `providers.name`: name of the publishing organization (e.g., 'Nuanced', 'SND').\n" | |
| "- `tags.name`: thematic keyword or topic (e.g., '3D', 'AI', 'Design').\n" | |
| "- `post_tags.weight`: relevance score between a post and a tag.\n" | |
| ) | |
| # ====================== | |
| # BUSINESS LOGIC | |
| # ====================== | |
| prompt += ( | |
| "\n## Business Logic\n" | |
| "- Providers named 'SND' must always be excluded.\n" | |
| "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n" | |
| "- By default, only posts with `type = 'spotlight'` are returned.\n" | |
| "- Posts of type `resource` or `insight` are excluded unless explicitly requested.\n" | |
| "- Tags link posts to specific themes or disciplines.\n" | |
| "- A single post may have multiple tags, awards, or categories.\n" | |
| "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n" | |
| "- If the user says 'recently', filter posts from the last 90 days.\n" | |
| "- Always limit exploratory results to 9 rows.\n" | |
| ) | |
| # ====================== | |
| # AVAILABLE TOOLS | |
| # ====================== | |
| if tool_schemas: | |
| prompt += "\n## Available Tools\n" | |
| for tool in tool_schemas: | |
| prompt += f"- {tool.name}: {getattr(tool, 'description', 'No description')}\n" | |
| prompt += f" Parameters: {getattr(tool, 'parameters', 'N/A')}\n" | |
| # ====================== | |
| # MEMORY SYSTEM | |
| # ====================== | |
| tool_names = [t.name for t in tool_schemas] | |
| has_search = "search_saved_correct_tool_uses" in tool_names | |
| has_save = "save_question_tool_args" in tool_names | |
| if has_search or has_save: | |
| prompt += "\n## Memory System\n" | |
| if has_search: | |
| prompt += "- Use `search_saved_correct_tool_uses` to detect past patterns.\n" | |
| if has_save: | |
| prompt += "- Use `save_question_tool_args` to store successful pairs.\n" | |
| # ====================== | |
| # EXAMPLES | |
| # ====================== | |
| prompt += ( | |
| "\n## Example Interactions\n" | |
| "User: 'Show me posts related to 3D'\n" | |
| "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " | |
| "FROM posts p " | |
| "JOIN post_tags pt ON p.id = pt.post_id " | |
| "JOIN tags t ON pt.tag_id = t.id " | |
| "JOIN providers pr ON p.provider_id = pr.id " | |
| "WHERE t.name ILIKE '%3D%' AND pr.name != 'SND' AND p.type = 'spotlight' " | |
| "LIMIT 9;\"]\n" | |
| "\nUser: 'Show me posts from The New York Times'\n" | |
| "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " | |
| "FROM posts p " | |
| "LEFT JOIN providers pr ON pr.id = p.provider_id " | |
| "WHERE LOWER(p.author) LIKE '%new york times%' OR LOWER(pr.name) LIKE '%new york times%' " | |
| "AND pr.name != 'SND' AND p.type = 'spotlight' " | |
| "LIMIT 9;\"]\n" | |
| ) | |
| # ====================== | |
| # FINAL INSTRUCTIONS | |
| # ====================== | |
| prompt += ( | |
| "\nIMPORTANT:\n" | |
| "- Always exclude posts with provider = 'SND'.\n" | |
| "- Always exclude posts with type = 'resource' or 'insight'.\n" | |
| "- Always return **only the raw CSV result** β no explanations, no JSON, no commentary.\n" | |
| "- Stop tool execution once the query result is obtained.\n" | |
| ) | |
| return prompt | |
| class SimpleUserResolver(UserResolver): | |
| async def resolve_user(self, request_context: RequestContext) -> User: | |
| user_email = request_context.get_cookie('vanna_email') or 'guest@example.com' | |
| group = 'admin' if user_email == 'admin@example.com' else 'user' | |
| return User(id=user_email, email=user_email, group_memberships=[group]) | |
| class VannaComponent: | |
| def __init__( | |
| self, | |
| hf_model: str, | |
| hf_token: str, | |
| hf_provider: str, | |
| connection_string: str, | |
| ): | |
| llm = VannaHuggingFaceLlmService(model=hf_model, token=hf_token, provider=hf_provider) | |
| self.sql_runner = PostgresRunner(connection_string=connection_string) | |
| db_tool = RunSqlTool(sql_runner=self.sql_runner) | |
| agent_memory = DemoAgentMemory(max_items=1000) | |
| save_memory_tool = SaveQuestionToolArgsTool(agent_memory) | |
| search_memory_tool = SearchSavedCorrectToolUsesTool(agent_memory) | |
| self.user_resolver = SimpleUserResolver() | |
| tools = ToolRegistry() | |
| tools.register_local_tool(db_tool, access_groups=['admin', 'user']) | |
| tools.register_local_tool(save_memory_tool, access_groups=['admin']) | |
| tools.register_local_tool(search_memory_tool, access_groups=['admin', 'user']) | |
| self.agent = Agent( | |
| llm_service=llm, | |
| tool_registry=tools, | |
| user_resolver=self.user_resolver, | |
| system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner), | |
| config=AgentConfig(stream_responses=False, max_tool_iterations=1) | |
| ) | |
| async def ask(self, prompt_for_llm: str): | |
| ctx = RequestContext() | |
| print(f"π Prompt sent to LLM: {prompt_for_llm}") | |
| final_text = "" | |
| seen_texts = set() | |
| async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm): | |
| simple = getattr(component, "simple_component", None) | |
| text = getattr(simple, "text", "") if simple else "" | |
| if text and text not in seen_texts: | |
| print(f"π¬ LLM says (part): {text[:200]}...") | |
| final_text += text + "\n" | |
| seen_texts.add(text) | |
| sql_query = getattr(component, "sql", None) | |
| if sql_query: | |
| print(f"π§Ύ SQL Query Generated: {sql_query}") | |
| metadata = getattr(component, "metadata", None) | |
| if metadata: | |
| print(f"π Metadata: {metadata}") | |
| component_type = getattr(component, "type", None) | |
| if component_type: | |
| print(f"π Component Type: {component_type}") | |
| match = re.search(r"query_results_[\w-]+\.csv", final_text) | |
| if match: | |
| filename = match.group(0) | |
| folder = "513935c4d2db2d2d" | |
| full_path = os.path.join(folder, filename) | |
| if os.path.exists(full_path): | |
| print(f"π Reading result file: {full_path}") | |
| with open(full_path, "r", encoding="utf-8") as f: | |
| csv_data = f.read().strip() | |
| print("π€ Response sent to user (from file):", csv_data[:300]) | |
| return csv_data | |
| else: | |
| print(f"β οΈ File not found: {full_path}") | |
| return final_text | |