zye0616 commited on
Commit
a8d3381
·
1 Parent(s): 788bb37

mission detection with summary

Browse files
.gitignore CHANGED
@@ -4,3 +4,4 @@ __pycache__/
4
  *.log
5
  *.tmp
6
  .DS_Store
 
 
4
  *.log
5
  *.tmp
6
  .DS_Store
7
+ .env
README.md CHANGED
@@ -9,3 +9,11 @@ license: mit
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+ ## Mission-guided detections
14
+
15
+ 1. Send a `POST /process_video` request with fields `video` (file) and `prompt` (mission text).
16
+ 2. The backend feeds the mission text into an OpenAI (`gpt-4o-mini`) reasoning step that scores and ranks every YOLO/COCO class. Place your API key inside `.env` as either `OPENAI_API_KEY=...` or `OpenAI-API: ...`; the server loads it automatically on startup.
17
+ 3. The top scored classes become the textual queries for the existing OWLv2 detector so the detections align with the mission.
18
+ 4. After object detection finishes, another OpenAI call ingests the detection log plus the first/middle/last frame context and produces a natural-language summary of the mission outcome.
19
+ 5. The HTTP response still streams the processed video, and it now embeds the structured mission plan (`x-mission-plan`) and text summary (`x-mission-summary`) in the headers.
app.py CHANGED
@@ -63,7 +63,7 @@ async def process_video(
63
  os.close(fd)
64
 
65
  try:
66
- run_inference(input_path, output_path, prompt, max_frames=10)
67
  except ValueError as exc:
68
  logging.exception("Video decoding failed.")
69
  _safe_delete(input_path)
@@ -78,11 +78,14 @@ async def process_video(
78
  _schedule_cleanup(background_tasks, input_path)
79
  _schedule_cleanup(background_tasks, output_path)
80
 
81
- return FileResponse(
82
  path=output_path,
83
  media_type="video/mp4",
84
  filename="processed.mp4",
85
  )
 
 
 
86
 
87
 
88
  if __name__ == "__main__":
 
63
  os.close(fd)
64
 
65
  try:
66
+ output_path, mission_plan, mission_summary = run_inference(input_path, output_path, prompt, max_frames=10)
67
  except ValueError as exc:
68
  logging.exception("Video decoding failed.")
69
  _safe_delete(input_path)
 
78
  _schedule_cleanup(background_tasks, input_path)
79
  _schedule_cleanup(background_tasks, output_path)
80
 
81
+ response = FileResponse(
82
  path=output_path,
83
  media_type="video/mp4",
84
  filename="processed.mp4",
85
  )
86
+ response.headers["x-mission-plan"] = mission_plan.to_json()
87
+ response.headers["x-mission-summary"] = mission_summary.replace("\n", " ").strip()
88
+ return response
89
 
90
 
91
  if __name__ == "__main__":
inference.py CHANGED
@@ -1,11 +1,11 @@
1
  import logging
2
- from typing import List, Optional
3
 
4
  import cv2
5
  import numpy as np
6
- import torch
7
-
8
- from models.model_loader import load_model
9
  from utils.video import extract_frames, write_video
10
 
11
 
@@ -19,51 +19,67 @@ def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
19
  return output
20
 
21
 
22
- def infer_frame(frame: np.ndarray, prompt: str) -> np.ndarray:
23
- processor, model, device = load_model()
24
- try:
25
- inputs = processor(text=[prompt], images=frame, return_tensors="pt")
26
- if hasattr(inputs, "to"):
27
- inputs = inputs.to(device)
 
 
 
 
 
28
  else:
29
- inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
30
- with torch.no_grad():
31
- outputs = model(**inputs)
32
- results = processor.post_process_object_detection(
33
- outputs,
34
- threshold=0.3,
35
- target_sizes=[frame.shape[:2]],
36
- )[0]
37
- boxes = results["boxes"]
38
- if hasattr(boxes, "cpu"):
39
- boxes_np = boxes.cpu().numpy()
40
- else:
41
- boxes_np = np.asarray(boxes)
 
 
 
 
42
  except Exception:
43
- logging.exception("Inference failed for prompt '%s'", prompt)
44
  raise
45
- return draw_boxes(frame, boxes_np)
46
 
47
 
48
  def run_inference(
49
  input_video_path: str,
50
  output_video_path: str,
51
- prompt: str,
52
  max_frames: Optional[int] = None,
53
- ) -> str:
54
  try:
55
  frames, fps, width, height = extract_frames(input_video_path)
56
  except ValueError as exc:
57
  logging.exception("Failed to decode video at %s", input_video_path)
58
  raise
59
 
 
 
 
 
60
  processed_frames: List[np.ndarray] = []
 
61
  for idx, frame in enumerate(frames):
62
  if max_frames is not None and idx >= max_frames:
63
  break
64
  logging.debug("Processing frame %d", idx)
65
- processed_frame = infer_frame(frame, prompt)
 
66
  processed_frames.append(processed_frame)
67
 
68
  write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
69
- return output_video_path
 
 
1
  import logging
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
3
 
4
  import cv2
5
  import numpy as np
6
+ from models.model_loader import load_detector
7
+ from mission_planner import MissionPlan, get_mission_plan
8
+ from mission_summarizer import summarize_results
9
  from utils.video import extract_frames, write_video
10
 
11
 
 
19
  return output
20
 
21
 
22
+ def _build_detection_records(
23
+ boxes: np.ndarray,
24
+ scores: Sequence[float],
25
+ labels: Sequence[int],
26
+ queries: Sequence[str],
27
+ ) -> List[Dict[str, Any]]:
28
+ detections: List[Dict[str, Any]] = []
29
+ for idx, box in enumerate(boxes):
30
+ label_idx = int(labels[idx]) if idx < len(labels) else -1
31
+ if 0 <= label_idx < len(queries):
32
+ label = queries[label_idx]
33
  else:
34
+ label = f"label_{label_idx}"
35
+ detections.append(
36
+ {
37
+ "label": label,
38
+ "score": float(scores[idx]) if idx < len(scores) else 0.0,
39
+ "bbox": [int(coord) for coord in box.tolist()],
40
+ }
41
+ )
42
+ return detections
43
+
44
+
45
+ def infer_frame(frame: np.ndarray, queries: Sequence[str]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
46
+ detector = load_detector()
47
+ text_queries = list(queries) or ["object"]
48
+ try:
49
+ result = detector.predict(frame, text_queries)
50
+ detections = _build_detection_records(result.boxes, result.scores, result.labels, text_queries)
51
  except Exception:
52
+ logging.exception("Inference failed for queries %s", text_queries)
53
  raise
54
+ return draw_boxes(frame, result.boxes), detections
55
 
56
 
57
  def run_inference(
58
  input_video_path: str,
59
  output_video_path: str,
60
+ mission_prompt: str,
61
  max_frames: Optional[int] = None,
62
+ ) -> Tuple[str, MissionPlan, str]:
63
  try:
64
  frames, fps, width, height = extract_frames(input_video_path)
65
  except ValueError as exc:
66
  logging.exception("Failed to decode video at %s", input_video_path)
67
  raise
68
 
69
+ mission_plan = get_mission_plan(mission_prompt)
70
+ logging.info("Mission plan: %s", mission_plan.to_json())
71
+ queries = mission_plan.queries()
72
+
73
  processed_frames: List[np.ndarray] = []
74
+ detection_log: List[Dict[str, Any]] = []
75
  for idx, frame in enumerate(frames):
76
  if max_frames is not None and idx >= max_frames:
77
  break
78
  logging.debug("Processing frame %d", idx)
79
+ processed_frame, detections = infer_frame(frame, queries)
80
+ detection_log.append({"frame_index": idx, "detections": detections})
81
  processed_frames.append(processed_frame)
82
 
83
  write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
84
+ mission_summary = summarize_results(mission_prompt, mission_plan, detection_log)
85
+ return output_video_path, mission_plan, mission_summary
mission_planner.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Dict, List, Tuple
7
+
8
+ from utils.openai_client import get_openai_client
9
+
10
+
11
+ YOLO_CLASSES: Tuple[str, ...] = (
12
+ "person",
13
+ "bicycle",
14
+ "car",
15
+ "motorcycle",
16
+ "airplane",
17
+ "bus",
18
+ "train",
19
+ "truck",
20
+ "boat",
21
+ "traffic light",
22
+ "fire hydrant",
23
+ "stop sign",
24
+ "parking meter",
25
+ "bench",
26
+ "bird",
27
+ "cat",
28
+ "dog",
29
+ "horse",
30
+ "sheep",
31
+ "cow",
32
+ "elephant",
33
+ "bear",
34
+ "zebra",
35
+ "giraffe",
36
+ "backpack",
37
+ "umbrella",
38
+ "handbag",
39
+ "tie",
40
+ "suitcase",
41
+ "frisbee",
42
+ "skis",
43
+ "snowboard",
44
+ "sports ball",
45
+ "kite",
46
+ "baseball bat",
47
+ "baseball glove",
48
+ "skateboard",
49
+ "surfboard",
50
+ "tennis racket",
51
+ "bottle",
52
+ "wine glass",
53
+ "cup",
54
+ "fork",
55
+ "knife",
56
+ "spoon",
57
+ "bowl",
58
+ "banana",
59
+ "apple",
60
+ "sandwich",
61
+ "orange",
62
+ "broccoli",
63
+ "carrot",
64
+ "hot dog",
65
+ "pizza",
66
+ "donut",
67
+ "cake",
68
+ "chair",
69
+ "couch",
70
+ "potted plant",
71
+ "bed",
72
+ "dining table",
73
+ "toilet",
74
+ "tv",
75
+ "laptop",
76
+ "mouse",
77
+ "remote",
78
+ "keyboard",
79
+ "cell phone",
80
+ "microwave",
81
+ "oven",
82
+ "toaster",
83
+ "sink",
84
+ "refrigerator",
85
+ "book",
86
+ "clock",
87
+ "vase",
88
+ "scissors",
89
+ "teddy bear",
90
+ "hair drier",
91
+ "toothbrush",
92
+ )
93
+
94
+
95
+ DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
96
+
97
+
98
+ @dataclass
99
+ class MissionClass:
100
+ name: str
101
+ score: float
102
+ rationale: str
103
+
104
+
105
+ @dataclass
106
+ class MissionPlan:
107
+ mission: str
108
+ relevant_classes: List[MissionClass]
109
+
110
+ def queries(self) -> List[str]:
111
+ return [entry.name for entry in self.relevant_classes]
112
+
113
+ def to_dict(self) -> dict:
114
+ return {
115
+ "mission": self.mission,
116
+ "classes": [asdict(entry) for entry in self.relevant_classes],
117
+ }
118
+
119
+ def to_json(self) -> str:
120
+ return json.dumps(self.to_dict())
121
+
122
+
123
+ class MissionReasoner:
124
+ def __init__(
125
+ self,
126
+ *,
127
+ model_name: str = DEFAULT_OPENAI_MODEL,
128
+ top_k: int = 10,
129
+ ) -> None:
130
+ self._model_name = model_name
131
+ self._top_k = top_k
132
+
133
+ def plan(self, mission: str) -> MissionPlan:
134
+ mission = (mission or "").strip()
135
+ if not mission:
136
+ raise ValueError("Mission prompt cannot be empty.")
137
+ response_payload = self._query_llm(mission)
138
+ relevant = self._parse_plan(response_payload, fallback_mission=mission)
139
+ return MissionPlan(mission=response_payload.get("mission", mission), relevant_classes=relevant[: self._top_k])
140
+
141
+ def _query_llm(self, mission: str) -> Dict[str, object]:
142
+ client = get_openai_client()
143
+ system_prompt = (
144
+ "You are a mission-planning assistant helping a vision system select which YOLO object classes to detect. "
145
+ "You must only reference the provided list of YOLO classes."
146
+ )
147
+ classes_blob = ", ".join(YOLO_CLASSES)
148
+ user_prompt = (
149
+ f"Mission: {mission}\n"
150
+ f"Available YOLO classes: {classes_blob}\n"
151
+ f"Return JSON with: mission (string) and classes (array). "
152
+ f"Each entry needs name, score (0-1 float), rationale. "
153
+ f"Limit to at most {self._top_k} classes. Only choose names from the list."
154
+ )
155
+ completion = client.chat.completions.create(
156
+ model=self._model_name,
157
+ temperature=0.2,
158
+ response_format={"type": "json_object"},
159
+ messages=[
160
+ {"role": "system", "content": system_prompt},
161
+ {"role": "user", "content": user_prompt},
162
+ ],
163
+ )
164
+ content = completion.choices[0].message.content or "{}"
165
+ try:
166
+ return json.loads(content)
167
+ except json.JSONDecodeError:
168
+ logging.exception("LLM returned non-JSON content: %s", content)
169
+ return {"mission": mission, "classes": []}
170
+
171
+ def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]:
172
+ entries = payload.get("classes") or payload.get("relevant_classes") or []
173
+ mission = payload.get("mission") or fallback_mission
174
+ parsed: List[MissionClass] = []
175
+ seen = set()
176
+ for entry in entries:
177
+ if not isinstance(entry, dict):
178
+ continue
179
+ name = str(entry.get("name") or "").strip()
180
+ if not name or name not in YOLO_CLASSES or name in seen:
181
+ continue
182
+ seen.add(name)
183
+ score_raw = entry.get("score")
184
+ try:
185
+ score = float(score_raw)
186
+ except (TypeError, ValueError):
187
+ score = 0.5
188
+ rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.")
189
+ parsed.append(MissionClass(name=name, score=max(0.0, min(1.0, score)), rationale=rationale))
190
+
191
+ if not parsed:
192
+ logging.warning("LLM returned no usable classes. Falling back to default YOLO list.")
193
+ parsed = [
194
+ MissionClass(
195
+ name=label,
196
+ score=1.0 - (idx * 0.05),
197
+ rationale=f"Fallback selection for mission '{mission}'.",
198
+ )
199
+ for idx, label in enumerate(YOLO_CLASSES[: self._top_k])
200
+ ]
201
+ return parsed
202
+
203
+
204
+ _REASONER: MissionReasoner | None = None
205
+
206
+
207
+ def get_mission_plan(mission: str) -> MissionPlan:
208
+ global _REASONER
209
+ if _REASONER is None:
210
+ _REASONER = MissionReasoner()
211
+ return _REASONER.plan(mission)
mission_summarizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from typing import Any, Dict, List
6
+
7
+ from mission_planner import MissionPlan
8
+ from utils.openai_client import get_openai_client
9
+
10
+ SUMMARY_MODEL = "gpt-4o-mini"
11
+
12
+
13
+ def _trim_detections(detections: List[Dict[str, Any]], max_boxes: int = 5) -> List[Dict[str, Any]]:
14
+ if len(detections) <= max_boxes:
15
+ return detections
16
+ return detections[:max_boxes]
17
+
18
+
19
+ def _build_context_snapshot(records: List[Dict[str, Any]]) -> Dict[str, Any]:
20
+ if not records:
21
+ return {}
22
+ first = records[0]
23
+ middle = records[len(records) // 2]
24
+ last = records[-1]
25
+ return {
26
+ "first_frame": {
27
+ "frame_index": first["frame_index"],
28
+ "detections": _trim_detections(first.get("detections", [])),
29
+ },
30
+ "middle_frame": {
31
+ "frame_index": middle["frame_index"],
32
+ "detections": _trim_detections(middle.get("detections", [])),
33
+ },
34
+ "last_frame": {
35
+ "frame_index": last["frame_index"],
36
+ "detections": _trim_detections(last.get("detections", [])),
37
+ },
38
+ }
39
+
40
+
41
+ def summarize_results(
42
+ mission_prompt: str,
43
+ mission_plan: MissionPlan,
44
+ detection_log: List[Dict[str, Any]],
45
+ ) -> str:
46
+ if not detection_log:
47
+ return "No detections were produced, so no summary is available."
48
+
49
+ context_snapshot = _build_context_snapshot(detection_log)
50
+ payload = {
51
+ "mission_prompt": mission_prompt,
52
+ "mission_plan": mission_plan.to_dict(),
53
+ "global_context": context_snapshot,
54
+ "detection_log": [
55
+ {
56
+ "frame_index": entry["frame_index"],
57
+ "detections": _trim_detections(entry.get("detections", []), max_boxes=8),
58
+ }
59
+ for entry in detection_log
60
+ ],
61
+ }
62
+
63
+ system_prompt = (
64
+ "You are a surveillance analyst. Review structured detections aligned to a mission and summarize actionable "
65
+ "insights, highlighting objects of interest, temporal trends, and any security concerns. "
66
+ "Base conclusions solely on the provided data; if nothing is detected, explicitly state that."
67
+ )
68
+ messages = [
69
+ {"role": "system", "content": system_prompt},
70
+ {
71
+ "role": "user",
72
+ "content": (
73
+ "Use this JSON to summarize the mission outcome:\n"
74
+ f"{json.dumps(payload, ensure_ascii=False)}"
75
+ ),
76
+ },
77
+ ]
78
+
79
+ try:
80
+ client = get_openai_client()
81
+ completion = client.chat.completions.create(
82
+ model=SUMMARY_MODEL,
83
+ temperature=0.2,
84
+ messages=messages,
85
+ )
86
+ return (completion.choices[0].message.content or "").strip()
87
+ except Exception:
88
+ logging.exception("Failed to generate mission summary.")
89
+ return "Mission summary generation failed."
models/detectors/base.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple, Sequence
2
+
3
+ import numpy as np
4
+
5
+
6
+ class DetectionResult(NamedTuple):
7
+ boxes: np.ndarray
8
+ scores: Sequence[float]
9
+ labels: Sequence[int]
10
+
11
+
12
+ class ObjectDetector:
13
+ """Detector interface to keep inference agnostic to model details."""
14
+
15
+ name: str
16
+
17
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
18
+ raise NotImplementedError
models/detectors/owlv2.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import Owlv2ForObjectDetection, Owlv2Processor
7
+
8
+ from models.detectors.base import DetectionResult, ObjectDetector
9
+
10
+
11
+ class Owlv2Detector(ObjectDetector):
12
+ MODEL_NAME = "google/owlv2-large-patch14"
13
+
14
+ def __init__(self) -> None:
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
17
+ self.processor = Owlv2Processor.from_pretrained(self.MODEL_NAME)
18
+ torch_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
19
+ self.model = Owlv2ForObjectDetection.from_pretrained(
20
+ self.MODEL_NAME, torch_dtype=torch_dtype
21
+ )
22
+ self.model.to(self.device)
23
+ self.model.eval()
24
+ self.name = "owlv2"
25
+
26
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
27
+ inputs = self.processor(text=queries, images=frame, return_tensors="pt")
28
+ if hasattr(inputs, "to"):
29
+ inputs = inputs.to(self.device)
30
+ else:
31
+ inputs = {
32
+ key: value.to(self.device) if hasattr(value, "to") else value
33
+ for key, value in inputs.items()
34
+ }
35
+ with torch.no_grad():
36
+ outputs = self.model(**inputs)
37
+ processed = self.processor.post_process_object_detection(
38
+ outputs, threshold=0.3, target_sizes=[frame.shape[:2]]
39
+ )[0]
40
+ boxes = processed["boxes"]
41
+ scores = processed.get("scores", [])
42
+ labels = processed.get("labels", [])
43
+ boxes_np = boxes.cpu().numpy() if hasattr(boxes, "cpu") else np.asarray(boxes)
44
+ if hasattr(scores, "cpu"):
45
+ scores_seq = scores.cpu().numpy().tolist()
46
+ elif isinstance(scores, np.ndarray):
47
+ scores_seq = scores.tolist()
48
+ else:
49
+ scores_seq = list(scores)
50
+ if hasattr(labels, "cpu"):
51
+ labels_seq = labels.cpu().numpy().tolist()
52
+ elif isinstance(labels, np.ndarray):
53
+ labels_seq = labels.tolist()
54
+ else:
55
+ labels_seq = list(labels)
56
+ return DetectionResult(boxes=boxes_np, scores=scores_seq, labels=labels_seq)
models/model_loader.py CHANGED
@@ -1,20 +1,37 @@
1
- import logging
2
- from typing import Tuple
 
3
 
4
- import torch
5
- from transformers import Owlv2ForObjectDetection, Owlv2Processor
6
 
7
- MODEL_NAME = "google/owlv2-large-patch14"
8
- _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- logging.info("Loading %s onto %s", MODEL_NAME, _DEVICE)
11
- _PROCESSOR = Owlv2Processor.from_pretrained(MODEL_NAME)
12
- torch_dtype = torch.float16 if _DEVICE.type == "cuda" else torch.float32
13
- _MODEL = Owlv2ForObjectDetection.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype)
14
- _MODEL.to(_DEVICE)
15
- _MODEL.eval()
16
 
17
 
18
- def load_model() -> Tuple[Owlv2Processor, Owlv2ForObjectDetection, torch.device]:
19
- """Expose processor/model singletons so the API never reloads weights."""
20
- return _PROCESSOR, _MODEL, _DEVICE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Callable, Dict, Optional
4
 
5
+ from models.detectors.base import ObjectDetector
6
+ from models.detectors.owlv2 import Owlv2Detector
7
 
8
+ DEFAULT_DETECTOR = "owlv2"
 
9
 
10
+ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
11
+ "owlv2": Owlv2Detector,
12
+ }
 
 
 
13
 
14
 
15
+ def _create_detector(name: str) -> ObjectDetector:
16
+ try:
17
+ factory = _REGISTRY[name]
18
+ except KeyError as exc:
19
+ available = ", ".join(sorted(_REGISTRY))
20
+ raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
21
+ return factory()
22
+
23
+
24
+ @lru_cache(maxsize=None)
25
+ def _get_cached_detector(name: str) -> ObjectDetector:
26
+ return _create_detector(name)
27
+
28
+
29
+ def load_detector(name: Optional[str] = None) -> ObjectDetector:
30
+ """Return a cached detector instance selected via arg or OBJECT_DETECTOR env."""
31
+ detector_name = name or os.getenv("OBJECT_DETECTOR", DEFAULT_DETECTOR)
32
+ return _get_cached_detector(detector_name)
33
+
34
+
35
+ # Backwards compatibility for existing callers.
36
+ def load_model():
37
+ return load_detector()
requirements.txt CHANGED
@@ -7,3 +7,4 @@ python-multipart
7
  accelerate
8
  pillow
9
  scipy
 
 
7
  accelerate
8
  pillow
9
  scipy
10
+ openai
utils/openai_client.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Dict
6
+
7
+ from openai import OpenAI
8
+
9
+ ENV_FILE_NAME = ".env"
10
+ _ENV_KEY_CANDIDATES = (
11
+ "OPENAI_API_KEY",
12
+ "OpenAI_API_KEY",
13
+ "OpenAI-API",
14
+ "OpenAI_API",
15
+ "OPENAIKEY",
16
+ )
17
+
18
+ _OPENAI_CLIENT: OpenAI | None = None
19
+
20
+
21
+ def _read_env_file(path: Path) -> Dict[str, str]:
22
+ entries: Dict[str, str] = {}
23
+ if not path.exists():
24
+ return entries
25
+ for raw_line in path.read_text().splitlines():
26
+ line = raw_line.strip()
27
+ if not line or line.startswith("#"):
28
+ continue
29
+ if "=" in line:
30
+ key, value = line.split("=", 1)
31
+ elif ":" in line:
32
+ key, value = line.split(":", 1)
33
+ else:
34
+ continue
35
+ entries[key.strip()] = value.strip().strip('"').strip("'")
36
+ return entries
37
+
38
+
39
+ def ensure_openai_api_key() -> str:
40
+ key = os.getenv("OPENAI_API_KEY")
41
+ if key:
42
+ return key
43
+
44
+ env_path = Path(__file__).resolve().parent.parent / ENV_FILE_NAME
45
+ env_entries = _read_env_file(env_path)
46
+ for candidate in _ENV_KEY_CANDIDATES:
47
+ if env_entries.get(candidate):
48
+ key = env_entries[candidate]
49
+ break
50
+ else:
51
+ key = None
52
+
53
+ if not key:
54
+ raise RuntimeError(
55
+ "OpenAI API key is not configured. Set OPENAI_API_KEY or add it to .env (e.g., 'OpenAI-API: sk-...')."
56
+ )
57
+
58
+ os.environ["OPENAI_API_KEY"] = key
59
+ return key
60
+
61
+
62
+ def get_openai_client() -> OpenAI:
63
+ global _OPENAI_CLIENT
64
+ if _OPENAI_CLIENT is None:
65
+ api_key = ensure_openai_api_key()
66
+ _OPENAI_CLIENT = OpenAI(api_key=api_key)
67
+ return _OPENAI_CLIENT