zye0616 commited on
Commit
c57c49d
·
1 Parent(s): a8d3381

updated object detector

Browse files
README.md CHANGED
@@ -12,7 +12,7 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
12
 
13
  ## Mission-guided detections
14
 
15
- 1. Send a `POST /process_video` request with fields `video` (file) and `prompt` (mission text).
16
  2. The backend feeds the mission text into an OpenAI (`gpt-4o-mini`) reasoning step that scores and ranks every YOLO/COCO class. Place your API key inside `.env` as either `OPENAI_API_KEY=...` or `OpenAI-API: ...`; the server loads it automatically on startup.
17
  3. The top scored classes become the textual queries for the existing OWLv2 detector so the detections align with the mission.
18
  4. After object detection finishes, another OpenAI call ingests the detection log plus the first/middle/last frame context and produces a natural-language summary of the mission outcome.
 
12
 
13
  ## Mission-guided detections
14
 
15
+ 1. Send a `POST /process_video` request with fields `video` (file) and `prompt` (mission text). Optionally include `detector` (`owlv2` or `hf_yolov8`) to pick the backend per request; if omitted the server uses its default/`OBJECT_DETECTOR` setting.
16
  2. The backend feeds the mission text into an OpenAI (`gpt-4o-mini`) reasoning step that scores and ranks every YOLO/COCO class. Place your API key inside `.env` as either `OPENAI_API_KEY=...` or `OpenAI-API: ...`; the server loads it automatically on startup.
17
  3. The top scored classes become the textual queries for the existing OWLv2 detector so the detections align with the mission.
18
  4. After object detection finishes, another OpenAI call ingests the detection log plus the first/middle/last frame context and produces a natural-language summary of the mission outcome.
app.py CHANGED
@@ -2,8 +2,10 @@ import logging
2
  import os
3
  import tempfile
4
  from pathlib import Path
 
5
 
6
  from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
 
7
  from fastapi.responses import FileResponse, JSONResponse
8
  import uvicorn
9
 
@@ -12,6 +14,14 @@ from inference import run_inference
12
  logging.basicConfig(level=logging.INFO)
13
 
14
  app = FastAPI(title="Video Processing Backend")
 
 
 
 
 
 
 
 
15
 
16
 
17
  def _save_upload_to_tmp(upload: UploadFile) -> str:
@@ -45,6 +55,7 @@ async def process_video(
45
  background_tasks: BackgroundTasks,
46
  video: UploadFile = File(...),
47
  prompt: str = Form(...),
 
48
  ):
49
  if video is None:
50
  raise HTTPException(status_code=400, detail="Video file is required.")
@@ -63,7 +74,13 @@ async def process_video(
63
  os.close(fd)
64
 
65
  try:
66
- output_path, mission_plan, mission_summary = run_inference(input_path, output_path, prompt, max_frames=10)
 
 
 
 
 
 
67
  except ValueError as exc:
68
  logging.exception("Video decoding failed.")
69
  _safe_delete(input_path)
 
2
  import os
3
  import tempfile
4
  from pathlib import Path
5
+ from typing import Optional
6
 
7
  from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import FileResponse, JSONResponse
10
  import uvicorn
11
 
 
14
  logging.basicConfig(level=logging.INFO)
15
 
16
  app = FastAPI(title="Video Processing Backend")
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ expose_headers=["x-mission-summary"],
24
+ )
25
 
26
 
27
  def _save_upload_to_tmp(upload: UploadFile) -> str:
 
55
  background_tasks: BackgroundTasks,
56
  video: UploadFile = File(...),
57
  prompt: str = Form(...),
58
+ detector: Optional[str] = Form(None),
59
  ):
60
  if video is None:
61
  raise HTTPException(status_code=400, detail="Video file is required.")
 
74
  os.close(fd)
75
 
76
  try:
77
+ output_path, mission_plan, mission_summary = run_inference(
78
+ input_path,
79
+ output_path,
80
+ prompt,
81
+ max_frames=10,
82
+ detector_name=detector,
83
+ )
84
  except ValueError as exc:
85
  logging.exception("Video decoding failed.")
86
  _safe_delete(input_path)
inference.py CHANGED
@@ -24,14 +24,18 @@ def _build_detection_records(
24
  scores: Sequence[float],
25
  labels: Sequence[int],
26
  queries: Sequence[str],
 
27
  ) -> List[Dict[str, Any]]:
28
  detections: List[Dict[str, Any]] = []
29
  for idx, box in enumerate(boxes):
30
- label_idx = int(labels[idx]) if idx < len(labels) else -1
31
- if 0 <= label_idx < len(queries):
32
- label = queries[label_idx]
33
  else:
34
- label = f"label_{label_idx}"
 
 
 
 
35
  detections.append(
36
  {
37
  "label": label,
@@ -42,12 +46,18 @@ def _build_detection_records(
42
  return detections
43
 
44
 
45
- def infer_frame(frame: np.ndarray, queries: Sequence[str]) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
46
- detector = load_detector()
 
 
 
 
47
  text_queries = list(queries) or ["object"]
48
  try:
49
  result = detector.predict(frame, text_queries)
50
- detections = _build_detection_records(result.boxes, result.scores, result.labels, text_queries)
 
 
51
  except Exception:
52
  logging.exception("Inference failed for queries %s", text_queries)
53
  raise
@@ -59,6 +69,7 @@ def run_inference(
59
  output_video_path: str,
60
  mission_prompt: str,
61
  max_frames: Optional[int] = None,
 
62
  ) -> Tuple[str, MissionPlan, str]:
63
  try:
64
  frames, fps, width, height = extract_frames(input_video_path)
@@ -76,7 +87,7 @@ def run_inference(
76
  if max_frames is not None and idx >= max_frames:
77
  break
78
  logging.debug("Processing frame %d", idx)
79
- processed_frame, detections = infer_frame(frame, queries)
80
  detection_log.append({"frame_index": idx, "detections": detections})
81
  processed_frames.append(processed_frame)
82
 
 
24
  scores: Sequence[float],
25
  labels: Sequence[int],
26
  queries: Sequence[str],
27
+ label_names: Optional[Sequence[str]] = None,
28
  ) -> List[Dict[str, Any]]:
29
  detections: List[Dict[str, Any]] = []
30
  for idx, box in enumerate(boxes):
31
+ if label_names is not None and idx < len(label_names):
32
+ label = label_names[idx]
 
33
  else:
34
+ label_idx = int(labels[idx]) if idx < len(labels) else -1
35
+ if 0 <= label_idx < len(queries):
36
+ label = queries[label_idx]
37
+ else:
38
+ label = f"label_{label_idx}"
39
  detections.append(
40
  {
41
  "label": label,
 
46
  return detections
47
 
48
 
49
+ def infer_frame(
50
+ frame: np.ndarray,
51
+ queries: Sequence[str],
52
+ detector_name: Optional[str] = None,
53
+ ) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
54
+ detector = load_detector(detector_name)
55
  text_queries = list(queries) or ["object"]
56
  try:
57
  result = detector.predict(frame, text_queries)
58
+ detections = _build_detection_records(
59
+ result.boxes, result.scores, result.labels, text_queries, result.label_names
60
+ )
61
  except Exception:
62
  logging.exception("Inference failed for queries %s", text_queries)
63
  raise
 
69
  output_video_path: str,
70
  mission_prompt: str,
71
  max_frames: Optional[int] = None,
72
+ detector_name: Optional[str] = None,
73
  ) -> Tuple[str, MissionPlan, str]:
74
  try:
75
  frames, fps, width, height = extract_frames(input_video_path)
 
87
  if max_frames is not None and idx >= max_frames:
88
  break
89
  logging.debug("Processing frame %d", idx)
90
+ processed_frame, detections = infer_frame(frame, queries, detector_name=detector_name)
91
  detection_log.append({"frame_index": idx, "detections": detections})
92
  processed_frames.append(processed_frame)
93
 
models/detectors/base.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import NamedTuple, Sequence
2
 
3
  import numpy as np
4
 
@@ -7,6 +7,7 @@ class DetectionResult(NamedTuple):
7
  boxes: np.ndarray
8
  scores: Sequence[float]
9
  labels: Sequence[int]
 
10
 
11
 
12
  class ObjectDetector:
 
1
+ from typing import NamedTuple, Optional, Sequence
2
 
3
  import numpy as np
4
 
 
7
  boxes: np.ndarray
8
  scores: Sequence[float]
9
  labels: Sequence[int]
10
+ label_names: Optional[Sequence[str]] = None
11
 
12
 
13
  class ObjectDetector:
models/detectors/yolov8.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ from ultralytics import YOLO
8
+
9
+ from models.detectors.base import DetectionResult, ObjectDetector
10
+
11
+
12
+ class HuggingFaceYoloV8Detector(ObjectDetector):
13
+ """YOLOv8 detector whose weights are fetched from the Hugging Face Hub."""
14
+
15
+ REPO_ID = "Ultralytics/YOLOv8"
16
+ WEIGHT_FILE = "yolov8s.pt"
17
+
18
+ def __init__(self, score_threshold: float = 0.3) -> None:
19
+ self.name = "hf_yolov8"
20
+ self.score_threshold = score_threshold
21
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
+ logging.info(
23
+ "Loading Hugging Face YOLOv8 weights %s/%s onto %s",
24
+ self.REPO_ID,
25
+ self.WEIGHT_FILE,
26
+ self.device,
27
+ )
28
+ weight_path = hf_hub_download(repo_id=self.REPO_ID, filename=self.WEIGHT_FILE)
29
+ self.model = YOLO(weight_path)
30
+ self.model.to(self.device)
31
+ self.class_names = self.model.names
32
+
33
+ def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]:
34
+ if not queries:
35
+ return list(range(len(label_names)))
36
+ allowed = {query.lower().strip() for query in queries if query}
37
+ keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
38
+ return keep or list(range(len(label_names)))
39
+
40
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
41
+ device_arg = 0 if self.device.startswith("cuda") else "cpu"
42
+ results = self.model.predict(
43
+ source=frame,
44
+ device=device_arg,
45
+ conf=self.score_threshold,
46
+ verbose=False,
47
+ )
48
+ result = results[0]
49
+ boxes = result.boxes
50
+ if boxes is None or boxes.xyxy is None:
51
+ empty = np.empty((0, 4), dtype=np.float32)
52
+ return DetectionResult(empty, [], [], [])
53
+
54
+ xyxy = boxes.xyxy.cpu().numpy()
55
+ scores = boxes.conf.cpu().numpy().tolist()
56
+ label_ids = boxes.cls.cpu().numpy().astype(int).tolist()
57
+ label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids]
58
+ keep_indices = self._filter_indices(label_names, queries)
59
+ xyxy = xyxy[keep_indices] if len(xyxy) else xyxy
60
+ scores = [scores[i] for i in keep_indices]
61
+ label_ids = [label_ids[i] for i in keep_indices]
62
+ label_names = [label_names[i] for i in keep_indices]
63
+ return DetectionResult(
64
+ boxes=xyxy,
65
+ scores=scores,
66
+ labels=label_ids,
67
+ label_names=label_names,
68
+ )
69
+
models/model_loader.py CHANGED
@@ -4,11 +4,13 @@ from typing import Callable, Dict, Optional
4
 
5
  from models.detectors.base import ObjectDetector
6
  from models.detectors.owlv2 import Owlv2Detector
 
7
 
8
  DEFAULT_DETECTOR = "owlv2"
9
 
10
  _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
11
  "owlv2": Owlv2Detector,
 
12
  }
13
 
14
 
 
4
 
5
  from models.detectors.base import ObjectDetector
6
  from models.detectors.owlv2 import Owlv2Detector
7
+ from models.detectors.yolov8 import HuggingFaceYoloV8Detector
8
 
9
  DEFAULT_DETECTOR = "owlv2"
10
 
11
  _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
12
  "owlv2": Owlv2Detector,
13
+ "hf_yolov8": HuggingFaceYoloV8Detector,
14
  }
15
 
16
 
requirements.txt CHANGED
@@ -8,3 +8,5 @@ accelerate
8
  pillow
9
  scipy
10
  openai
 
 
 
8
  pillow
9
  scipy
10
  openai
11
+ huggingface-hub
12
+ ultralytics