File size: 9,871 Bytes
4851501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""
Multi-Step Query Planner

Detects complex queries that require multiple datasets or operations,
decomposes them into atomic steps, and orchestrates execution.
"""

import json
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from enum import Enum

logger = logging.getLogger(__name__)


class StepType(Enum):
    """Types of query steps."""
    DATA_QUERY = "data_query"      # Simple data retrieval
    AGGREGATION = "aggregation"    # COUNT, SUM, GROUP BY
    COMPARISON = "comparison"      # Comparing results from previous steps
    SPATIAL_JOIN = "spatial_join"  # Joining datasets spatially
    COMBINE = "combine"            # Merge/combine step results


@dataclass
class QueryStep:
    """A single atomic step in a query plan."""
    step_id: str
    step_type: StepType
    description: str
    tables_needed: List[str]
    sql_template: Optional[str] = None
    depends_on: List[str] = field(default_factory=list)
    result_name: str = ""  # Name for intermediate result
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "step_id": self.step_id,
            "step_type": self.step_type.value,
            "description": self.description,
            "tables_needed": self.tables_needed,
            "sql_template": self.sql_template,
            "depends_on": self.depends_on,
            "result_name": self.result_name
        }


@dataclass 
class QueryPlan:
    """Complete execution plan for a complex query."""
    original_query: str
    is_complex: bool
    steps: List[QueryStep] = field(default_factory=list)
    parallel_groups: List[List[str]] = field(default_factory=list)  # Steps that can run in parallel
    final_combination_logic: str = ""
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "original_query": self.original_query,
            "is_complex": self.is_complex,
            "steps": [s.to_dict() for s in self.steps],
            "parallel_groups": self.parallel_groups,
            "final_combination_logic": self.final_combination_logic
        }


class QueryPlanner:
    """
    Multi-step query planning service.
    
    Analyzes queries to determine complexity and decomposes
    complex queries into executable atomic steps.
    """
    
    _instance = None
    
    # Keywords that often indicate multi-step queries
    COMPLEXITY_INDICATORS = [
        "compare", "comparison", "versus", "vs",
        "more than", "less than", "higher than", "lower than",
        "both", "and also", "as well as",
        "ratio", "percentage", "proportion",
        "correlation", "relationship between",
        "combine", "merge", "together with",
        "relative to", "compared to",
        "difference between", "gap between"
    ]
    
    # Keywords indicating multiple distinct data types
    MULTI_DOMAIN_KEYWORDS = {
        "health": ["hospital", "clinic", "healthcare", "health", "medical"],
        "education": ["school", "university", "education", "college", "student"],
        "infrastructure": ["road", "bridge", "infrastructure", "building"],
        "environment": ["forest", "water", "environment", "park", "protected"],
        "population": ["population", "demographic", "census", "people", "resident"]
    }
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(QueryPlanner, cls).__new__(cls)
            cls._instance.initialized = False
        return cls._instance
    
    def __init__(self):
        if self.initialized:
            return
        self.initialized = True
    
    def detect_complexity(self, query: str) -> Dict[str, Any]:
        """
        Analyze a query to determine if it requires multi-step planning.
        
        Returns:
            {
                "is_complex": bool,
                "reason": str,
                "detected_domains": List[str],
                "complexity_indicators": List[str]
            }
        """
        query_lower = query.lower()
        
        # Check for complexity indicators
        found_indicators = [
            ind for ind in self.COMPLEXITY_INDICATORS 
            if ind in query_lower
        ]
        
        # Check for multiple data domains
        found_domains = []
        for domain, keywords in self.MULTI_DOMAIN_KEYWORDS.items():
            if any(kw in query_lower for kw in keywords):
                found_domains.append(domain)
        
        # Determine complexity
        is_complex = (
            len(found_indicators) > 0 and len(found_domains) >= 2
        ) or (
            len(found_domains) >= 3
        ) or (
            any(x in query_lower for x in ["compare", "ratio", "correlation", "versus", " vs "])
            and len(found_domains) >= 2
        )
        
        reason = ""
        if is_complex:
            if len(found_domains) >= 2:
                reason = f"Query involves multiple data domains: {', '.join(found_domains)}"
            if found_indicators:
                reason += f". Contains comparison/aggregation keywords: {', '.join(found_indicators[:3])}"
        
        return {
            "is_complex": is_complex,
            "reason": reason,
            "detected_domains": found_domains,
            "complexity_indicators": found_indicators
        }
    
    async def plan_query(
        self, 
        query: str, 
        available_tables: List[str],
        llm_gateway
    ) -> QueryPlan:
        """
        Create an execution plan for a complex query.
        
        Uses LLM to decompose the query into atomic steps.
        """
        from backend.core.prompts import QUERY_PLANNING_PROMPT
        
        # Build table context
        table_list = "\n".join(f"- {t}" for t in available_tables)
        
        prompt = QUERY_PLANNING_PROMPT.format(
            user_query=query,
            available_tables=table_list
        )
        
        try:
            response = await llm_gateway.generate_response(prompt, [])
            
            # Parse JSON response
            response_clean = response.strip()
            if response_clean.startswith("```json"):
                response_clean = response_clean[7:]
            if response_clean.startswith("```"):
                response_clean = response_clean[3:]
            if response_clean.endswith("```"):
                response_clean = response_clean[:-3]
            
            plan_data = json.loads(response_clean.strip())
            
            # Convert to QueryPlan
            steps = []
            for i, step_data in enumerate(plan_data.get("steps", [])):
                step = QueryStep(
                    step_id=f"step_{i+1}",
                    step_type=StepType(step_data.get("type", "data_query")),
                    description=step_data.get("description", ""),
                    tables_needed=step_data.get("tables", []),
                    sql_template=step_data.get("sql_hint", None),
                    depends_on=step_data.get("depends_on", []),
                    result_name=step_data.get("result_name", f"result_{i+1}")
                )
                steps.append(step)
            
            # Determine parallel groups (steps with no dependencies can run together)
            parallel_groups = self._compute_parallel_groups(steps)
            
            return QueryPlan(
                original_query=query,
                is_complex=True,
                steps=steps,
                parallel_groups=parallel_groups,
                final_combination_logic=plan_data.get("combination_logic", "")
            )
            
        except Exception as e:
            logger.error(f"Query planning failed: {e}")
            # Return single-step fallback
            return QueryPlan(
                original_query=query,
                is_complex=False,
                steps=[],
                parallel_groups=[],
                final_combination_logic=""
            )
    
    def _compute_parallel_groups(self, steps: List[QueryStep]) -> List[List[str]]:
        """
        Compute which steps can be executed in parallel.
        
        Steps with no dependencies (or only completed dependencies) 
        can run together.
        """
        if not steps:
            return []
        
        groups = []
        executed = set()
        remaining = {s.step_id: s for s in steps}
        
        while remaining:
            # Find steps whose dependencies are all satisfied
            ready = [
                step_id for step_id, step in remaining.items()
                if all(dep in executed for dep in step.depends_on)
            ]
            
            if not ready:
                # Avoid infinite loop - add remaining as sequential
                ready = list(remaining.keys())[:1]
            
            groups.append(ready)
            
            for step_id in ready:
                executed.add(step_id)
                del remaining[step_id]
        
        return groups
    
    def create_simple_plan(self, query: str) -> QueryPlan:
        """Create a simple single-step plan for non-complex queries."""
        return QueryPlan(
            original_query=query,
            is_complex=False,
            steps=[
                QueryStep(
                    step_id="step_1",
                    step_type=StepType.DATA_QUERY,
                    description="Execute query directly",
                    tables_needed=[],
                    depends_on=[]
                )
            ],
            parallel_groups=[["step_1"]]
        )


# Singleton accessor
_query_planner: Optional[QueryPlanner] = None


def get_query_planner() -> QueryPlanner:
    """Get the singleton query planner instance."""
    global _query_planner
    if _query_planner is None:
        _query_planner = QueryPlanner()
    return _query_planner