Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| import cv2 | |
| import numpy as np | |
| from models.model_loader import load_detector | |
| from mission_planner import MissionPlan, get_mission_plan | |
| from mission_summarizer import summarize_results | |
| from utils.video import extract_frames, write_video | |
| def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray: | |
| output = frame.copy() | |
| if boxes is None: | |
| return output | |
| for box in boxes: | |
| x1, y1, x2, y2 = [int(coord) for coord in box] | |
| cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2) | |
| return output | |
| def _build_detection_records( | |
| boxes: np.ndarray, | |
| scores: Sequence[float], | |
| labels: Sequence[int], | |
| queries: Sequence[str], | |
| label_names: Optional[Sequence[str]] = None, | |
| ) -> List[Dict[str, Any]]: | |
| detections: List[Dict[str, Any]] = [] | |
| for idx, box in enumerate(boxes): | |
| if label_names is not None and idx < len(label_names): | |
| label = label_names[idx] | |
| else: | |
| label_idx = int(labels[idx]) if idx < len(labels) else -1 | |
| if 0 <= label_idx < len(queries): | |
| label = queries[label_idx] | |
| else: | |
| label = f"label_{label_idx}" | |
| detections.append( | |
| { | |
| "label": label, | |
| "score": float(scores[idx]) if idx < len(scores) else 0.0, | |
| "bbox": [int(coord) for coord in box.tolist()], | |
| } | |
| ) | |
| return detections | |
| def infer_frame( | |
| frame: np.ndarray, | |
| queries: Sequence[str], | |
| detector_name: Optional[str] = None, | |
| ) -> Tuple[np.ndarray, List[Dict[str, Any]]]: | |
| detector = load_detector(detector_name) | |
| text_queries = list(queries) or ["object"] | |
| try: | |
| result = detector.predict(frame, text_queries) | |
| detections = _build_detection_records( | |
| result.boxes, result.scores, result.labels, text_queries, result.label_names | |
| ) | |
| except Exception: | |
| logging.exception("Inference failed for queries %s", text_queries) | |
| raise | |
| return draw_boxes(frame, result.boxes), detections | |
| def run_inference( | |
| input_video_path: str, | |
| output_video_path: Optional[str], | |
| mission_prompt: str, | |
| max_frames: Optional[int] = None, | |
| detector_name: Optional[str] = None, | |
| write_output_video: bool = True, | |
| generate_summary: bool = True, | |
| mission_plan: Optional[MissionPlan] = None, | |
| ) -> Tuple[Optional[str], MissionPlan, Optional[str]]: | |
| try: | |
| frames, fps, width, height = extract_frames(input_video_path) | |
| except ValueError as exc: | |
| logging.exception("Failed to decode video at %s", input_video_path) | |
| raise | |
| mission_prompt_clean = (mission_prompt or "").strip() | |
| if not mission_prompt_clean: | |
| raise ValueError("Mission prompt is required.") | |
| resolved_plan = mission_plan or get_mission_plan(mission_prompt_clean) | |
| logging.info("Mission plan: %s", resolved_plan.to_json()) | |
| queries = resolved_plan.queries() | |
| processed_frames: List[np.ndarray] = [] | |
| detection_log: List[Dict[str, Any]] = [] | |
| for idx, frame in enumerate(frames): | |
| if max_frames is not None and idx >= max_frames: | |
| break | |
| logging.debug("Processing frame %d", idx) | |
| processed_frame, detections = infer_frame(frame, queries, detector_name=detector_name) | |
| detection_log.append({"frame_index": idx, "detections": detections}) | |
| processed_frames.append(processed_frame) | |
| if write_output_video: | |
| if not output_video_path: | |
| raise ValueError("output_video_path is required when write_output_video=True.") | |
| write_video(processed_frames, output_video_path, fps=fps, width=width, height=height) | |
| video_path_result: Optional[str] = output_video_path | |
| else: | |
| video_path_result = None | |
| mission_summary = ( | |
| summarize_results(mission_prompt_clean, resolved_plan, detection_log) if generate_summary else None | |
| ) | |
| return video_path_result, resolved_plan, mission_summary | |