import os import uuid import hashlib import tempfile from pathlib import Path from typing import Dict, List, Tuple, Optional, Any, Union import cv2 import numpy as np import torch from transformers import Pipeline from .vine_config import VineConfig from .vine_model import VineModel from .vis_utils import render_dino_frames, render_sam_frames, render_vine_frame_sets from laser.loading import load_video from laser.preprocess.mask_generation_grounding_dino import generate_masks_grounding_dino class VinePipeline(Pipeline): """ Pipeline for VINE model that handles end-to-end video understanding. """ def __init__( self, sam_config_path: Optional[str] = None, sam_checkpoint_path: Optional[str] = None, gd_config_path: Optional[str] = None, gd_checkpoint_path: Optional[str] = None, **kwargs: Any, ): self.grounding_model = None self.sam_predictor = None self.mask_generator = None self.sam_config_path = sam_config_path self.sam_checkpoint_path = sam_checkpoint_path self.gd_config_path = gd_config_path self.gd_checkpoint_path = gd_checkpoint_path super().__init__(**kwargs) self.segmentation_method = getattr( self.model.config, "segmentation_method", "grounding_dino_sam2" ) self.box_threshold = getattr(self.model.config, "box_threshold", 0.35) self.text_threshold = getattr(self.model.config, "text_threshold", 0.25) self.target_fps = getattr(self.model.config, "target_fps", 1) self.visualize = getattr(self.model.config, "visualize", False) self.visualization_dir = getattr(self.model.config, "visualization_dir", None) self.debug_visualizations = getattr( self.model.config, "debug_visualizations", False ) self._device = getattr(self.model.config, "_device") if kwargs.get("device") is not None: self._device = kwargs.get("device") # ------------------------------------------------------------------ # # Segmentation model injection # ------------------------------------------------------------------ # def set_segmentation_models( self, *, sam_predictor=None, mask_generator=None, grounding_model=None, ): if sam_predictor is not None: self.sam_predictor = sam_predictor if mask_generator is not None: self.mask_generator = mask_generator if grounding_model is not None: self.grounding_model = grounding_model # ------------------------------------------------------------------ # # Pipeline protocol # ------------------------------------------------------------------ # def _sanitize_parameters(self, **kwargs: Any): preprocess_kwargs: Dict[str, Any] = {} forward_kwargs: Dict[str, Any] = {} postprocess_kwargs: Dict[str, Any] = {} if "segmentation_method" in kwargs: preprocess_kwargs["segmentation_method"] = kwargs["segmentation_method"] if "target_fps" in kwargs: preprocess_kwargs["target_fps"] = kwargs["target_fps"] if "box_threshold" in kwargs: preprocess_kwargs["box_threshold"] = kwargs["box_threshold"] if "text_threshold" in kwargs: preprocess_kwargs["text_threshold"] = kwargs["text_threshold"] if "categorical_keywords" in kwargs: preprocess_kwargs["categorical_keywords"] = kwargs["categorical_keywords"] if "categorical_keywords" in kwargs: forward_kwargs["categorical_keywords"] = kwargs["categorical_keywords"] if "unary_keywords" in kwargs: forward_kwargs["unary_keywords"] = kwargs["unary_keywords"] if "binary_keywords" in kwargs: forward_kwargs["binary_keywords"] = kwargs["binary_keywords"] if "object_pairs" in kwargs: forward_kwargs["object_pairs"] = kwargs["object_pairs"] if "return_flattened_segments" in kwargs: forward_kwargs["return_flattened_segments"] = kwargs[ "return_flattened_segments" ] if "return_valid_pairs" in kwargs: forward_kwargs["return_valid_pairs"] = kwargs["return_valid_pairs"] if "interested_object_pairs" in kwargs: forward_kwargs["interested_object_pairs"] = kwargs[ "interested_object_pairs" ] if "debug_visualizations" in kwargs: forward_kwargs["debug_visualizations"] = kwargs["debug_visualizations"] postprocess_kwargs["debug_visualizations"] = kwargs["debug_visualizations"] if "return_top_k" in kwargs: postprocess_kwargs["return_top_k"] = kwargs["return_top_k"] if "self.visualize" in kwargs: postprocess_kwargs["self.visualize"] = kwargs["self.visualize"] return preprocess_kwargs, forward_kwargs, postprocess_kwargs # ------------------------------------------------------------------ # # Preprocess: video + segmentation # ------------------------------------------------------------------ # def preprocess( self, video_input: Union[str, np.ndarray, torch.Tensor], segmentation_method: Optional[str] = None, target_fps: Optional[int] = None, box_threshold: Optional[float] = None, text_threshold: Optional[float] = None, categorical_keywords: Optional[List[str]] = None, **kwargs: Any, ) -> Dict[str, Any]: if segmentation_method is None: segmentation_method = self.segmentation_method if target_fps is None: target_fps = self.target_fps else: self.target_fps = target_fps if box_threshold is None: box_threshold = self.box_threshold else: self.box_threshold = box_threshold if text_threshold is None: text_threshold = self.text_threshold else: self.text_threshold = text_threshold if categorical_keywords is None: categorical_keywords = ["object"] if isinstance(video_input, str): video_tensor = load_video(video_input, target_fps=target_fps) if isinstance(video_tensor, list): video_tensor = np.array(video_tensor) elif isinstance(video_tensor, torch.Tensor): video_tensor = video_tensor.cpu().numpy() elif isinstance(video_input, (np.ndarray, torch.Tensor)): if isinstance(video_input, torch.Tensor): video_tensor = video_input.numpy() else: video_tensor = video_input else: raise ValueError(f"Unsupported video input type: {type(video_input)}") if not isinstance(video_tensor, np.ndarray): video_tensor = np.array(video_tensor) if len(video_tensor.shape) != 4: raise ValueError( f"Expected video tensor shape (frames, height, width, channels), got {video_tensor.shape}" ) visualization_data: Dict[str, Any] = {} print(f"Segmentation method: {segmentation_method}") if segmentation_method == "sam2": masks, bboxes, vis_data = self._generate_sam2_masks(video_tensor) elif segmentation_method == "grounding_dino_sam2": masks, bboxes, vis_data = self._generate_grounding_dino_sam2_masks( video_tensor, categorical_keywords, box_threshold, text_threshold, video_input, ) else: raise ValueError(f"Unsupported segmentation method: {segmentation_method}") if vis_data: visualization_data.update(vis_data) visualization_data.setdefault("sam_masks", masks) return { "video_frames": torch.tensor(video_tensor), "masks": masks, "bboxes": bboxes, "num_frames": len(video_tensor), "visualization_data": visualization_data, } # ------------------------------------------------------------------ # # Segmentation helpers # ------------------------------------------------------------------ # def _generate_sam2_masks( self, video_tensor: np.ndarray ) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]: print("Generating SAM2 masks...") if self.mask_generator is None: self._initialize_segmentation_models() if self.mask_generator is None: raise ValueError("SAM2 mask generator not available") masks: Dict[int, Dict[int, torch.Tensor]] = {} bboxes: Dict[int, Dict[int, List[int]]] = {} for frame_id, frame in enumerate(video_tensor): if isinstance(frame, np.ndarray) and frame.dtype != np.uint8: frame = ( (frame * 255).astype(np.uint8) if frame.max() <= 1 else frame.astype(np.uint8) ) frame_masks = self.mask_generator.generate(frame) masks[frame_id] = {} bboxes[frame_id] = {} for obj_id, mask_data in enumerate(frame_masks): mask = mask_data["segmentation"] if isinstance(mask, np.ndarray): mask = torch.from_numpy(mask) if len(mask.shape) == 2: mask = mask.unsqueeze(-1) elif len(mask.shape) == 3 and mask.shape[0] == 1: mask = mask.permute(1, 2, 0) wrapped_id = obj_id + 1 masks[frame_id][wrapped_id] = mask mask_np = ( mask.squeeze().numpy() if isinstance(mask, torch.Tensor) else mask.squeeze() ) coords = np.where(mask_np > 0) if len(coords[0]) > 0: y1, y2 = coords[0].min(), coords[0].max() x1, x2 = coords[1].min(), coords[1].max() bboxes[frame_id][wrapped_id] = [x1, y1, x2, y2] tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes) return tracked_masks, tracked_bboxes, {"sam_masks": tracked_masks} def _generate_grounding_dino_sam2_masks( self, video_tensor: np.ndarray, categorical_keywords: List[str], box_threshold: float, text_threshold: float, video_path: Union[str, None], ) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]: print("Generating Grounding DINO + SAM2 masks...") if self.grounding_model is None or self.sam_predictor is None: self._initialize_segmentation_models() if self.grounding_model is None or self.sam_predictor is None: raise ValueError("GroundingDINO or SAM2 models not available") temp_video_path = None if video_path is None or not isinstance(video_path, str): temp_video_path = self._create_temp_video(video_tensor) video_path = temp_video_path CHUNK = 5 classes_ls = [ categorical_keywords[i : i + CHUNK] for i in range(0, len(categorical_keywords), CHUNK) ] base_name = Path(video_path).stem fps_tag = f"fps{int(self.target_fps)}" path_hash = hashlib.md5(video_path.encode("utf-8")).hexdigest()[:8] video_cache_name = f"{base_name}_{fps_tag}_{path_hash}" video_segments, oid_class_pred, _ = generate_masks_grounding_dino( self.grounding_model, box_threshold, text_threshold, self.sam_predictor, self.mask_generator, video_tensor, video_path, video_cache_name, out_dir=tempfile.gettempdir(), classes_ls=classes_ls, target_fps=self.target_fps, visualize=self.debug_visualizations, frames=None, max_prop_time=2, ) masks: Dict[int, Dict[int, torch.Tensor]] = {} bboxes: Dict[int, Dict[int, List[int]]] = {} for frame_id, frame_masks in video_segments.items(): masks[frame_id] = {} bboxes[frame_id] = {} for obj_id, mask in frame_masks.items(): if not isinstance(mask, torch.Tensor): mask = torch.tensor(mask) masks[frame_id][obj_id] = mask mask_np = mask.numpy() if mask_np.ndim == 3 and mask_np.shape[0] == 1: mask_np = np.squeeze(mask_np, axis=0) coords = np.where(mask_np > 0) if len(coords[0]) > 0: y1, y2 = coords[0].min(), coords[0].max() x1, x2 = coords[1].min(), coords[1].max() bboxes[frame_id][obj_id] = [x1, y1, x2, y2] if temp_video_path and os.path.exists(temp_video_path): os.remove(temp_video_path) tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes) vis_data: Dict[str, Any] = { "sam_masks": tracked_masks, "dino_labels": oid_class_pred, } return tracked_masks, tracked_bboxes, vis_data # ------------------------------------------------------------------ # # ID tracking across frames # ------------------------------------------------------------------ # def _bbox_iou(self, box1: List[int], box2: List[int]) -> float: x1, y1, x2, y2 = box1 x1b, y1b, x2b, y2b = box2 ix1 = max(x1, x1b) iy1 = max(y1, y1b) ix2 = min(x2, x2b) iy2 = min(y2, y2b) iw = max(0, ix2 - ix1) ih = max(0, iy2 - iy1) inter = iw * ih if inter <= 0: return 0.0 area1 = max(0, x2 - x1) * max(0, y2 - y1) area2 = max(0, x2b - x1b) * max(0, y2b - y1b) union = area1 + area2 - inter if union <= 0: return 0.0 return inter / union def _track_ids_across_frames( self, masks: Dict[int, Dict[int, torch.Tensor]], bboxes: Dict[int, Dict[int, List[int]]], iou_threshold: float = 0.3, ) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]]]: frame_ids = sorted(masks.keys()) tracked_masks: Dict[int, Dict[int, torch.Tensor]] = {} tracked_bboxes: Dict[int, Dict[int, List[int]]] = {} next_track_id = 0 prev_tracks: Dict[int, List[int]] = {} for frame_id in frame_ids: frame_masks = masks.get(frame_id, {}) frame_boxes = bboxes.get(frame_id, {}) tracked_masks[frame_id] = {} tracked_bboxes[frame_id] = {} if not frame_boxes: prev_tracks = {} continue det_ids = list(frame_boxes.keys()) prev_ids = list(prev_tracks.keys()) candidates: List[Tuple[float, int, int]] = [] for tid in prev_ids: prev_box = prev_tracks[tid] for det_id in det_ids: iou = self._bbox_iou(prev_box, frame_boxes[det_id]) if iou > iou_threshold: candidates.append((iou, tid, det_id)) candidates.sort(reverse=True) matched_prev = set() matched_det = set() for iou, tid, det_id in candidates: if tid in matched_prev or det_id in matched_det: continue matched_prev.add(tid) matched_det.add(det_id) tracked_masks[frame_id][tid] = frame_masks[det_id] tracked_bboxes[frame_id][tid] = frame_boxes[det_id] for det_id in det_ids: if det_id in matched_det: continue tid = next_track_id next_track_id += 1 tracked_masks[frame_id][tid] = frame_masks[det_id] tracked_bboxes[frame_id][tid] = frame_boxes[det_id] prev_tracks = { tid: tracked_bboxes[frame_id][tid] for tid in tracked_bboxes[frame_id].keys() } return tracked_masks, tracked_bboxes # ------------------------------------------------------------------ # # Segmentation model initialization # ------------------------------------------------------------------ # def _initialize_segmentation_models(self): if self.sam_predictor is None or self.mask_generator is None: self._initialize_sam2_models() if self.grounding_model is None: self._initialize_grounding_dino_model() def _initialize_sam2_models(self): try: from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator except ImportError as e: print(f"Warning: Could not import SAM2: {e}") return config_path, checkpoint_path = self._resolve_sam2_paths() if self.sam_config_path is not None and not os.path.exists(config_path): raise ValueError(f"SAM2 config path not found: {config_path}") if self.sam_checkpoint_path is not None and not os.path.exists(checkpoint_path): raise ValueError(f"SAM2 checkpoint path not found: {checkpoint_path}") if not os.path.exists(checkpoint_path): print(f"Warning: SAM2 checkpoint not found at {checkpoint_path}") print("SAM2 functionality will be unavailable") return try: device = self._device self.sam_predictor = build_sam2_video_predictor( config_path, checkpoint_path, device=device ) sam2_model = build_sam2( config_path, checkpoint_path, device=device, apply_postprocessing=False, ) self.mask_generator = SAM2AutomaticMaskGenerator( model=sam2_model, points_per_side=32, points_per_batch=32, pred_iou_thresh=0.7, stability_score_thresh=0.8, crop_n_layers=2, box_nms_thresh=0.6, crop_n_points_downscale_factor=2, min_mask_region_area=100, use_m2m=True, ) print("✓ SAM2 models initialized successfully") except Exception as e: raise ValueError(f"Failed to initialize SAM2 with custom paths: {e}") def _initialize_grounding_dino_model(self): try: from groundingdino.util.inference import Model as gd_Model except ImportError as e: print(f"Warning: Could not import GroundingDINO: {e}") return config_path, checkpoint_path = self._resolve_grounding_dino_paths() if self.gd_config_path is not None and not os.path.exists(config_path): raise ValueError(f"GroundingDINO config path not found: {config_path}") if self.gd_checkpoint_path is not None and not os.path.exists(checkpoint_path): raise ValueError( f"GroundingDINO checkpoint path not found: {checkpoint_path}" ) if not (os.path.exists(config_path) and os.path.exists(checkpoint_path)): print( f"Warning: GroundingDINO models not found at {config_path} / {checkpoint_path}" ) print("GroundingDINO functionality will be unavailable") return try: device = self._device self.grounding_model = gd_Model( model_config_path=config_path, model_checkpoint_path=checkpoint_path, device=device, ) print("✓ GroundingDINO model initialized successfully") except Exception as e: raise ValueError(f"Failed to initialize GroundingDINO with custom paths: {e}") def _resolve_sam2_paths(self): if self.sam_config_path and self.sam_checkpoint_path: return self.sam_config_path, self.sam_checkpoint_path def _resolve_grounding_dino_paths(self): if self.gd_config_path and self.gd_checkpoint_path: return self.gd_config_path, self.gd_checkpoint_path # ------------------------------------------------------------------ # # Video writing helpers # ------------------------------------------------------------------ # def _prepare_visualization_dir(self, name: str, enabled: bool) -> Optional[str]: if not enabled: return None if self.visualization_dir: target_dir = ( os.path.join(self.visualization_dir, name) if name else self.visualization_dir ) os.makedirs(target_dir, exist_ok=True) return target_dir return tempfile.mkdtemp(prefix=f"vine_{name}_") def _create_temp_video( self, video_tensor: np.ndarray, base_dir: Optional[str] = None, prefix: str = "temp_video", ) -> str: import subprocess if base_dir is None: base_dir = tempfile.mkdtemp(prefix=f"vine_{prefix}_") else: os.makedirs(base_dir, exist_ok=True) file_name = f"{prefix}_{uuid.uuid4().hex}.mp4" temp_path = os.path.join(base_dir, file_name) height, width = video_tensor.shape[1:3] processing_fps = max(1, self.target_fps) output_fps = processing_fps video_tensor_for_output = video_tensor ffmpeg_success = False try: ffmpeg_success = self._create_video_with_ffmpeg( video_tensor_for_output, temp_path, output_fps, width, height ) except Exception as e: print(f"FFmpeg method failed: {e}") if not ffmpeg_success: print("Using OpenCV fallback") self._create_temp_video_opencv( video_tensor_for_output, temp_path, output_fps, width, height ) return temp_path def _create_video_with_ffmpeg( self, video_tensor: np.ndarray, output_path: str, fps: int, width: int, height: int ) -> bool: import subprocess try: ffmpeg_cmd = [ "ffmpeg", "-y", "-f", "rawvideo", "-vcodec", "rawvideo", "-s", f"{width}x{height}", "-pix_fmt", "rgb24", "-r", str(fps), "-i", "pipe:0", "-c:v", "libx264", "-preset", "fast", "-crf", "23", "-pix_fmt", "yuv420p", "-movflags", "+faststart", "-loglevel", "error", output_path, ] process = subprocess.Popen( ffmpeg_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) frame_data = b"" for frame in video_tensor: if frame.dtype != np.uint8: frame = ( (frame * 255).astype(np.uint8) if frame.max() <= 1 else frame.astype(np.uint8) ) frame_data += frame.tobytes() stdout, stderr = process.communicate(input=frame_data, timeout=60) if process.returncode == 0: print(f"Video created with FFmpeg (H.264) at {fps} FPS") return True else: error_msg = stderr.decode() if stderr else "Unknown error" print(f"FFmpeg error: {error_msg}") return False except FileNotFoundError: print("FFmpeg not found in PATH") return False except Exception as e: print(f"FFmpeg exception: {e}") return False def _create_temp_video_opencv( self, video_tensor: np.ndarray, temp_path: str, fps: int, width: int, height: int ) -> str: codecs_to_try = ["avc1", "X264", "mp4v"] out = None used_codec = None for codec in codecs_to_try: try: fourcc = cv2.VideoWriter_fourcc(*codec) temp_out = cv2.VideoWriter(temp_path, fourcc, fps, (width, height)) if temp_out.isOpened(): out = temp_out used_codec = codec break else: temp_out.release() except Exception as e: print(f"Warning: Codec {codec} not available: {e}") continue if out is None or not out.isOpened(): raise RuntimeError( f"Failed to initialize VideoWriter with any codec. Tried: {codecs_to_try}" ) print(f"Using OpenCV with codec: {used_codec}") for frame in video_tensor: if len(frame.shape) == 3 and frame.shape[2] == 3: frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) else: frame_bgr = frame if frame_bgr.dtype != np.uint8: frame_bgr = ( (frame_bgr * 255).astype(np.uint8) if frame_bgr.max() <= 1 else frame_bgr.astype(np.uint8) ) out.write(frame_bgr) out.release() return temp_path # ------------------------------------------------------------------ # # Forward + postprocess # ------------------------------------------------------------------ # def _forward(self, model_inputs: Dict[str, Any], **forward_kwargs: Any) -> Dict[str, Any]: outputs = self.model.predict( video_frames=model_inputs["video_frames"], masks=model_inputs["masks"], bboxes=model_inputs["bboxes"], **forward_kwargs, ) outputs.setdefault("video_frames", model_inputs.get("video_frames")) outputs.setdefault("bboxes", model_inputs.get("bboxes")) outputs.setdefault("masks", model_inputs.get("masks")) outputs.setdefault("visualization_data", model_inputs.get("visualization_data")) return outputs def postprocess( self, model_outputs: Dict[str, Any], return_top_k: int = 3, visualize: Optional[bool] = None, **kwargs: Any, ) -> Dict[str, Any]: results: Dict[str, Any] = { "categorical_predictions": model_outputs.get("categorical_predictions", {}), "unary_predictions": model_outputs.get("unary_predictions", {}), "binary_predictions": model_outputs.get("binary_predictions", {}), "confidence_scores": model_outputs.get("confidence_scores", {}), "summary": self._generate_summary(model_outputs), } print("\n" + "=" * 50) print("DEBUG: Raw Model Outputs - Categorical Predictions") cat_preds = model_outputs.get("categorical_predictions", {}) for obj_id, preds in cat_preds.items(): print(f"Object {obj_id}: {preds}") print("=" * 50 + "\n") if "flattened_segments" in model_outputs: results["flattened_segments"] = model_outputs["flattened_segments"] if "valid_pairs" in model_outputs: results["valid_pairs"] = model_outputs["valid_pairs"] if "valid_pairs_metadata" in model_outputs: results["valid_pairs_metadata"] = model_outputs["valid_pairs_metadata"] if "visualization_data" in model_outputs: results["visualization_data"] = model_outputs["visualization_data"] if self.visualize and "video_frames" in model_outputs and "bboxes" in model_outputs: frames_tensor = model_outputs["video_frames"] if isinstance(frames_tensor, torch.Tensor): frames_np = frames_tensor.detach().cpu().numpy() else: frames_np = np.asarray(frames_tensor) if frames_np.dtype != np.uint8: if np.issubdtype(frames_np.dtype, np.floating): max_val = frames_np.max() if frames_np.size else 0.0 scale = 255.0 if max_val <= 1.0 else 1.0 frames_np = (frames_np * scale).clip(0, 255).astype(np.uint8) else: frames_np = frames_np.clip(0, 255).astype(np.uint8) cat_label_lookup: Dict[int, Tuple[str, float]] = {} for obj_id, preds in model_outputs.get("categorical_predictions", {}).items(): if preds: prob, label = preds[0] cat_label_lookup[obj_id] = (label, prob) unary_preds = model_outputs.get("unary_predictions", {}) unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]] = {} for (frame_id, obj_id), preds in unary_preds.items(): if preds: unary_lookup.setdefault(frame_id, {})[obj_id] = preds[:1] binary_preds = model_outputs.get("binary_predictions", {}) binary_lookup: Dict[ int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]] ] = {} for (frame_id, obj_pair), preds in binary_preds.items(): if preds: binary_lookup.setdefault(frame_id, []).append((obj_pair, preds[:1])) bboxes = model_outputs["bboxes"] visualization_data = model_outputs.get("visualization_data", {}) visualizations: Dict[str, Dict[str, Any]] = {} debug_visualizations = kwargs.get("debug_visualizations") if debug_visualizations is None: debug_visualizations = self.debug_visualizations vine_frame_sets = render_vine_frame_sets( frames_np, bboxes, cat_label_lookup, unary_lookup, binary_lookup, visualization_data.get("sam_masks"), ) vine_visuals: Dict[str, Dict[str, Any]] = {} final_frames = vine_frame_sets.get("all", []) if final_frames: final_entry: Dict[str, Any] = {"frames": final_frames, "video_path": None} final_dir = self._prepare_visualization_dir( "all", enabled=self.visualize ) final_entry["video_path"] = self._create_temp_video( np.stack(final_frames, axis=0), base_dir=final_dir, prefix="all_visualization", ) vine_visuals["all"] = final_entry if debug_visualizations: sam_masks = visualization_data.get("sam_masks") if sam_masks: sam_frames = render_sam_frames( frames_np, sam_masks, visualization_data.get("dino_labels") ) sam_entry = {"frames": sam_frames, "video_path": None} if sam_frames: sam_dir = self._prepare_visualization_dir( "sam", enabled=self.visualize ) sam_entry["video_path"] = self._create_temp_video( np.stack(sam_frames, axis=0), base_dir=sam_dir, prefix="sam_visualization", ) visualizations["sam"] = sam_entry dino_labels = visualization_data.get("dino_labels") if dino_labels: dino_frames = render_dino_frames(frames_np, bboxes, dino_labels) dino_entry = {"frames": dino_frames, "video_path": None} if dino_frames: dino_dir = self._prepare_visualization_dir( "dino", enabled=self.visualize ) dino_entry["video_path"] = self._create_temp_video( np.stack(dino_frames, axis=0), base_dir=dino_dir, prefix="dino_visualization", ) visualizations["dino"] = dino_entry for name in ("object", "unary", "binary"): frames_list = vine_frame_sets.get(name, []) entry: Dict[str, Any] = {"frames": frames_list, "video_path": None} if frames_list: vine_dir = self._prepare_visualization_dir( name, enabled=self.visualize ) entry["video_path"] = self._create_temp_video( np.stack(frames_list, axis=0), base_dir=vine_dir, prefix=f"{name}_visualization", ) vine_visuals[name] = entry if vine_visuals: visualizations["vine"] = vine_visuals if visualizations: results["visualizations"] = visualizations return results # ------------------------------------------------------------------ # # Summary JSON # ------------------------------------------------------------------ # def _generate_summary(self, model_outputs: Dict[str, Any]) -> Dict[str, Any]: """ Per-object summary: { "num_objects_detected": N, "objects": { "": { "top_categories": [{"label": str, "probability": float}, ...], "top_unary": [{"frame_id": int, "predicate": str, "probability": float}, ...], } } } """ categorical_preds = model_outputs.get("categorical_predictions", {}) unary_preds = model_outputs.get("unary_predictions", {}) unary_by_obj: Dict[int, List[Tuple[float, str, int]]] = {} for (frame_id, obj_id), preds in unary_preds.items(): for prob, predicate in preds: prob_val = ( float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) ) unary_by_obj.setdefault(obj_id, []).append((prob_val, predicate, frame_id)) objects_summary: Dict[str, Dict[str, Any]] = {} all_obj_ids = set(categorical_preds.keys()) | set(unary_by_obj.keys()) for obj_id in sorted(all_obj_ids): cat_list = categorical_preds.get(obj_id, []) cat_sorted = sorted( [ ( float(p.detach().cpu()) if torch.is_tensor(p) else float(p), label, ) for p, label in cat_list ], key=lambda x: x[0], reverse=True, )[:3] top_categories = [ {"label": label, "probability": prob} for prob, label in cat_sorted ] unary_list = unary_by_obj.get(obj_id, []) unary_sorted = sorted(unary_list, key=lambda x: x[0], reverse=True)[:3] top_unary = [ { "frame_id": int(frame_id), "predicate": predicate, "probability": prob, } for (prob, predicate, frame_id) in unary_sorted ] objects_summary[str(obj_id)] = { "top_categories": top_categories, "top_unary": top_unary, } summary = { "num_objects_detected": len(objects_summary), "objects": objects_summary, } return summary