Spaces:
Running
Running
| import json | |
| import ast | |
| from agno.tools import Toolkit | |
| from src.services.googlemap_api_service import GoogleMapAPIService | |
| from src.infra.poi_repository import poi_repo | |
| from src.infra.logger import get_logger | |
| logger = get_logger(__name__) | |
| MAX_SEARCH = 1000 | |
| class ScoutToolkit(Toolkit): | |
| def __init__(self, google_maps_api_key: str): | |
| super().__init__(name="scout_toolkit") | |
| self.gmaps = GoogleMapAPIService(api_key=google_maps_api_key) | |
| self.register(self.search_and_offload) | |
| def _extract_first_json_object(self, text: str) -> str: | |
| text = text.strip() | |
| # 1. 尋找第一個 '{' | |
| start_idx = text.find('{') | |
| if start_idx == -1: | |
| return text # 找不到,原樣回傳碰運氣 | |
| # 2. 開始數括號 | |
| balance = 0 | |
| for i in range(start_idx, len(text)): | |
| char = text[i] | |
| if char == '{': | |
| balance += 1 | |
| elif char == '}': | |
| balance -= 1 | |
| # 當括號歸零時,代表找到了一個完整的 JSON Object | |
| if balance == 0: | |
| return text[start_idx: i + 1] | |
| # 如果跑完迴圈 balance 還不是 0,代表 JSON 被截斷了 (Truncated) | |
| return text[start_idx:] | |
| def _robust_parse_json(self, text: str) -> dict: | |
| """ | |
| 強力解析器 | |
| """ | |
| # 1. 先移除 Markdown Code Block 標記 (如果有) | |
| if "```" in text: | |
| lines = text.split('\n') | |
| clean_lines = [] | |
| in_code = False | |
| for line in lines: | |
| if "```" in line: | |
| in_code = not in_code | |
| continue | |
| if in_code: # 只保留 code block 內的內容 | |
| clean_lines.append(line) | |
| # 如果有提取到內容,就用提取的;否則假設整個 text 都是 | |
| if clean_lines: | |
| text = "\n".join(clean_lines) | |
| # 2. 使用堆疊提取器抓出純淨的 JSON 字串 | |
| json_str = self._extract_first_json_object(text) | |
| # 3. 第一關:標準 JSON load | |
| try: | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| pass | |
| # 4. 第二關:處理 Python 風格 (單引號, True/False/None) | |
| try: | |
| return ast.literal_eval(json_str) | |
| except (ValueError, SyntaxError): | |
| pass | |
| # 5. 第三關:暴力修正 (針對 Python 字串中的 unescaped quotes) | |
| # 嘗試把 Python 的 None/True/False 換成 JSON 格式 | |
| try: | |
| fixed_text = json_str.replace("True", "true").replace("False", "false").replace("None", "null") | |
| return json.loads(fixed_text) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"Failed to parse JSON via all methods. Raw: {text}...") from e | |
| def search_and_offload(self, task_list_json: str) -> str: | |
| """ | |
| Performs a proximity search for POIs based on the provided tasks and global context, then offloads results to the DB. | |
| CRITICAL: The input JSON **MUST** include the 'global_info' section containing 'start_location' (lat, lng) to ensure searches are performed nearby the user's starting point, not in a random location. | |
| Args: | |
| task_list_json (str): A JSON formatted string. The structure must be: | |
| { | |
| "global_info": { | |
| "language": str, | |
| "plan_type": str | |
| "return_to_start": bool, | |
| "start_location": ..., | |
| "departure_time": str, | |
| "deadline": str or null, | |
| }, | |
| "tasks": [ | |
| { | |
| "task_id": 1, | |
| "category": "MEAL" | "LEISURE" | "ERRAND" | "SHOPPING", | |
| "description": "Short description", | |
| "location_hint": "Clean Keyword for Google Maps", | |
| "priority": "HIGH" | "MEDIUM" | "LOW", | |
| "service_duration_min": 30, | |
| "time_window": { | |
| "earliest_time": "ISO 8601" or null, | |
| "latest_time": "ISO 8601" or null} | |
| } | |
| ..., | |
| ] | |
| } | |
| Returns: | |
| str: A Ref_id of DB system. | |
| """ | |
| try: | |
| #print("task_list_json", task_list_json) | |
| data = self._robust_parse_json(task_list_json) | |
| tasks = data.get("tasks", []) | |
| global_info = data.get("global_info", {}) | |
| except Exception as e: | |
| logger.warning(f"❌ JSON Parse Error: {e}") | |
| # 這裡回傳錯誤訊息給 Agent,讓它知道格式錯了,它通常會自我修正並重試 | |
| return f"❌ Error: Invalid JSON format. Please output RAW JSON only. Details: {e}" | |
| logger.debug(f"🕵️ Scout: Processing Global Info & {len(tasks)} tasks...") | |
| # ============================================================ | |
| # 1. 處理 Start Location & 設定錨點 (兼容性修復版) | |
| # ============================================================ | |
| # Helper: 提取 lat/lng (兼容 lat/latitude) | |
| def extract_lat_lng(d): | |
| if not isinstance(d, dict): return None, None | |
| lat = d.get("lat") or d.get("latitude") | |
| lng = d.get("lng") or d.get("longitude") | |
| return lat, lng | |
| start_loc = global_info.get("start_location") | |
| logger.info(f"🕵️ Scout: Start Location - {start_loc}") | |
| anchor_point = None | |
| if isinstance(start_loc, str): | |
| logger.debug(f"🕵️ Scout: Resolving Start Location Name '{start_loc}'...") | |
| try: | |
| results = self.gmaps.text_search(query=start_loc, limit=1) | |
| if results: | |
| loc = results[0].get("location", {}) | |
| lat = loc.get("latitude") or loc.get("lat") | |
| lng = loc.get("longitude") or loc.get("lng") | |
| name = results[0].get("name") or start_loc | |
| global_info["start_location"] = {"name": name, "lat": lat, "lng": lng} | |
| anchor_point = {"lat": lat, "lng": lng} | |
| logger.info(f" ✅ Resolved Start: {name}") | |
| except Exception as e: | |
| pass | |
| logger.warning(f" ❌ Error searching start location: {e}") | |
| elif isinstance(start_loc, dict): | |
| lat, lng = extract_lat_lng(start_loc) | |
| if lat is not None and lng is not None: | |
| # 已經有座標 | |
| anchor_point = {"lat": lat, "lng": lng} | |
| global_info["start_location"] = {"name": "User Location", "lat": lat, "lng": lng} | |
| logger.info(f" ✅ Anchor Point set from input: {anchor_point}") | |
| else: | |
| query_name = start_loc.get("name", "Unknown Start") | |
| logger.info(f"🕵️ Scout: Resolving Start Location Dict '{query_name}'...") | |
| try: | |
| results = self.gmaps.text_search(query=query_name, limit=1) | |
| if results: | |
| loc = results[0].get("location", {}) | |
| lat = loc.get("latitude") or loc.get("lat") | |
| lng = loc.get("longitude") or loc.get("lng") | |
| global_info["start_location"] = { | |
| "name": results[0].get("name", query_name), | |
| "lat": lat, "lng": lng | |
| } | |
| anchor_point = {"lat": lat, "lng": lng} | |
| logger.info(f" ✅ Resolved Start: {global_info['start_location']}") | |
| except Exception as e: | |
| logger.warning(f" ❌ Error searching start location: {e}") | |
| return json.dumps({ | |
| "status": "ERROR", | |
| #"message": "Search POI complete.", | |
| "scout_ref": None, | |
| "note": "Failed to resolve start location., Please check the input and try again." | |
| }) | |
| total_tasks_count = len(tasks) | |
| total_node_budget = MAX_SEARCH ** 0.5 | |
| HARD_CAP_PER_TASK = 15 | |
| MIN_LIMIT = 3 | |
| logger.info(f"🕵️ Scout: Starting Adaptive Search (Budget: {total_node_budget} nodes)") | |
| enriched_tasks = [] | |
| for i, task in enumerate(tasks): | |
| tasks_remaining = total_tasks_count - i | |
| if total_node_budget <= 0: | |
| current_limit = MIN_LIMIT | |
| else: | |
| allocation = total_node_budget // tasks_remaining | |
| current_limit = max(MIN_LIMIT, min(HARD_CAP_PER_TASK, allocation)) | |
| desc = task.get("description", "") | |
| hint = task.get("location_hint", "") | |
| query = hint if hint else desc | |
| if not query: query = "Unknown Location" | |
| try: | |
| places = self.gmaps.text_search( | |
| query=query, | |
| limit=current_limit, | |
| location=anchor_point | |
| ) | |
| except Exception as e: | |
| logger.warning(f"⚠️ Search failed for {query}: {e}") | |
| places = [] | |
| candidates = [] | |
| for p in places: | |
| loc = p.get("location", {}) | |
| lat = loc.get("latitude") if "latitude" in loc else loc.get("lat") | |
| lng = loc.get("longitude") if "longitude" in loc else loc.get("lng") | |
| if lat is not None and lng is not None: | |
| candidates.append({ | |
| "poi_id": p.get("place_id") or p.get("id"), | |
| "name": p.get("name") or p.get("displayName", {}).get("text"), | |
| "lat": lat, | |
| "lng": lng, | |
| "rating": p.get("rating"), | |
| "time_window": None | |
| }) | |
| # ✅ ID Fallback 機制 | |
| raw_id = task.get("task_id") or task.get("id") # 兼容 task_id 和 id | |
| if raw_id and str(raw_id).strip().lower() not in ["none", "", "null"]: | |
| task_id = str(raw_id) | |
| else: | |
| task_id = f"task_{i + 1}" | |
| task_entry = { | |
| "task_id": task_id, | |
| "priority": task.get("priority", "MEDIUM"), | |
| "service_duration_min": task.get("service_duration_min", 60), # 兼容欄位名 | |
| "time_window": task.get("time_window"), | |
| "candidates": candidates | |
| } | |
| enriched_tasks.append(task_entry) | |
| logger.info(f" - Task {task_id}: Found {len(candidates)} POIs") | |
| full_payload = {"global_info": global_info, "tasks": enriched_tasks} | |
| ref_id = poi_repo.save(full_payload, data_type="scout_result") | |
| return json.dumps({ | |
| "status": "SUCCESS", | |
| "message_task": len(enriched_tasks), | |
| "scout_ref": ref_id, | |
| "note": "Please pass this scout_ref to the Optimizer immediately." | |
| }) |