File size: 3,824 Bytes
2ab6e0a
a8d3381
2ab6e0a
 
 
a8d3381
 
 
2ab6e0a
 
 
 
 
 
 
 
 
 
 
 
 
a8d3381
 
 
 
 
c57c49d
a8d3381
 
 
c57c49d
 
2ab6e0a
c57c49d
 
 
 
 
a8d3381
 
 
 
 
 
 
 
 
 
c57c49d
 
 
 
 
 
a8d3381
 
 
c57c49d
 
 
2ab6e0a
a8d3381
2ab6e0a
a8d3381
2ab6e0a
 
 
 
51780f2
a8d3381
2ab6e0a
c57c49d
51780f2
 
 
2ab6e0a
 
 
 
 
 
a8d3381
 
 
 
2ab6e0a
a8d3381
2ab6e0a
 
 
 
c57c49d
a8d3381
2ab6e0a
 
51780f2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple

import cv2
import numpy as np
from models.model_loader import load_detector
from mission_planner import MissionPlan, get_mission_plan
from mission_summarizer import summarize_results
from utils.video import extract_frames, write_video


def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
    output = frame.copy()
    if boxes is None:
        return output
    for box in boxes:
        x1, y1, x2, y2 = [int(coord) for coord in box]
        cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
    return output


def _build_detection_records(
    boxes: np.ndarray,
    scores: Sequence[float],
    labels: Sequence[int],
    queries: Sequence[str],
    label_names: Optional[Sequence[str]] = None,
) -> List[Dict[str, Any]]:
    detections: List[Dict[str, Any]] = []
    for idx, box in enumerate(boxes):
        if label_names is not None and idx < len(label_names):
            label = label_names[idx]
        else:
            label_idx = int(labels[idx]) if idx < len(labels) else -1
            if 0 <= label_idx < len(queries):
                label = queries[label_idx]
            else:
                label = f"label_{label_idx}"
        detections.append(
            {
                "label": label,
                "score": float(scores[idx]) if idx < len(scores) else 0.0,
                "bbox": [int(coord) for coord in box.tolist()],
            }
        )
    return detections


def infer_frame(
    frame: np.ndarray,
    queries: Sequence[str],
    detector_name: Optional[str] = None,
) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
    detector = load_detector(detector_name)
    text_queries = list(queries) or ["object"]
    try:
        result = detector.predict(frame, text_queries)
        detections = _build_detection_records(
            result.boxes, result.scores, result.labels, text_queries, result.label_names
        )
    except Exception:
        logging.exception("Inference failed for queries %s", text_queries)
        raise
    return draw_boxes(frame, result.boxes), detections


def run_inference(
    input_video_path: str,
    output_video_path: Optional[str],
    mission_prompt: str,
    max_frames: Optional[int] = None,
    detector_name: Optional[str] = None,
    write_output_video: bool = True,
    generate_summary: bool = True,
) -> Tuple[Optional[str], MissionPlan, Optional[str]]:
    try:
        frames, fps, width, height = extract_frames(input_video_path)
    except ValueError as exc:
        logging.exception("Failed to decode video at %s", input_video_path)
        raise

    mission_plan = get_mission_plan(mission_prompt)
    logging.info("Mission plan: %s", mission_plan.to_json())
    queries = mission_plan.queries()

    processed_frames: List[np.ndarray] = []
    detection_log: List[Dict[str, Any]] = []
    for idx, frame in enumerate(frames):
        if max_frames is not None and idx >= max_frames:
            break
        logging.debug("Processing frame %d", idx)
        processed_frame, detections = infer_frame(frame, queries, detector_name=detector_name)
        detection_log.append({"frame_index": idx, "detections": detections})
        processed_frames.append(processed_frame)

    if write_output_video:
        if not output_video_path:
            raise ValueError("output_video_path is required when write_output_video=True.")
        write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
        video_path_result: Optional[str] = output_video_path
    else:
        video_path_result = None
    mission_summary = (
        summarize_results(mission_prompt, mission_plan, detection_log) if generate_summary else None
    )
    return video_path_result, mission_plan, mission_summary