""" Multi-Step Query Planner Detects complex queries that require multiple datasets or operations, decomposes them into atomic steps, and orchestrates execution. """ import json import logging from dataclasses import dataclass, field from typing import List, Dict, Any, Optional from enum import Enum logger = logging.getLogger(__name__) class StepType(Enum): """Types of query steps.""" DATA_QUERY = "data_query" # Simple data retrieval AGGREGATION = "aggregation" # COUNT, SUM, GROUP BY COMPARISON = "comparison" # Comparing results from previous steps SPATIAL_JOIN = "spatial_join" # Joining datasets spatially COMBINE = "combine" # Merge/combine step results @dataclass class QueryStep: """A single atomic step in a query plan.""" step_id: str step_type: StepType description: str tables_needed: List[str] sql_template: Optional[str] = None depends_on: List[str] = field(default_factory=list) result_name: str = "" # Name for intermediate result def to_dict(self) -> Dict[str, Any]: return { "step_id": self.step_id, "step_type": self.step_type.value, "description": self.description, "tables_needed": self.tables_needed, "sql_template": self.sql_template, "depends_on": self.depends_on, "result_name": self.result_name } @dataclass class QueryPlan: """Complete execution plan for a complex query.""" original_query: str is_complex: bool steps: List[QueryStep] = field(default_factory=list) parallel_groups: List[List[str]] = field(default_factory=list) # Steps that can run in parallel final_combination_logic: str = "" def to_dict(self) -> Dict[str, Any]: return { "original_query": self.original_query, "is_complex": self.is_complex, "steps": [s.to_dict() for s in self.steps], "parallel_groups": self.parallel_groups, "final_combination_logic": self.final_combination_logic } class QueryPlanner: """ Multi-step query planning service. Analyzes queries to determine complexity and decomposes complex queries into executable atomic steps. """ _instance = None # Keywords that often indicate multi-step queries COMPLEXITY_INDICATORS = [ "compare", "comparison", "versus", "vs", "more than", "less than", "higher than", "lower than", "both", "and also", "as well as", "ratio", "percentage", "proportion", "correlation", "relationship between", "combine", "merge", "together with", "relative to", "compared to", "difference between", "gap between" ] # Keywords indicating multiple distinct data types MULTI_DOMAIN_KEYWORDS = { "health": ["hospital", "clinic", "healthcare", "health", "medical"], "education": ["school", "university", "education", "college", "student"], "infrastructure": ["road", "bridge", "infrastructure", "building"], "environment": ["forest", "water", "environment", "park", "protected"], "population": ["population", "demographic", "census", "people", "resident"] } def __new__(cls): if cls._instance is None: cls._instance = super(QueryPlanner, cls).__new__(cls) cls._instance.initialized = False return cls._instance def __init__(self): if self.initialized: return self.initialized = True def detect_complexity(self, query: str) -> Dict[str, Any]: """ Analyze a query to determine if it requires multi-step planning. Returns: { "is_complex": bool, "reason": str, "detected_domains": List[str], "complexity_indicators": List[str] } """ query_lower = query.lower() # Check for complexity indicators found_indicators = [ ind for ind in self.COMPLEXITY_INDICATORS if ind in query_lower ] # Check for multiple data domains found_domains = [] for domain, keywords in self.MULTI_DOMAIN_KEYWORDS.items(): if any(kw in query_lower for kw in keywords): found_domains.append(domain) # Determine complexity is_complex = ( len(found_indicators) > 0 and len(found_domains) >= 2 ) or ( len(found_domains) >= 3 ) or ( any(x in query_lower for x in ["compare", "ratio", "correlation", "versus", " vs "]) and len(found_domains) >= 2 ) reason = "" if is_complex: if len(found_domains) >= 2: reason = f"Query involves multiple data domains: {', '.join(found_domains)}" if found_indicators: reason += f". Contains comparison/aggregation keywords: {', '.join(found_indicators[:3])}" return { "is_complex": is_complex, "reason": reason, "detected_domains": found_domains, "complexity_indicators": found_indicators } async def plan_query( self, query: str, available_tables: List[str], llm_gateway ) -> QueryPlan: """ Create an execution plan for a complex query. Uses LLM to decompose the query into atomic steps. """ from backend.core.prompts import QUERY_PLANNING_PROMPT # Build table context table_list = "\n".join(f"- {t}" for t in available_tables) prompt = QUERY_PLANNING_PROMPT.format( user_query=query, available_tables=table_list ) try: response = await llm_gateway.generate_response(prompt, []) # Parse JSON response response_clean = response.strip() if response_clean.startswith("```json"): response_clean = response_clean[7:] if response_clean.startswith("```"): response_clean = response_clean[3:] if response_clean.endswith("```"): response_clean = response_clean[:-3] plan_data = json.loads(response_clean.strip()) # Convert to QueryPlan steps = [] for i, step_data in enumerate(plan_data.get("steps", [])): step = QueryStep( step_id=f"step_{i+1}", step_type=StepType(step_data.get("type", "data_query")), description=step_data.get("description", ""), tables_needed=step_data.get("tables", []), sql_template=step_data.get("sql_hint", None), depends_on=step_data.get("depends_on", []), result_name=step_data.get("result_name", f"result_{i+1}") ) steps.append(step) # Determine parallel groups (steps with no dependencies can run together) parallel_groups = self._compute_parallel_groups(steps) return QueryPlan( original_query=query, is_complex=True, steps=steps, parallel_groups=parallel_groups, final_combination_logic=plan_data.get("combination_logic", "") ) except Exception as e: logger.error(f"Query planning failed: {e}") # Return single-step fallback return QueryPlan( original_query=query, is_complex=False, steps=[], parallel_groups=[], final_combination_logic="" ) def _compute_parallel_groups(self, steps: List[QueryStep]) -> List[List[str]]: """ Compute which steps can be executed in parallel. Steps with no dependencies (or only completed dependencies) can run together. """ if not steps: return [] groups = [] executed = set() remaining = {s.step_id: s for s in steps} while remaining: # Find steps whose dependencies are all satisfied ready = [ step_id for step_id, step in remaining.items() if all(dep in executed for dep in step.depends_on) ] if not ready: # Avoid infinite loop - add remaining as sequential ready = list(remaining.keys())[:1] groups.append(ready) for step_id in ready: executed.add(step_id) del remaining[step_id] return groups def create_simple_plan(self, query: str) -> QueryPlan: """Create a simple single-step plan for non-complex queries.""" return QueryPlan( original_query=query, is_complex=False, steps=[ QueryStep( step_id="step_1", step_type=StepType.DATA_QUERY, description="Execute query directly", tables_needed=[], depends_on=[] ) ], parallel_groups=[["step_1"]] ) # Singleton accessor _query_planner: Optional[QueryPlanner] = None def get_query_planner() -> QueryPlanner: """Get the singleton query planner instance.""" global _query_planner if _query_planner is None: _query_planner = QueryPlanner() return _query_planner