Demo-2025 / inference.py
zye0616's picture
update: two stages processing
5e3ba22
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple
import cv2
import numpy as np
from models.model_loader import load_detector
from mission_planner import MissionPlan, get_mission_plan
from mission_summarizer import summarize_results
from utils.video import extract_frames, write_video
def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
output = frame.copy()
if boxes is None:
return output
for box in boxes:
x1, y1, x2, y2 = [int(coord) for coord in box]
cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
return output
def _build_detection_records(
boxes: np.ndarray,
scores: Sequence[float],
labels: Sequence[int],
queries: Sequence[str],
label_names: Optional[Sequence[str]] = None,
) -> List[Dict[str, Any]]:
detections: List[Dict[str, Any]] = []
for idx, box in enumerate(boxes):
if label_names is not None and idx < len(label_names):
label = label_names[idx]
else:
label_idx = int(labels[idx]) if idx < len(labels) else -1
if 0 <= label_idx < len(queries):
label = queries[label_idx]
else:
label = f"label_{label_idx}"
detections.append(
{
"label": label,
"score": float(scores[idx]) if idx < len(scores) else 0.0,
"bbox": [int(coord) for coord in box.tolist()],
}
)
return detections
def infer_frame(
frame: np.ndarray,
queries: Sequence[str],
detector_name: Optional[str] = None,
) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
detector = load_detector(detector_name)
text_queries = list(queries) or ["object"]
try:
result = detector.predict(frame, text_queries)
detections = _build_detection_records(
result.boxes, result.scores, result.labels, text_queries, result.label_names
)
except Exception:
logging.exception("Inference failed for queries %s", text_queries)
raise
return draw_boxes(frame, result.boxes), detections
def run_inference(
input_video_path: str,
output_video_path: Optional[str],
mission_prompt: str,
max_frames: Optional[int] = None,
detector_name: Optional[str] = None,
write_output_video: bool = True,
generate_summary: bool = True,
mission_plan: Optional[MissionPlan] = None,
) -> Tuple[Optional[str], MissionPlan, Optional[str]]:
try:
frames, fps, width, height = extract_frames(input_video_path)
except ValueError as exc:
logging.exception("Failed to decode video at %s", input_video_path)
raise
mission_prompt_clean = (mission_prompt or "").strip()
if not mission_prompt_clean:
raise ValueError("Mission prompt is required.")
resolved_plan = mission_plan or get_mission_plan(mission_prompt_clean)
logging.info("Mission plan: %s", resolved_plan.to_json())
queries = resolved_plan.queries()
processed_frames: List[np.ndarray] = []
detection_log: List[Dict[str, Any]] = []
for idx, frame in enumerate(frames):
if max_frames is not None and idx >= max_frames:
break
logging.debug("Processing frame %d", idx)
processed_frame, detections = infer_frame(frame, queries, detector_name=detector_name)
detection_log.append({"frame_index": idx, "detections": detections})
processed_frames.append(processed_frame)
if write_output_video:
if not output_video_path:
raise ValueError("output_video_path is required when write_output_video=True.")
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
video_path_result: Optional[str] = output_video_path
else:
video_path_result = None
mission_summary = (
summarize_results(mission_prompt_clean, resolved_plan, detection_log) if generate_summary else None
)
return video_path_result, resolved_plan, mission_summary