Spaces:
Running
Running
| """ | |
| 圖構建器 - 構建 OR-Tools 需要的圖結構 | |
| 完全保留原始 tsptw_solver_old.py 的圖構建邏輯 | |
| """ | |
| from typing import List, Dict, Tuple, Any | |
| import numpy as np | |
| from src.infra.logger import get_logger | |
| from src.services.local_route_estimator import LocalRouteEstimator | |
| from ..models.internal_models import _Task, _Location, _Graph | |
| logger = get_logger(__name__) | |
| class GraphBuilder: | |
| """ | |
| 圖構建器 | |
| 職責: | |
| - 收集所有地點(depot + POI candidates) | |
| - 構建節點元數據(type, task_id, poi_id, time_window...) | |
| - 計算距離/時間矩陣 | |
| """ | |
| def __init__(self, **kwargs): | |
| self.estimator = LocalRouteEstimator() | |
| def build_graph( | |
| self, | |
| start_location: _Location, | |
| tasks: List[_Task], | |
| travel_mode="DRIVE", | |
| ) -> _Graph: | |
| """ | |
| 構建完整的圖 | |
| 完全保留原始邏輯: | |
| - _build_locations_and_meta() | |
| - _build_service_time_per_node() | |
| - compute_route_matrix() | |
| Returns: | |
| _Graph: 包含 node_meta, locations, duration_matrix, distance_matrix | |
| """ | |
| # 1. 構建位置列表和節點元數據 | |
| locations, node_meta = self._build_locations_and_meta( | |
| start_location, tasks | |
| ) | |
| num_nodes = len(locations) | |
| logger.info(f"GraphBuilder: {num_nodes} nodes (1 depot + {num_nodes - 1} POIs)") | |
| # 2. 計算距離/時間矩陣 | |
| if num_nodes <= 1: | |
| # 只有 depot,沒有 POI | |
| duration_matrix = np.zeros((1, 1), dtype=int) | |
| distance_matrix = np.zeros((1, 1), dtype=int) | |
| else: | |
| duration_matrix, distance_matrix = self._calculate_matrices(locations) | |
| # 3. 返回圖 | |
| return _Graph( | |
| node_meta=node_meta, | |
| locations=locations, | |
| duration_matrix=duration_matrix, | |
| distance_matrix=distance_matrix, | |
| ) | |
| def _build_locations_and_meta( | |
| start_location: _Location, | |
| tasks: List[_Task], | |
| ) -> Tuple[List[Dict[str, float]], List[Dict[str, Any]]]: | |
| """ | |
| 構建位置列表和節點元數據 | |
| 完全保留原始邏輯: _build_locations_and_meta() | |
| Returns: | |
| locations: [{"lat": float, "lng": float}, ...] | |
| node_meta: [{"type": "depot"}, {"type": "poi", ...}, ...] | |
| """ | |
| locations: List[Dict[str, float]] = [] | |
| node_meta: List[Dict[str, Any]] = [] | |
| # depot | |
| locations.append({"lat": start_location.lat, "lng": start_location.lng}) | |
| node_meta.append({"type": "depot"}) | |
| # candidate POIs | |
| for task_idx, task in enumerate(tasks): | |
| for cand_idx, cand in enumerate(task.candidates): | |
| lat = cand.lat | |
| lng = cand.lng | |
| time_windows = cand.time_windows | |
| if time_windows is None: | |
| if cand.time_window is not None: | |
| time_windows = [cand.time_window] | |
| else: | |
| time_windows = [None] | |
| for interval_idx, tw in enumerate(time_windows): | |
| locations.append({"lat": lat, "lng": lng}) | |
| node_meta.append( | |
| { | |
| "type": "poi", | |
| "task_id": task.task_id, | |
| "poi_id": cand.poi_id, | |
| "task_idx": task_idx, | |
| "candidate_idx": cand_idx, | |
| "interval_idx": interval_idx, | |
| "poi_time_window": tw, | |
| } | |
| ) | |
| return locations, node_meta | |
| def _calculate_matrices( | |
| self, | |
| locations: List[Dict[str, float]], | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| 計算距離/時間矩陣 | |
| 完全保留原始邏輯: 調用 gmaps.compute_route_matrix() | |
| Returns: | |
| (duration_matrix, distance_matrix): 兩個 numpy array | |
| """ | |
| locations_dict = [{"lat": loc["lat"], "lng": loc["lng"]} for loc in locations] | |
| compute_route_result = self.estimator.compute_route_matrix( | |
| origins=locations_dict, | |
| destinations=locations_dict) | |
| duration_matrix = np.array(compute_route_result["duration_matrix"]) | |
| distance_matrix = np.array(compute_route_result["distance_matrix"]) | |
| return duration_matrix, distance_matrix | |
| def build_service_time_per_node( | |
| tasks: List[_Task], | |
| node_meta: List[Dict[str, Any]], | |
| ) -> List[int]: | |
| """ | |
| 構建每個節點的服務時間(秒) | |
| 完全保留原始邏輯: _build_service_time_per_node() | |
| Returns: | |
| service_time: [0, service_sec, service_sec, ...] | |
| """ | |
| service_time = [0] * len(node_meta) | |
| for node, meta in enumerate(node_meta): | |
| if meta["type"] == "poi": | |
| task_idx = meta["task_idx"] | |
| task = tasks[task_idx] | |
| service_time[node] = task.service_duration_sec | |
| return service_time |