|
|
""" |
|
|
Multi-Step Query Planner |
|
|
|
|
|
Detects complex queries that require multiple datasets or operations, |
|
|
decomposes them into atomic steps, and orchestrates execution. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Dict, Any, Optional |
|
|
from enum import Enum |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class StepType(Enum): |
|
|
"""Types of query steps.""" |
|
|
DATA_QUERY = "data_query" |
|
|
AGGREGATION = "aggregation" |
|
|
COMPARISON = "comparison" |
|
|
SPATIAL_JOIN = "spatial_join" |
|
|
COMBINE = "combine" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class QueryStep: |
|
|
"""A single atomic step in a query plan.""" |
|
|
step_id: str |
|
|
step_type: StepType |
|
|
description: str |
|
|
tables_needed: List[str] |
|
|
sql_template: Optional[str] = None |
|
|
depends_on: List[str] = field(default_factory=list) |
|
|
result_name: str = "" |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"step_id": self.step_id, |
|
|
"step_type": self.step_type.value, |
|
|
"description": self.description, |
|
|
"tables_needed": self.tables_needed, |
|
|
"sql_template": self.sql_template, |
|
|
"depends_on": self.depends_on, |
|
|
"result_name": self.result_name |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class QueryPlan: |
|
|
"""Complete execution plan for a complex query.""" |
|
|
original_query: str |
|
|
is_complex: bool |
|
|
steps: List[QueryStep] = field(default_factory=list) |
|
|
parallel_groups: List[List[str]] = field(default_factory=list) |
|
|
final_combination_logic: str = "" |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"original_query": self.original_query, |
|
|
"is_complex": self.is_complex, |
|
|
"steps": [s.to_dict() for s in self.steps], |
|
|
"parallel_groups": self.parallel_groups, |
|
|
"final_combination_logic": self.final_combination_logic |
|
|
} |
|
|
|
|
|
|
|
|
class QueryPlanner: |
|
|
""" |
|
|
Multi-step query planning service. |
|
|
|
|
|
Analyzes queries to determine complexity and decomposes |
|
|
complex queries into executable atomic steps. |
|
|
""" |
|
|
|
|
|
_instance = None |
|
|
|
|
|
|
|
|
COMPLEXITY_INDICATORS = [ |
|
|
"compare", "comparison", "versus", "vs", |
|
|
"more than", "less than", "higher than", "lower than", |
|
|
"both", "and also", "as well as", |
|
|
"ratio", "percentage", "proportion", |
|
|
"correlation", "relationship between", |
|
|
"combine", "merge", "together with", |
|
|
"relative to", "compared to", |
|
|
"difference between", "gap between" |
|
|
] |
|
|
|
|
|
|
|
|
MULTI_DOMAIN_KEYWORDS = { |
|
|
"health": ["hospital", "clinic", "healthcare", "health", "medical"], |
|
|
"education": ["school", "university", "education", "college", "student"], |
|
|
"infrastructure": ["road", "bridge", "infrastructure", "building"], |
|
|
"environment": ["forest", "water", "environment", "park", "protected"], |
|
|
"population": ["population", "demographic", "census", "people", "resident"] |
|
|
} |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super(QueryPlanner, cls).__new__(cls) |
|
|
cls._instance.initialized = False |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self): |
|
|
if self.initialized: |
|
|
return |
|
|
self.initialized = True |
|
|
|
|
|
def detect_complexity(self, query: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Analyze a query to determine if it requires multi-step planning. |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"is_complex": bool, |
|
|
"reason": str, |
|
|
"detected_domains": List[str], |
|
|
"complexity_indicators": List[str] |
|
|
} |
|
|
""" |
|
|
query_lower = query.lower() |
|
|
|
|
|
|
|
|
found_indicators = [ |
|
|
ind for ind in self.COMPLEXITY_INDICATORS |
|
|
if ind in query_lower |
|
|
] |
|
|
|
|
|
|
|
|
found_domains = [] |
|
|
for domain, keywords in self.MULTI_DOMAIN_KEYWORDS.items(): |
|
|
if any(kw in query_lower for kw in keywords): |
|
|
found_domains.append(domain) |
|
|
|
|
|
|
|
|
is_complex = ( |
|
|
len(found_indicators) > 0 and len(found_domains) >= 2 |
|
|
) or ( |
|
|
len(found_domains) >= 3 |
|
|
) or ( |
|
|
any(x in query_lower for x in ["compare", "ratio", "correlation", "versus", " vs "]) |
|
|
and len(found_domains) >= 2 |
|
|
) |
|
|
|
|
|
reason = "" |
|
|
if is_complex: |
|
|
if len(found_domains) >= 2: |
|
|
reason = f"Query involves multiple data domains: {', '.join(found_domains)}" |
|
|
if found_indicators: |
|
|
reason += f". Contains comparison/aggregation keywords: {', '.join(found_indicators[:3])}" |
|
|
|
|
|
return { |
|
|
"is_complex": is_complex, |
|
|
"reason": reason, |
|
|
"detected_domains": found_domains, |
|
|
"complexity_indicators": found_indicators |
|
|
} |
|
|
|
|
|
async def plan_query( |
|
|
self, |
|
|
query: str, |
|
|
available_tables: List[str], |
|
|
llm_gateway |
|
|
) -> QueryPlan: |
|
|
""" |
|
|
Create an execution plan for a complex query. |
|
|
|
|
|
Uses LLM to decompose the query into atomic steps. |
|
|
""" |
|
|
from backend.core.prompts import QUERY_PLANNING_PROMPT |
|
|
|
|
|
|
|
|
table_list = "\n".join(f"- {t}" for t in available_tables) |
|
|
|
|
|
prompt = QUERY_PLANNING_PROMPT.format( |
|
|
user_query=query, |
|
|
available_tables=table_list |
|
|
) |
|
|
|
|
|
try: |
|
|
response = await llm_gateway.generate_response(prompt, []) |
|
|
|
|
|
|
|
|
response_clean = response.strip() |
|
|
if response_clean.startswith("```json"): |
|
|
response_clean = response_clean[7:] |
|
|
if response_clean.startswith("```"): |
|
|
response_clean = response_clean[3:] |
|
|
if response_clean.endswith("```"): |
|
|
response_clean = response_clean[:-3] |
|
|
|
|
|
plan_data = json.loads(response_clean.strip()) |
|
|
|
|
|
|
|
|
steps = [] |
|
|
for i, step_data in enumerate(plan_data.get("steps", [])): |
|
|
step = QueryStep( |
|
|
step_id=f"step_{i+1}", |
|
|
step_type=StepType(step_data.get("type", "data_query")), |
|
|
description=step_data.get("description", ""), |
|
|
tables_needed=step_data.get("tables", []), |
|
|
sql_template=step_data.get("sql_hint", None), |
|
|
depends_on=step_data.get("depends_on", []), |
|
|
result_name=step_data.get("result_name", f"result_{i+1}") |
|
|
) |
|
|
steps.append(step) |
|
|
|
|
|
|
|
|
parallel_groups = self._compute_parallel_groups(steps) |
|
|
|
|
|
return QueryPlan( |
|
|
original_query=query, |
|
|
is_complex=True, |
|
|
steps=steps, |
|
|
parallel_groups=parallel_groups, |
|
|
final_combination_logic=plan_data.get("combination_logic", "") |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Query planning failed: {e}") |
|
|
|
|
|
return QueryPlan( |
|
|
original_query=query, |
|
|
is_complex=False, |
|
|
steps=[], |
|
|
parallel_groups=[], |
|
|
final_combination_logic="" |
|
|
) |
|
|
|
|
|
def _compute_parallel_groups(self, steps: List[QueryStep]) -> List[List[str]]: |
|
|
""" |
|
|
Compute which steps can be executed in parallel. |
|
|
|
|
|
Steps with no dependencies (or only completed dependencies) |
|
|
can run together. |
|
|
""" |
|
|
if not steps: |
|
|
return [] |
|
|
|
|
|
groups = [] |
|
|
executed = set() |
|
|
remaining = {s.step_id: s for s in steps} |
|
|
|
|
|
while remaining: |
|
|
|
|
|
ready = [ |
|
|
step_id for step_id, step in remaining.items() |
|
|
if all(dep in executed for dep in step.depends_on) |
|
|
] |
|
|
|
|
|
if not ready: |
|
|
|
|
|
ready = list(remaining.keys())[:1] |
|
|
|
|
|
groups.append(ready) |
|
|
|
|
|
for step_id in ready: |
|
|
executed.add(step_id) |
|
|
del remaining[step_id] |
|
|
|
|
|
return groups |
|
|
|
|
|
def create_simple_plan(self, query: str) -> QueryPlan: |
|
|
"""Create a simple single-step plan for non-complex queries.""" |
|
|
return QueryPlan( |
|
|
original_query=query, |
|
|
is_complex=False, |
|
|
steps=[ |
|
|
QueryStep( |
|
|
step_id="step_1", |
|
|
step_type=StepType.DATA_QUERY, |
|
|
description="Execute query directly", |
|
|
tables_needed=[], |
|
|
depends_on=[] |
|
|
) |
|
|
], |
|
|
parallel_groups=[["step_1"]] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_query_planner: Optional[QueryPlanner] = None |
|
|
|
|
|
|
|
|
def get_query_planner() -> QueryPlanner: |
|
|
"""Get the singleton query planner instance.""" |
|
|
global _query_planner |
|
|
if _query_planner is None: |
|
|
_query_planner = QueryPlanner() |
|
|
return _query_planner |
|
|
|