Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| from dataclasses import asdict, dataclass | |
| from typing import Dict, List, Tuple | |
| from prompt import mission_planner_system_prompt, mission_planner_user_prompt | |
| from utils.openai_client import get_openai_client | |
| YOLO_CLASSES: Tuple[str, ...] = ( | |
| "person", | |
| "bicycle", | |
| "car", | |
| "motorcycle", | |
| "airplane", | |
| "bus", | |
| "train", | |
| "truck", | |
| "boat", | |
| "traffic light", | |
| "fire hydrant", | |
| "stop sign", | |
| "parking meter", | |
| "bench", | |
| "bird", | |
| "cat", | |
| "dog", | |
| "horse", | |
| "sheep", | |
| "cow", | |
| "elephant", | |
| "bear", | |
| "zebra", | |
| "giraffe", | |
| "backpack", | |
| "umbrella", | |
| "handbag", | |
| "tie", | |
| "suitcase", | |
| "frisbee", | |
| "skis", | |
| "snowboard", | |
| "sports ball", | |
| "kite", | |
| "baseball bat", | |
| "baseball glove", | |
| "skateboard", | |
| "surfboard", | |
| "tennis racket", | |
| "bottle", | |
| "wine glass", | |
| "cup", | |
| "fork", | |
| "knife", | |
| "spoon", | |
| "bowl", | |
| "banana", | |
| "apple", | |
| "sandwich", | |
| "orange", | |
| "broccoli", | |
| "carrot", | |
| "hot dog", | |
| "pizza", | |
| "donut", | |
| "cake", | |
| "chair", | |
| "couch", | |
| "potted plant", | |
| "bed", | |
| "dining table", | |
| "toilet", | |
| "tv", | |
| "laptop", | |
| "mouse", | |
| "remote", | |
| "keyboard", | |
| "cell phone", | |
| "microwave", | |
| "oven", | |
| "toaster", | |
| "sink", | |
| "refrigerator", | |
| "book", | |
| "clock", | |
| "vase", | |
| "scissors", | |
| "teddy bear", | |
| "hair drier", | |
| "toothbrush", | |
| ) | |
| DEFAULT_OPENAI_MODEL = "gpt-4o-mini" | |
| class MissionClass: | |
| name: str | |
| score: float | |
| rationale: str | |
| class MissionPlan: | |
| mission: str | |
| relevant_classes: List[MissionClass] | |
| def queries(self) -> List[str]: | |
| return [entry.name for entry in self.relevant_classes] | |
| def to_dict(self) -> dict: | |
| return { | |
| "mission": self.mission, | |
| "classes": [asdict(entry) for entry in self.relevant_classes], | |
| } | |
| def to_json(self) -> str: | |
| return json.dumps(self.to_dict()) | |
| class MissionReasoner: | |
| def __init__( | |
| self, | |
| *, | |
| model_name: str = DEFAULT_OPENAI_MODEL, | |
| top_k: int = 10, | |
| ) -> None: | |
| self._model_name = model_name | |
| self._top_k = top_k | |
| def plan(self, mission: str) -> MissionPlan: | |
| mission = (mission or "").strip() | |
| if not mission: | |
| raise ValueError("Mission prompt cannot be empty.") | |
| response_payload = self._query_llm(mission) | |
| relevant = self._parse_plan(response_payload, fallback_mission=mission) | |
| return MissionPlan(mission=response_payload.get("mission", mission), relevant_classes=relevant[: self._top_k]) | |
| def _query_llm(self, mission: str) -> Dict[str, object]: | |
| client = get_openai_client() | |
| system_prompt = mission_planner_system_prompt() | |
| user_prompt = mission_planner_user_prompt(mission, YOLO_CLASSES, self._top_k) | |
| completion = client.chat.completions.create( | |
| model=self._model_name, | |
| temperature=0.2, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| content = completion.choices[0].message.content or "{}" | |
| try: | |
| return json.loads(content) | |
| except json.JSONDecodeError: | |
| logging.exception("LLM returned non-JSON content: %s", content) | |
| return {"mission": mission, "classes": []} | |
| def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]: | |
| entries = payload.get("classes") or payload.get("relevant_classes") or [] | |
| mission = payload.get("mission") or fallback_mission | |
| parsed: List[MissionClass] = [] | |
| seen = set() | |
| for entry in entries: | |
| if not isinstance(entry, dict): | |
| continue | |
| name = str(entry.get("name") or "").strip() | |
| if not name or name not in YOLO_CLASSES or name in seen: | |
| continue | |
| seen.add(name) | |
| score_raw = entry.get("score") | |
| try: | |
| score = float(score_raw) | |
| except (TypeError, ValueError): | |
| score = 0.5 | |
| rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.") | |
| parsed.append(MissionClass(name=name, score=max(0.0, min(1.0, score)), rationale=rationale)) | |
| if not parsed: | |
| logging.warning("LLM returned no usable classes. Falling back to default YOLO list.") | |
| parsed = [ | |
| MissionClass( | |
| name=label, | |
| score=1.0 - (idx * 0.05), | |
| rationale=f"Fallback selection for mission '{mission}'.", | |
| ) | |
| for idx, label in enumerate(YOLO_CLASSES[: self._top_k]) | |
| ] | |
| return parsed | |
| _REASONER: MissionReasoner | None = None | |
| def get_mission_plan(mission: str) -> MissionPlan: | |
| global _REASONER | |
| if _REASONER is None: | |
| _REASONER = MissionReasoner() | |
| return _REASONER.plan(mission) | |