Spaces:
Sleeping
Sleeping
mission detection with summary
Browse files- .gitignore +1 -0
- README.md +8 -0
- app.py +5 -2
- inference.py +45 -29
- mission_planner.py +211 -0
- mission_summarizer.py +89 -0
- models/detectors/base.py +18 -0
- models/detectors/owlv2.py +56 -0
- models/model_loader.py +32 -15
- requirements.txt +1 -0
- utils/openai_client.py +67 -0
.gitignore
CHANGED
|
@@ -4,3 +4,4 @@ __pycache__/
|
|
| 4 |
*.log
|
| 5 |
*.tmp
|
| 6 |
.DS_Store
|
|
|
|
|
|
| 4 |
*.log
|
| 5 |
*.tmp
|
| 6 |
.DS_Store
|
| 7 |
+
.env
|
README.md
CHANGED
|
@@ -9,3 +9,11 @@ license: mit
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
+
|
| 13 |
+
## Mission-guided detections
|
| 14 |
+
|
| 15 |
+
1. Send a `POST /process_video` request with fields `video` (file) and `prompt` (mission text).
|
| 16 |
+
2. The backend feeds the mission text into an OpenAI (`gpt-4o-mini`) reasoning step that scores and ranks every YOLO/COCO class. Place your API key inside `.env` as either `OPENAI_API_KEY=...` or `OpenAI-API: ...`; the server loads it automatically on startup.
|
| 17 |
+
3. The top scored classes become the textual queries for the existing OWLv2 detector so the detections align with the mission.
|
| 18 |
+
4. After object detection finishes, another OpenAI call ingests the detection log plus the first/middle/last frame context and produces a natural-language summary of the mission outcome.
|
| 19 |
+
5. The HTTP response still streams the processed video, and it now embeds the structured mission plan (`x-mission-plan`) and text summary (`x-mission-summary`) in the headers.
|
app.py
CHANGED
|
@@ -63,7 +63,7 @@ async def process_video(
|
|
| 63 |
os.close(fd)
|
| 64 |
|
| 65 |
try:
|
| 66 |
-
run_inference(input_path, output_path, prompt, max_frames=10)
|
| 67 |
except ValueError as exc:
|
| 68 |
logging.exception("Video decoding failed.")
|
| 69 |
_safe_delete(input_path)
|
|
@@ -78,11 +78,14 @@ async def process_video(
|
|
| 78 |
_schedule_cleanup(background_tasks, input_path)
|
| 79 |
_schedule_cleanup(background_tasks, output_path)
|
| 80 |
|
| 81 |
-
|
| 82 |
path=output_path,
|
| 83 |
media_type="video/mp4",
|
| 84 |
filename="processed.mp4",
|
| 85 |
)
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
if __name__ == "__main__":
|
|
|
|
| 63 |
os.close(fd)
|
| 64 |
|
| 65 |
try:
|
| 66 |
+
output_path, mission_plan, mission_summary = run_inference(input_path, output_path, prompt, max_frames=10)
|
| 67 |
except ValueError as exc:
|
| 68 |
logging.exception("Video decoding failed.")
|
| 69 |
_safe_delete(input_path)
|
|
|
|
| 78 |
_schedule_cleanup(background_tasks, input_path)
|
| 79 |
_schedule_cleanup(background_tasks, output_path)
|
| 80 |
|
| 81 |
+
response = FileResponse(
|
| 82 |
path=output_path,
|
| 83 |
media_type="video/mp4",
|
| 84 |
filename="processed.mp4",
|
| 85 |
)
|
| 86 |
+
response.headers["x-mission-plan"] = mission_plan.to_json()
|
| 87 |
+
response.headers["x-mission-summary"] = mission_summary.replace("\n", " ").strip()
|
| 88 |
+
return response
|
| 89 |
|
| 90 |
|
| 91 |
if __name__ == "__main__":
|
inference.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
import logging
|
| 2 |
-
from typing import List, Optional
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
import numpy as np
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
from
|
| 9 |
from utils.video import extract_frames, write_video
|
| 10 |
|
| 11 |
|
|
@@ -19,51 +19,67 @@ def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
|
| 19 |
return output
|
| 20 |
|
| 21 |
|
| 22 |
-
def
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
else:
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
except Exception:
|
| 43 |
-
logging.exception("Inference failed for
|
| 44 |
raise
|
| 45 |
-
return draw_boxes(frame,
|
| 46 |
|
| 47 |
|
| 48 |
def run_inference(
|
| 49 |
input_video_path: str,
|
| 50 |
output_video_path: str,
|
| 51 |
-
|
| 52 |
max_frames: Optional[int] = None,
|
| 53 |
-
) -> str:
|
| 54 |
try:
|
| 55 |
frames, fps, width, height = extract_frames(input_video_path)
|
| 56 |
except ValueError as exc:
|
| 57 |
logging.exception("Failed to decode video at %s", input_video_path)
|
| 58 |
raise
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
processed_frames: List[np.ndarray] = []
|
|
|
|
| 61 |
for idx, frame in enumerate(frames):
|
| 62 |
if max_frames is not None and idx >= max_frames:
|
| 63 |
break
|
| 64 |
logging.debug("Processing frame %d", idx)
|
| 65 |
-
processed_frame = infer_frame(frame,
|
|
|
|
| 66 |
processed_frames.append(processed_frame)
|
| 67 |
|
| 68 |
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
| 69 |
-
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 3 |
|
| 4 |
import cv2
|
| 5 |
import numpy as np
|
| 6 |
+
from models.model_loader import load_detector
|
| 7 |
+
from mission_planner import MissionPlan, get_mission_plan
|
| 8 |
+
from mission_summarizer import summarize_results
|
| 9 |
from utils.video import extract_frames, write_video
|
| 10 |
|
| 11 |
|
|
|
|
| 19 |
return output
|
| 20 |
|
| 21 |
|
| 22 |
+
def _build_detection_records(
|
| 23 |
+
boxes: np.ndarray,
|
| 24 |
+
scores: Sequence[float],
|
| 25 |
+
labels: Sequence[int],
|
| 26 |
+
queries: Sequence[str],
|
| 27 |
+
) -> List[Dict[str, Any]]:
|
| 28 |
+
detections: List[Dict[str, Any]] = []
|
| 29 |
+
for idx, box in enumerate(boxes):
|
| 30 |
+
label_idx = int(labels[idx]) if idx < len(labels) else -1
|
| 31 |
+
if 0 <= label_idx < len(queries):
|
| 32 |
+
label = queries[label_idx]
|
| 33 |
else:
|
| 34 |
+
label = f"label_{label_idx}"
|
| 35 |
+
detections.append(
|
| 36 |
+
{
|
| 37 |
+
"label": label,
|
| 38 |
+
"score": float(scores[idx]) if idx < len(scores) else 0.0,
|
| 39 |
+
"bbox": [int(coord) for coord in box.tolist()],
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
return detections
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def infer_frame(frame: np.ndarray, queries: Sequence[str]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
|
| 46 |
+
detector = load_detector()
|
| 47 |
+
text_queries = list(queries) or ["object"]
|
| 48 |
+
try:
|
| 49 |
+
result = detector.predict(frame, text_queries)
|
| 50 |
+
detections = _build_detection_records(result.boxes, result.scores, result.labels, text_queries)
|
| 51 |
except Exception:
|
| 52 |
+
logging.exception("Inference failed for queries %s", text_queries)
|
| 53 |
raise
|
| 54 |
+
return draw_boxes(frame, result.boxes), detections
|
| 55 |
|
| 56 |
|
| 57 |
def run_inference(
|
| 58 |
input_video_path: str,
|
| 59 |
output_video_path: str,
|
| 60 |
+
mission_prompt: str,
|
| 61 |
max_frames: Optional[int] = None,
|
| 62 |
+
) -> Tuple[str, MissionPlan, str]:
|
| 63 |
try:
|
| 64 |
frames, fps, width, height = extract_frames(input_video_path)
|
| 65 |
except ValueError as exc:
|
| 66 |
logging.exception("Failed to decode video at %s", input_video_path)
|
| 67 |
raise
|
| 68 |
|
| 69 |
+
mission_plan = get_mission_plan(mission_prompt)
|
| 70 |
+
logging.info("Mission plan: %s", mission_plan.to_json())
|
| 71 |
+
queries = mission_plan.queries()
|
| 72 |
+
|
| 73 |
processed_frames: List[np.ndarray] = []
|
| 74 |
+
detection_log: List[Dict[str, Any]] = []
|
| 75 |
for idx, frame in enumerate(frames):
|
| 76 |
if max_frames is not None and idx >= max_frames:
|
| 77 |
break
|
| 78 |
logging.debug("Processing frame %d", idx)
|
| 79 |
+
processed_frame, detections = infer_frame(frame, queries)
|
| 80 |
+
detection_log.append({"frame_index": idx, "detections": detections})
|
| 81 |
processed_frames.append(processed_frame)
|
| 82 |
|
| 83 |
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
| 84 |
+
mission_summary = summarize_results(mission_prompt, mission_plan, detection_log)
|
| 85 |
+
return output_video_path, mission_plan, mission_summary
|
mission_planner.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import asdict, dataclass
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
from utils.openai_client import get_openai_client
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
YOLO_CLASSES: Tuple[str, ...] = (
|
| 12 |
+
"person",
|
| 13 |
+
"bicycle",
|
| 14 |
+
"car",
|
| 15 |
+
"motorcycle",
|
| 16 |
+
"airplane",
|
| 17 |
+
"bus",
|
| 18 |
+
"train",
|
| 19 |
+
"truck",
|
| 20 |
+
"boat",
|
| 21 |
+
"traffic light",
|
| 22 |
+
"fire hydrant",
|
| 23 |
+
"stop sign",
|
| 24 |
+
"parking meter",
|
| 25 |
+
"bench",
|
| 26 |
+
"bird",
|
| 27 |
+
"cat",
|
| 28 |
+
"dog",
|
| 29 |
+
"horse",
|
| 30 |
+
"sheep",
|
| 31 |
+
"cow",
|
| 32 |
+
"elephant",
|
| 33 |
+
"bear",
|
| 34 |
+
"zebra",
|
| 35 |
+
"giraffe",
|
| 36 |
+
"backpack",
|
| 37 |
+
"umbrella",
|
| 38 |
+
"handbag",
|
| 39 |
+
"tie",
|
| 40 |
+
"suitcase",
|
| 41 |
+
"frisbee",
|
| 42 |
+
"skis",
|
| 43 |
+
"snowboard",
|
| 44 |
+
"sports ball",
|
| 45 |
+
"kite",
|
| 46 |
+
"baseball bat",
|
| 47 |
+
"baseball glove",
|
| 48 |
+
"skateboard",
|
| 49 |
+
"surfboard",
|
| 50 |
+
"tennis racket",
|
| 51 |
+
"bottle",
|
| 52 |
+
"wine glass",
|
| 53 |
+
"cup",
|
| 54 |
+
"fork",
|
| 55 |
+
"knife",
|
| 56 |
+
"spoon",
|
| 57 |
+
"bowl",
|
| 58 |
+
"banana",
|
| 59 |
+
"apple",
|
| 60 |
+
"sandwich",
|
| 61 |
+
"orange",
|
| 62 |
+
"broccoli",
|
| 63 |
+
"carrot",
|
| 64 |
+
"hot dog",
|
| 65 |
+
"pizza",
|
| 66 |
+
"donut",
|
| 67 |
+
"cake",
|
| 68 |
+
"chair",
|
| 69 |
+
"couch",
|
| 70 |
+
"potted plant",
|
| 71 |
+
"bed",
|
| 72 |
+
"dining table",
|
| 73 |
+
"toilet",
|
| 74 |
+
"tv",
|
| 75 |
+
"laptop",
|
| 76 |
+
"mouse",
|
| 77 |
+
"remote",
|
| 78 |
+
"keyboard",
|
| 79 |
+
"cell phone",
|
| 80 |
+
"microwave",
|
| 81 |
+
"oven",
|
| 82 |
+
"toaster",
|
| 83 |
+
"sink",
|
| 84 |
+
"refrigerator",
|
| 85 |
+
"book",
|
| 86 |
+
"clock",
|
| 87 |
+
"vase",
|
| 88 |
+
"scissors",
|
| 89 |
+
"teddy bear",
|
| 90 |
+
"hair drier",
|
| 91 |
+
"toothbrush",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class MissionClass:
|
| 100 |
+
name: str
|
| 101 |
+
score: float
|
| 102 |
+
rationale: str
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class MissionPlan:
|
| 107 |
+
mission: str
|
| 108 |
+
relevant_classes: List[MissionClass]
|
| 109 |
+
|
| 110 |
+
def queries(self) -> List[str]:
|
| 111 |
+
return [entry.name for entry in self.relevant_classes]
|
| 112 |
+
|
| 113 |
+
def to_dict(self) -> dict:
|
| 114 |
+
return {
|
| 115 |
+
"mission": self.mission,
|
| 116 |
+
"classes": [asdict(entry) for entry in self.relevant_classes],
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
def to_json(self) -> str:
|
| 120 |
+
return json.dumps(self.to_dict())
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MissionReasoner:
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
*,
|
| 127 |
+
model_name: str = DEFAULT_OPENAI_MODEL,
|
| 128 |
+
top_k: int = 10,
|
| 129 |
+
) -> None:
|
| 130 |
+
self._model_name = model_name
|
| 131 |
+
self._top_k = top_k
|
| 132 |
+
|
| 133 |
+
def plan(self, mission: str) -> MissionPlan:
|
| 134 |
+
mission = (mission or "").strip()
|
| 135 |
+
if not mission:
|
| 136 |
+
raise ValueError("Mission prompt cannot be empty.")
|
| 137 |
+
response_payload = self._query_llm(mission)
|
| 138 |
+
relevant = self._parse_plan(response_payload, fallback_mission=mission)
|
| 139 |
+
return MissionPlan(mission=response_payload.get("mission", mission), relevant_classes=relevant[: self._top_k])
|
| 140 |
+
|
| 141 |
+
def _query_llm(self, mission: str) -> Dict[str, object]:
|
| 142 |
+
client = get_openai_client()
|
| 143 |
+
system_prompt = (
|
| 144 |
+
"You are a mission-planning assistant helping a vision system select which YOLO object classes to detect. "
|
| 145 |
+
"You must only reference the provided list of YOLO classes."
|
| 146 |
+
)
|
| 147 |
+
classes_blob = ", ".join(YOLO_CLASSES)
|
| 148 |
+
user_prompt = (
|
| 149 |
+
f"Mission: {mission}\n"
|
| 150 |
+
f"Available YOLO classes: {classes_blob}\n"
|
| 151 |
+
f"Return JSON with: mission (string) and classes (array). "
|
| 152 |
+
f"Each entry needs name, score (0-1 float), rationale. "
|
| 153 |
+
f"Limit to at most {self._top_k} classes. Only choose names from the list."
|
| 154 |
+
)
|
| 155 |
+
completion = client.chat.completions.create(
|
| 156 |
+
model=self._model_name,
|
| 157 |
+
temperature=0.2,
|
| 158 |
+
response_format={"type": "json_object"},
|
| 159 |
+
messages=[
|
| 160 |
+
{"role": "system", "content": system_prompt},
|
| 161 |
+
{"role": "user", "content": user_prompt},
|
| 162 |
+
],
|
| 163 |
+
)
|
| 164 |
+
content = completion.choices[0].message.content or "{}"
|
| 165 |
+
try:
|
| 166 |
+
return json.loads(content)
|
| 167 |
+
except json.JSONDecodeError:
|
| 168 |
+
logging.exception("LLM returned non-JSON content: %s", content)
|
| 169 |
+
return {"mission": mission, "classes": []}
|
| 170 |
+
|
| 171 |
+
def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]:
|
| 172 |
+
entries = payload.get("classes") or payload.get("relevant_classes") or []
|
| 173 |
+
mission = payload.get("mission") or fallback_mission
|
| 174 |
+
parsed: List[MissionClass] = []
|
| 175 |
+
seen = set()
|
| 176 |
+
for entry in entries:
|
| 177 |
+
if not isinstance(entry, dict):
|
| 178 |
+
continue
|
| 179 |
+
name = str(entry.get("name") or "").strip()
|
| 180 |
+
if not name or name not in YOLO_CLASSES or name in seen:
|
| 181 |
+
continue
|
| 182 |
+
seen.add(name)
|
| 183 |
+
score_raw = entry.get("score")
|
| 184 |
+
try:
|
| 185 |
+
score = float(score_raw)
|
| 186 |
+
except (TypeError, ValueError):
|
| 187 |
+
score = 0.5
|
| 188 |
+
rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.")
|
| 189 |
+
parsed.append(MissionClass(name=name, score=max(0.0, min(1.0, score)), rationale=rationale))
|
| 190 |
+
|
| 191 |
+
if not parsed:
|
| 192 |
+
logging.warning("LLM returned no usable classes. Falling back to default YOLO list.")
|
| 193 |
+
parsed = [
|
| 194 |
+
MissionClass(
|
| 195 |
+
name=label,
|
| 196 |
+
score=1.0 - (idx * 0.05),
|
| 197 |
+
rationale=f"Fallback selection for mission '{mission}'.",
|
| 198 |
+
)
|
| 199 |
+
for idx, label in enumerate(YOLO_CLASSES[: self._top_k])
|
| 200 |
+
]
|
| 201 |
+
return parsed
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
_REASONER: MissionReasoner | None = None
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_mission_plan(mission: str) -> MissionPlan:
|
| 208 |
+
global _REASONER
|
| 209 |
+
if _REASONER is None:
|
| 210 |
+
_REASONER = MissionReasoner()
|
| 211 |
+
return _REASONER.plan(mission)
|
mission_summarizer.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from mission_planner import MissionPlan
|
| 8 |
+
from utils.openai_client import get_openai_client
|
| 9 |
+
|
| 10 |
+
SUMMARY_MODEL = "gpt-4o-mini"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _trim_detections(detections: List[Dict[str, Any]], max_boxes: int = 5) -> List[Dict[str, Any]]:
|
| 14 |
+
if len(detections) <= max_boxes:
|
| 15 |
+
return detections
|
| 16 |
+
return detections[:max_boxes]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _build_context_snapshot(records: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 20 |
+
if not records:
|
| 21 |
+
return {}
|
| 22 |
+
first = records[0]
|
| 23 |
+
middle = records[len(records) // 2]
|
| 24 |
+
last = records[-1]
|
| 25 |
+
return {
|
| 26 |
+
"first_frame": {
|
| 27 |
+
"frame_index": first["frame_index"],
|
| 28 |
+
"detections": _trim_detections(first.get("detections", [])),
|
| 29 |
+
},
|
| 30 |
+
"middle_frame": {
|
| 31 |
+
"frame_index": middle["frame_index"],
|
| 32 |
+
"detections": _trim_detections(middle.get("detections", [])),
|
| 33 |
+
},
|
| 34 |
+
"last_frame": {
|
| 35 |
+
"frame_index": last["frame_index"],
|
| 36 |
+
"detections": _trim_detections(last.get("detections", [])),
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def summarize_results(
|
| 42 |
+
mission_prompt: str,
|
| 43 |
+
mission_plan: MissionPlan,
|
| 44 |
+
detection_log: List[Dict[str, Any]],
|
| 45 |
+
) -> str:
|
| 46 |
+
if not detection_log:
|
| 47 |
+
return "No detections were produced, so no summary is available."
|
| 48 |
+
|
| 49 |
+
context_snapshot = _build_context_snapshot(detection_log)
|
| 50 |
+
payload = {
|
| 51 |
+
"mission_prompt": mission_prompt,
|
| 52 |
+
"mission_plan": mission_plan.to_dict(),
|
| 53 |
+
"global_context": context_snapshot,
|
| 54 |
+
"detection_log": [
|
| 55 |
+
{
|
| 56 |
+
"frame_index": entry["frame_index"],
|
| 57 |
+
"detections": _trim_detections(entry.get("detections", []), max_boxes=8),
|
| 58 |
+
}
|
| 59 |
+
for entry in detection_log
|
| 60 |
+
],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
system_prompt = (
|
| 64 |
+
"You are a surveillance analyst. Review structured detections aligned to a mission and summarize actionable "
|
| 65 |
+
"insights, highlighting objects of interest, temporal trends, and any security concerns. "
|
| 66 |
+
"Base conclusions solely on the provided data; if nothing is detected, explicitly state that."
|
| 67 |
+
)
|
| 68 |
+
messages = [
|
| 69 |
+
{"role": "system", "content": system_prompt},
|
| 70 |
+
{
|
| 71 |
+
"role": "user",
|
| 72 |
+
"content": (
|
| 73 |
+
"Use this JSON to summarize the mission outcome:\n"
|
| 74 |
+
f"{json.dumps(payload, ensure_ascii=False)}"
|
| 75 |
+
),
|
| 76 |
+
},
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
client = get_openai_client()
|
| 81 |
+
completion = client.chat.completions.create(
|
| 82 |
+
model=SUMMARY_MODEL,
|
| 83 |
+
temperature=0.2,
|
| 84 |
+
messages=messages,
|
| 85 |
+
)
|
| 86 |
+
return (completion.choices[0].message.content or "").strip()
|
| 87 |
+
except Exception:
|
| 88 |
+
logging.exception("Failed to generate mission summary.")
|
| 89 |
+
return "Mission summary generation failed."
|
models/detectors/base.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import NamedTuple, Sequence
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DetectionResult(NamedTuple):
|
| 7 |
+
boxes: np.ndarray
|
| 8 |
+
scores: Sequence[float]
|
| 9 |
+
labels: Sequence[int]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ObjectDetector:
|
| 13 |
+
"""Detector interface to keep inference agnostic to model details."""
|
| 14 |
+
|
| 15 |
+
name: str
|
| 16 |
+
|
| 17 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 18 |
+
raise NotImplementedError
|
models/detectors/owlv2.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import Owlv2ForObjectDetection, Owlv2Processor
|
| 7 |
+
|
| 8 |
+
from models.detectors.base import DetectionResult, ObjectDetector
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Owlv2Detector(ObjectDetector):
|
| 12 |
+
MODEL_NAME = "google/owlv2-large-patch14"
|
| 13 |
+
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 17 |
+
self.processor = Owlv2Processor.from_pretrained(self.MODEL_NAME)
|
| 18 |
+
torch_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
| 19 |
+
self.model = Owlv2ForObjectDetection.from_pretrained(
|
| 20 |
+
self.MODEL_NAME, torch_dtype=torch_dtype
|
| 21 |
+
)
|
| 22 |
+
self.model.to(self.device)
|
| 23 |
+
self.model.eval()
|
| 24 |
+
self.name = "owlv2"
|
| 25 |
+
|
| 26 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 27 |
+
inputs = self.processor(text=queries, images=frame, return_tensors="pt")
|
| 28 |
+
if hasattr(inputs, "to"):
|
| 29 |
+
inputs = inputs.to(self.device)
|
| 30 |
+
else:
|
| 31 |
+
inputs = {
|
| 32 |
+
key: value.to(self.device) if hasattr(value, "to") else value
|
| 33 |
+
for key, value in inputs.items()
|
| 34 |
+
}
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
outputs = self.model(**inputs)
|
| 37 |
+
processed = self.processor.post_process_object_detection(
|
| 38 |
+
outputs, threshold=0.3, target_sizes=[frame.shape[:2]]
|
| 39 |
+
)[0]
|
| 40 |
+
boxes = processed["boxes"]
|
| 41 |
+
scores = processed.get("scores", [])
|
| 42 |
+
labels = processed.get("labels", [])
|
| 43 |
+
boxes_np = boxes.cpu().numpy() if hasattr(boxes, "cpu") else np.asarray(boxes)
|
| 44 |
+
if hasattr(scores, "cpu"):
|
| 45 |
+
scores_seq = scores.cpu().numpy().tolist()
|
| 46 |
+
elif isinstance(scores, np.ndarray):
|
| 47 |
+
scores_seq = scores.tolist()
|
| 48 |
+
else:
|
| 49 |
+
scores_seq = list(scores)
|
| 50 |
+
if hasattr(labels, "cpu"):
|
| 51 |
+
labels_seq = labels.cpu().numpy().tolist()
|
| 52 |
+
elif isinstance(labels, np.ndarray):
|
| 53 |
+
labels_seq = labels.tolist()
|
| 54 |
+
else:
|
| 55 |
+
labels_seq = list(labels)
|
| 56 |
+
return DetectionResult(boxes=boxes_np, scores=scores_seq, labels=labels_seq)
|
models/model_loader.py
CHANGED
|
@@ -1,20 +1,37 @@
|
|
| 1 |
-
import
|
| 2 |
-
from
|
|
|
|
| 3 |
|
| 4 |
-
import
|
| 5 |
-
from
|
| 6 |
|
| 7 |
-
|
| 8 |
-
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
_MODEL = Owlv2ForObjectDetection.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype)
|
| 14 |
-
_MODEL.to(_DEVICE)
|
| 15 |
-
_MODEL.eval()
|
| 16 |
|
| 17 |
|
| 18 |
-
def
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Callable, Dict, Optional
|
| 4 |
|
| 5 |
+
from models.detectors.base import ObjectDetector
|
| 6 |
+
from models.detectors.owlv2 import Owlv2Detector
|
| 7 |
|
| 8 |
+
DEFAULT_DETECTOR = "owlv2"
|
|
|
|
| 9 |
|
| 10 |
+
_REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
|
| 11 |
+
"owlv2": Owlv2Detector,
|
| 12 |
+
}
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
+
def _create_detector(name: str) -> ObjectDetector:
|
| 16 |
+
try:
|
| 17 |
+
factory = _REGISTRY[name]
|
| 18 |
+
except KeyError as exc:
|
| 19 |
+
available = ", ".join(sorted(_REGISTRY))
|
| 20 |
+
raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
|
| 21 |
+
return factory()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@lru_cache(maxsize=None)
|
| 25 |
+
def _get_cached_detector(name: str) -> ObjectDetector:
|
| 26 |
+
return _create_detector(name)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_detector(name: Optional[str] = None) -> ObjectDetector:
|
| 30 |
+
"""Return a cached detector instance selected via arg or OBJECT_DETECTOR env."""
|
| 31 |
+
detector_name = name or os.getenv("OBJECT_DETECTOR", DEFAULT_DETECTOR)
|
| 32 |
+
return _get_cached_detector(detector_name)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Backwards compatibility for existing callers.
|
| 36 |
+
def load_model():
|
| 37 |
+
return load_detector()
|
requirements.txt
CHANGED
|
@@ -7,3 +7,4 @@ python-multipart
|
|
| 7 |
accelerate
|
| 8 |
pillow
|
| 9 |
scipy
|
|
|
|
|
|
| 7 |
accelerate
|
| 8 |
pillow
|
| 9 |
scipy
|
| 10 |
+
openai
|
utils/openai_client.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict
|
| 6 |
+
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
|
| 9 |
+
ENV_FILE_NAME = ".env"
|
| 10 |
+
_ENV_KEY_CANDIDATES = (
|
| 11 |
+
"OPENAI_API_KEY",
|
| 12 |
+
"OpenAI_API_KEY",
|
| 13 |
+
"OpenAI-API",
|
| 14 |
+
"OpenAI_API",
|
| 15 |
+
"OPENAIKEY",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
_OPENAI_CLIENT: OpenAI | None = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _read_env_file(path: Path) -> Dict[str, str]:
|
| 22 |
+
entries: Dict[str, str] = {}
|
| 23 |
+
if not path.exists():
|
| 24 |
+
return entries
|
| 25 |
+
for raw_line in path.read_text().splitlines():
|
| 26 |
+
line = raw_line.strip()
|
| 27 |
+
if not line or line.startswith("#"):
|
| 28 |
+
continue
|
| 29 |
+
if "=" in line:
|
| 30 |
+
key, value = line.split("=", 1)
|
| 31 |
+
elif ":" in line:
|
| 32 |
+
key, value = line.split(":", 1)
|
| 33 |
+
else:
|
| 34 |
+
continue
|
| 35 |
+
entries[key.strip()] = value.strip().strip('"').strip("'")
|
| 36 |
+
return entries
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def ensure_openai_api_key() -> str:
|
| 40 |
+
key = os.getenv("OPENAI_API_KEY")
|
| 41 |
+
if key:
|
| 42 |
+
return key
|
| 43 |
+
|
| 44 |
+
env_path = Path(__file__).resolve().parent.parent / ENV_FILE_NAME
|
| 45 |
+
env_entries = _read_env_file(env_path)
|
| 46 |
+
for candidate in _ENV_KEY_CANDIDATES:
|
| 47 |
+
if env_entries.get(candidate):
|
| 48 |
+
key = env_entries[candidate]
|
| 49 |
+
break
|
| 50 |
+
else:
|
| 51 |
+
key = None
|
| 52 |
+
|
| 53 |
+
if not key:
|
| 54 |
+
raise RuntimeError(
|
| 55 |
+
"OpenAI API key is not configured. Set OPENAI_API_KEY or add it to .env (e.g., 'OpenAI-API: sk-...')."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
os.environ["OPENAI_API_KEY"] = key
|
| 59 |
+
return key
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_openai_client() -> OpenAI:
|
| 63 |
+
global _OPENAI_CLIENT
|
| 64 |
+
if _OPENAI_CLIENT is None:
|
| 65 |
+
api_key = ensure_openai_api_key()
|
| 66 |
+
_OPENAI_CLIENT = OpenAI(api_key=api_key)
|
| 67 |
+
return _OPENAI_CLIENT
|