LASER / vine_hf /vine_pipeline.py
ASethi04's picture
updates
f9a6349
raw
history blame
36.5 kB
import os
import uuid
import hashlib
import tempfile
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union
import cv2
import numpy as np
import torch
from transformers import Pipeline
from .vine_config import VineConfig
from .vine_model import VineModel
from .vis_utils import render_dino_frames, render_sam_frames, render_vine_frame_sets
from laser.loading import load_video
from laser.preprocess.mask_generation_grounding_dino import generate_masks_grounding_dino
class VinePipeline(Pipeline):
"""
Pipeline for VINE model that handles end-to-end video understanding.
"""
def __init__(
self,
sam_config_path: Optional[str] = None,
sam_checkpoint_path: Optional[str] = None,
gd_config_path: Optional[str] = None,
gd_checkpoint_path: Optional[str] = None,
**kwargs: Any,
):
self.grounding_model = None
self.sam_predictor = None
self.mask_generator = None
self.sam_config_path = sam_config_path
self.sam_checkpoint_path = sam_checkpoint_path
self.gd_config_path = gd_config_path
self.gd_checkpoint_path = gd_checkpoint_path
super().__init__(**kwargs)
self.segmentation_method = getattr(
self.model.config, "segmentation_method", "grounding_dino_sam2"
)
self.box_threshold = getattr(self.model.config, "box_threshold", 0.35)
self.text_threshold = getattr(self.model.config, "text_threshold", 0.25)
self.target_fps = getattr(self.model.config, "target_fps", 1)
self.visualize = getattr(self.model.config, "visualize", False)
self.visualization_dir = getattr(self.model.config, "visualization_dir", None)
self.debug_visualizations = getattr(
self.model.config, "debug_visualizations", False
)
self._device = getattr(self.model.config, "_device")
if kwargs.get("device") is not None:
self._device = kwargs.get("device")
# ------------------------------------------------------------------ #
# Segmentation model injection
# ------------------------------------------------------------------ #
def set_segmentation_models(
self,
*,
sam_predictor=None,
mask_generator=None,
grounding_model=None,
):
if sam_predictor is not None:
self.sam_predictor = sam_predictor
if mask_generator is not None:
self.mask_generator = mask_generator
if grounding_model is not None:
self.grounding_model = grounding_model
# ------------------------------------------------------------------ #
# Pipeline protocol
# ------------------------------------------------------------------ #
def _sanitize_parameters(self, **kwargs: Any):
preprocess_kwargs: Dict[str, Any] = {}
forward_kwargs: Dict[str, Any] = {}
postprocess_kwargs: Dict[str, Any] = {}
if "segmentation_method" in kwargs:
preprocess_kwargs["segmentation_method"] = kwargs["segmentation_method"]
if "target_fps" in kwargs:
preprocess_kwargs["target_fps"] = kwargs["target_fps"]
if "box_threshold" in kwargs:
preprocess_kwargs["box_threshold"] = kwargs["box_threshold"]
if "text_threshold" in kwargs:
preprocess_kwargs["text_threshold"] = kwargs["text_threshold"]
if "categorical_keywords" in kwargs:
preprocess_kwargs["categorical_keywords"] = kwargs["categorical_keywords"]
if "categorical_keywords" in kwargs:
forward_kwargs["categorical_keywords"] = kwargs["categorical_keywords"]
if "unary_keywords" in kwargs:
forward_kwargs["unary_keywords"] = kwargs["unary_keywords"]
if "binary_keywords" in kwargs:
forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
if "object_pairs" in kwargs:
forward_kwargs["object_pairs"] = kwargs["object_pairs"]
if "return_flattened_segments" in kwargs:
forward_kwargs["return_flattened_segments"] = kwargs[
"return_flattened_segments"
]
if "return_valid_pairs" in kwargs:
forward_kwargs["return_valid_pairs"] = kwargs["return_valid_pairs"]
if "interested_object_pairs" in kwargs:
forward_kwargs["interested_object_pairs"] = kwargs[
"interested_object_pairs"
]
if "debug_visualizations" in kwargs:
forward_kwargs["debug_visualizations"] = kwargs["debug_visualizations"]
postprocess_kwargs["debug_visualizations"] = kwargs["debug_visualizations"]
if "return_top_k" in kwargs:
postprocess_kwargs["return_top_k"] = kwargs["return_top_k"]
if "self.visualize" in kwargs:
postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
# ------------------------------------------------------------------ #
# Preprocess: video + segmentation
# ------------------------------------------------------------------ #
def preprocess(
self,
video_input: Union[str, np.ndarray, torch.Tensor],
segmentation_method: Optional[str] = None,
target_fps: Optional[int] = None,
box_threshold: Optional[float] = None,
text_threshold: Optional[float] = None,
categorical_keywords: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if segmentation_method is None:
segmentation_method = self.segmentation_method
if target_fps is None:
target_fps = self.target_fps
else:
self.target_fps = target_fps
if box_threshold is None:
box_threshold = self.box_threshold
else:
self.box_threshold = box_threshold
if text_threshold is None:
text_threshold = self.text_threshold
else:
self.text_threshold = text_threshold
if categorical_keywords is None:
categorical_keywords = ["object"]
if isinstance(video_input, str):
video_tensor = load_video(video_input, target_fps=target_fps)
if isinstance(video_tensor, list):
video_tensor = np.array(video_tensor)
elif isinstance(video_tensor, torch.Tensor):
video_tensor = video_tensor.cpu().numpy()
elif isinstance(video_input, (np.ndarray, torch.Tensor)):
if isinstance(video_input, torch.Tensor):
video_tensor = video_input.numpy()
else:
video_tensor = video_input
else:
raise ValueError(f"Unsupported video input type: {type(video_input)}")
if not isinstance(video_tensor, np.ndarray):
video_tensor = np.array(video_tensor)
if len(video_tensor.shape) != 4:
raise ValueError(
f"Expected video tensor shape (frames, height, width, channels), got {video_tensor.shape}"
)
visualization_data: Dict[str, Any] = {}
print(f"Segmentation method: {segmentation_method}")
if segmentation_method == "sam2":
masks, bboxes, vis_data = self._generate_sam2_masks(video_tensor)
elif segmentation_method == "grounding_dino_sam2":
masks, bboxes, vis_data = self._generate_grounding_dino_sam2_masks(
video_tensor,
categorical_keywords,
box_threshold,
text_threshold,
video_input,
)
else:
raise ValueError(f"Unsupported segmentation method: {segmentation_method}")
if vis_data:
visualization_data.update(vis_data)
visualization_data.setdefault("sam_masks", masks)
return {
"video_frames": torch.tensor(video_tensor),
"masks": masks,
"bboxes": bboxes,
"num_frames": len(video_tensor),
"visualization_data": visualization_data,
}
# ------------------------------------------------------------------ #
# Segmentation helpers
# ------------------------------------------------------------------ #
def _generate_sam2_masks(
self, video_tensor: np.ndarray
) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]:
print("Generating SAM2 masks...")
if self.mask_generator is None:
self._initialize_segmentation_models()
if self.mask_generator is None:
raise ValueError("SAM2 mask generator not available")
masks: Dict[int, Dict[int, torch.Tensor]] = {}
bboxes: Dict[int, Dict[int, List[int]]] = {}
for frame_id, frame in enumerate(video_tensor):
if isinstance(frame, np.ndarray) and frame.dtype != np.uint8:
frame = (
(frame * 255).astype(np.uint8)
if frame.max() <= 1
else frame.astype(np.uint8)
)
frame_masks = self.mask_generator.generate(frame)
masks[frame_id] = {}
bboxes[frame_id] = {}
for obj_id, mask_data in enumerate(frame_masks):
mask = mask_data["segmentation"]
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)
if len(mask.shape) == 2:
mask = mask.unsqueeze(-1)
elif len(mask.shape) == 3 and mask.shape[0] == 1:
mask = mask.permute(1, 2, 0)
wrapped_id = obj_id + 1
masks[frame_id][wrapped_id] = mask
mask_np = (
mask.squeeze().numpy()
if isinstance(mask, torch.Tensor)
else mask.squeeze()
)
coords = np.where(mask_np > 0)
if len(coords[0]) > 0:
y1, y2 = coords[0].min(), coords[0].max()
x1, x2 = coords[1].min(), coords[1].max()
bboxes[frame_id][wrapped_id] = [x1, y1, x2, y2]
tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes)
return tracked_masks, tracked_bboxes, {"sam_masks": tracked_masks}
def _generate_grounding_dino_sam2_masks(
self,
video_tensor: np.ndarray,
categorical_keywords: List[str],
box_threshold: float,
text_threshold: float,
video_path: Union[str, None],
) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]:
print("Generating Grounding DINO + SAM2 masks...")
if self.grounding_model is None or self.sam_predictor is None:
self._initialize_segmentation_models()
if self.grounding_model is None or self.sam_predictor is None:
raise ValueError("GroundingDINO or SAM2 models not available")
temp_video_path = None
if video_path is None or not isinstance(video_path, str):
temp_video_path = self._create_temp_video(video_tensor)
video_path = temp_video_path
CHUNK = 5
classes_ls = [
categorical_keywords[i : i + CHUNK]
for i in range(0, len(categorical_keywords), CHUNK)
]
base_name = Path(video_path).stem
fps_tag = f"fps{int(self.target_fps)}"
path_hash = hashlib.md5(video_path.encode("utf-8")).hexdigest()[:8]
video_cache_name = f"{base_name}_{fps_tag}_{path_hash}"
video_segments, oid_class_pred, _ = generate_masks_grounding_dino(
self.grounding_model,
box_threshold,
text_threshold,
self.sam_predictor,
self.mask_generator,
video_tensor,
video_path,
video_cache_name,
out_dir=tempfile.gettempdir(),
classes_ls=classes_ls,
target_fps=self.target_fps,
visualize=self.debug_visualizations,
frames=None,
max_prop_time=2,
)
masks: Dict[int, Dict[int, torch.Tensor]] = {}
bboxes: Dict[int, Dict[int, List[int]]] = {}
for frame_id, frame_masks in video_segments.items():
masks[frame_id] = {}
bboxes[frame_id] = {}
for obj_id, mask in frame_masks.items():
if not isinstance(mask, torch.Tensor):
mask = torch.tensor(mask)
masks[frame_id][obj_id] = mask
mask_np = mask.numpy()
if mask_np.ndim == 3 and mask_np.shape[0] == 1:
mask_np = np.squeeze(mask_np, axis=0)
coords = np.where(mask_np > 0)
if len(coords[0]) > 0:
y1, y2 = coords[0].min(), coords[0].max()
x1, x2 = coords[1].min(), coords[1].max()
bboxes[frame_id][obj_id] = [x1, y1, x2, y2]
if temp_video_path and os.path.exists(temp_video_path):
os.remove(temp_video_path)
tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes)
vis_data: Dict[str, Any] = {
"sam_masks": tracked_masks,
"dino_labels": oid_class_pred,
}
return tracked_masks, tracked_bboxes, vis_data
# ------------------------------------------------------------------ #
# ID tracking across frames
# ------------------------------------------------------------------ #
def _bbox_iou(self, box1: List[int], box2: List[int]) -> float:
x1, y1, x2, y2 = box1
x1b, y1b, x2b, y2b = box2
ix1 = max(x1, x1b)
iy1 = max(y1, y1b)
ix2 = min(x2, x2b)
iy2 = min(y2, y2b)
iw = max(0, ix2 - ix1)
ih = max(0, iy2 - iy1)
inter = iw * ih
if inter <= 0:
return 0.0
area1 = max(0, x2 - x1) * max(0, y2 - y1)
area2 = max(0, x2b - x1b) * max(0, y2b - y1b)
union = area1 + area2 - inter
if union <= 0:
return 0.0
return inter / union
def _track_ids_across_frames(
self,
masks: Dict[int, Dict[int, torch.Tensor]],
bboxes: Dict[int, Dict[int, List[int]]],
iou_threshold: float = 0.3,
) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]]]:
frame_ids = sorted(masks.keys())
tracked_masks: Dict[int, Dict[int, torch.Tensor]] = {}
tracked_bboxes: Dict[int, Dict[int, List[int]]] = {}
next_track_id = 0
prev_tracks: Dict[int, List[int]] = {}
for frame_id in frame_ids:
frame_masks = masks.get(frame_id, {})
frame_boxes = bboxes.get(frame_id, {})
tracked_masks[frame_id] = {}
tracked_bboxes[frame_id] = {}
if not frame_boxes:
prev_tracks = {}
continue
det_ids = list(frame_boxes.keys())
prev_ids = list(prev_tracks.keys())
candidates: List[Tuple[float, int, int]] = []
for tid in prev_ids:
prev_box = prev_tracks[tid]
for det_id in det_ids:
iou = self._bbox_iou(prev_box, frame_boxes[det_id])
if iou > iou_threshold:
candidates.append((iou, tid, det_id))
candidates.sort(reverse=True)
matched_prev = set()
matched_det = set()
for iou, tid, det_id in candidates:
if tid in matched_prev or det_id in matched_det:
continue
matched_prev.add(tid)
matched_det.add(det_id)
tracked_masks[frame_id][tid] = frame_masks[det_id]
tracked_bboxes[frame_id][tid] = frame_boxes[det_id]
for det_id in det_ids:
if det_id in matched_det:
continue
tid = next_track_id
next_track_id += 1
tracked_masks[frame_id][tid] = frame_masks[det_id]
tracked_bboxes[frame_id][tid] = frame_boxes[det_id]
prev_tracks = {
tid: tracked_bboxes[frame_id][tid]
for tid in tracked_bboxes[frame_id].keys()
}
return tracked_masks, tracked_bboxes
# ------------------------------------------------------------------ #
# Segmentation model initialization
# ------------------------------------------------------------------ #
def _initialize_segmentation_models(self):
if self.sam_predictor is None or self.mask_generator is None:
self._initialize_sam2_models()
if self.grounding_model is None:
self._initialize_grounding_dino_model()
def _initialize_sam2_models(self):
try:
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
except ImportError as e:
print(f"Warning: Could not import SAM2: {e}")
return
config_path, checkpoint_path = self._resolve_sam2_paths()
if self.sam_config_path is not None and not os.path.exists(config_path):
raise ValueError(f"SAM2 config path not found: {config_path}")
if self.sam_checkpoint_path is not None and not os.path.exists(checkpoint_path):
raise ValueError(f"SAM2 checkpoint path not found: {checkpoint_path}")
if not os.path.exists(checkpoint_path):
print(f"Warning: SAM2 checkpoint not found at {checkpoint_path}")
print("SAM2 functionality will be unavailable")
return
try:
device = self._device
self.sam_predictor = build_sam2_video_predictor(
config_path, checkpoint_path, device=device
)
sam2_model = build_sam2(
config_path,
checkpoint_path,
device=device,
apply_postprocessing=False,
)
self.mask_generator = SAM2AutomaticMaskGenerator(
model=sam2_model,
points_per_side=32,
points_per_batch=32,
pred_iou_thresh=0.7,
stability_score_thresh=0.8,
crop_n_layers=2,
box_nms_thresh=0.6,
crop_n_points_downscale_factor=2,
min_mask_region_area=100,
use_m2m=True,
)
print("✓ SAM2 models initialized successfully")
except Exception as e:
raise ValueError(f"Failed to initialize SAM2 with custom paths: {e}")
def _initialize_grounding_dino_model(self):
try:
from groundingdino.util.inference import Model as gd_Model
except ImportError as e:
print(f"Warning: Could not import GroundingDINO: {e}")
return
config_path, checkpoint_path = self._resolve_grounding_dino_paths()
if self.gd_config_path is not None and not os.path.exists(config_path):
raise ValueError(f"GroundingDINO config path not found: {config_path}")
if self.gd_checkpoint_path is not None and not os.path.exists(checkpoint_path):
raise ValueError(
f"GroundingDINO checkpoint path not found: {checkpoint_path}"
)
if not (os.path.exists(config_path) and os.path.exists(checkpoint_path)):
print(
f"Warning: GroundingDINO models not found at {config_path} / {checkpoint_path}"
)
print("GroundingDINO functionality will be unavailable")
return
try:
device = self._device
self.grounding_model = gd_Model(
model_config_path=config_path,
model_checkpoint_path=checkpoint_path,
device=device,
)
print("✓ GroundingDINO model initialized successfully")
except Exception as e:
raise ValueError(f"Failed to initialize GroundingDINO with custom paths: {e}")
def _resolve_sam2_paths(self):
if self.sam_config_path and self.sam_checkpoint_path:
return self.sam_config_path, self.sam_checkpoint_path
def _resolve_grounding_dino_paths(self):
if self.gd_config_path and self.gd_checkpoint_path:
return self.gd_config_path, self.gd_checkpoint_path
# ------------------------------------------------------------------ #
# Video writing helpers
# ------------------------------------------------------------------ #
def _prepare_visualization_dir(self, name: str, enabled: bool) -> Optional[str]:
if not enabled:
return None
if self.visualization_dir:
target_dir = (
os.path.join(self.visualization_dir, name)
if name
else self.visualization_dir
)
os.makedirs(target_dir, exist_ok=True)
return target_dir
return tempfile.mkdtemp(prefix=f"vine_{name}_")
def _create_temp_video(
self,
video_tensor: np.ndarray,
base_dir: Optional[str] = None,
prefix: str = "temp_video",
) -> str:
import subprocess
if base_dir is None:
base_dir = tempfile.mkdtemp(prefix=f"vine_{prefix}_")
else:
os.makedirs(base_dir, exist_ok=True)
file_name = f"{prefix}_{uuid.uuid4().hex}.mp4"
temp_path = os.path.join(base_dir, file_name)
height, width = video_tensor.shape[1:3]
processing_fps = max(1, self.target_fps)
output_fps = processing_fps
video_tensor_for_output = video_tensor
ffmpeg_success = False
try:
ffmpeg_success = self._create_video_with_ffmpeg(
video_tensor_for_output, temp_path, output_fps, width, height
)
except Exception as e:
print(f"FFmpeg method failed: {e}")
if not ffmpeg_success:
print("Using OpenCV fallback")
self._create_temp_video_opencv(
video_tensor_for_output, temp_path, output_fps, width, height
)
return temp_path
def _create_video_with_ffmpeg(
self, video_tensor: np.ndarray, output_path: str, fps: int, width: int, height: int
) -> bool:
import subprocess
try:
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-f",
"rawvideo",
"-vcodec",
"rawvideo",
"-s",
f"{width}x{height}",
"-pix_fmt",
"rgb24",
"-r",
str(fps),
"-i",
"pipe:0",
"-c:v",
"libx264",
"-preset",
"fast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-movflags",
"+faststart",
"-loglevel",
"error",
output_path,
]
process = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
frame_data = b""
for frame in video_tensor:
if frame.dtype != np.uint8:
frame = (
(frame * 255).astype(np.uint8)
if frame.max() <= 1
else frame.astype(np.uint8)
)
frame_data += frame.tobytes()
stdout, stderr = process.communicate(input=frame_data, timeout=60)
if process.returncode == 0:
print(f"Video created with FFmpeg (H.264) at {fps} FPS")
return True
else:
error_msg = stderr.decode() if stderr else "Unknown error"
print(f"FFmpeg error: {error_msg}")
return False
except FileNotFoundError:
print("FFmpeg not found in PATH")
return False
except Exception as e:
print(f"FFmpeg exception: {e}")
return False
def _create_temp_video_opencv(
self, video_tensor: np.ndarray, temp_path: str, fps: int, width: int, height: int
) -> str:
codecs_to_try = ["avc1", "X264", "mp4v"]
out = None
used_codec = None
for codec in codecs_to_try:
try:
fourcc = cv2.VideoWriter_fourcc(*codec)
temp_out = cv2.VideoWriter(temp_path, fourcc, fps, (width, height))
if temp_out.isOpened():
out = temp_out
used_codec = codec
break
else:
temp_out.release()
except Exception as e:
print(f"Warning: Codec {codec} not available: {e}")
continue
if out is None or not out.isOpened():
raise RuntimeError(
f"Failed to initialize VideoWriter with any codec. Tried: {codecs_to_try}"
)
print(f"Using OpenCV with codec: {used_codec}")
for frame in video_tensor:
if len(frame.shape) == 3 and frame.shape[2] == 3:
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
else:
frame_bgr = frame
if frame_bgr.dtype != np.uint8:
frame_bgr = (
(frame_bgr * 255).astype(np.uint8)
if frame_bgr.max() <= 1
else frame_bgr.astype(np.uint8)
)
out.write(frame_bgr)
out.release()
return temp_path
# ------------------------------------------------------------------ #
# Forward + postprocess
# ------------------------------------------------------------------ #
def _forward(self, model_inputs: Dict[str, Any], **forward_kwargs: Any) -> Dict[str, Any]:
outputs = self.model.predict(
video_frames=model_inputs["video_frames"],
masks=model_inputs["masks"],
bboxes=model_inputs["bboxes"],
**forward_kwargs,
)
outputs.setdefault("video_frames", model_inputs.get("video_frames"))
outputs.setdefault("bboxes", model_inputs.get("bboxes"))
outputs.setdefault("masks", model_inputs.get("masks"))
outputs.setdefault("visualization_data", model_inputs.get("visualization_data"))
return outputs
def postprocess(
self,
model_outputs: Dict[str, Any],
return_top_k: int = 3,
visualize: Optional[bool] = None,
**kwargs: Any,
) -> Dict[str, Any]:
results: Dict[str, Any] = {
"categorical_predictions": model_outputs.get("categorical_predictions", {}),
"unary_predictions": model_outputs.get("unary_predictions", {}),
"binary_predictions": model_outputs.get("binary_predictions", {}),
"confidence_scores": model_outputs.get("confidence_scores", {}),
"summary": self._generate_summary(model_outputs),
}
print("\n" + "=" * 50)
print("DEBUG: Raw Model Outputs - Categorical Predictions")
cat_preds = model_outputs.get("categorical_predictions", {})
for obj_id, preds in cat_preds.items():
print(f"Object {obj_id}: {preds}")
print("=" * 50 + "\n")
if "flattened_segments" in model_outputs:
results["flattened_segments"] = model_outputs["flattened_segments"]
if "valid_pairs" in model_outputs:
results["valid_pairs"] = model_outputs["valid_pairs"]
if "valid_pairs_metadata" in model_outputs:
results["valid_pairs_metadata"] = model_outputs["valid_pairs_metadata"]
if "visualization_data" in model_outputs:
results["visualization_data"] = model_outputs["visualization_data"]
if self.visualize and "video_frames" in model_outputs and "bboxes" in model_outputs:
frames_tensor = model_outputs["video_frames"]
if isinstance(frames_tensor, torch.Tensor):
frames_np = frames_tensor.detach().cpu().numpy()
else:
frames_np = np.asarray(frames_tensor)
if frames_np.dtype != np.uint8:
if np.issubdtype(frames_np.dtype, np.floating):
max_val = frames_np.max() if frames_np.size else 0.0
scale = 255.0 if max_val <= 1.0 else 1.0
frames_np = (frames_np * scale).clip(0, 255).astype(np.uint8)
else:
frames_np = frames_np.clip(0, 255).astype(np.uint8)
cat_label_lookup: Dict[int, Tuple[str, float]] = {}
for obj_id, preds in model_outputs.get("categorical_predictions", {}).items():
if preds:
prob, label = preds[0]
cat_label_lookup[obj_id] = (label, prob)
unary_preds = model_outputs.get("unary_predictions", {})
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]] = {}
for (frame_id, obj_id), preds in unary_preds.items():
if preds:
unary_lookup.setdefault(frame_id, {})[obj_id] = preds[:1]
binary_preds = model_outputs.get("binary_predictions", {})
binary_lookup: Dict[
int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]
] = {}
for (frame_id, obj_pair), preds in binary_preds.items():
if preds:
binary_lookup.setdefault(frame_id, []).append((obj_pair, preds[:1]))
bboxes = model_outputs["bboxes"]
visualization_data = model_outputs.get("visualization_data", {})
visualizations: Dict[str, Dict[str, Any]] = {}
debug_visualizations = kwargs.get("debug_visualizations")
if debug_visualizations is None:
debug_visualizations = self.debug_visualizations
vine_frame_sets = render_vine_frame_sets(
frames_np,
bboxes,
cat_label_lookup,
unary_lookup,
binary_lookup,
visualization_data.get("sam_masks"),
)
vine_visuals: Dict[str, Dict[str, Any]] = {}
final_frames = vine_frame_sets.get("all", [])
if final_frames:
final_entry: Dict[str, Any] = {"frames": final_frames, "video_path": None}
final_dir = self._prepare_visualization_dir(
"all", enabled=self.visualize
)
final_entry["video_path"] = self._create_temp_video(
np.stack(final_frames, axis=0),
base_dir=final_dir,
prefix="all_visualization",
)
vine_visuals["all"] = final_entry
if debug_visualizations:
sam_masks = visualization_data.get("sam_masks")
if sam_masks:
sam_frames = render_sam_frames(
frames_np, sam_masks, visualization_data.get("dino_labels")
)
sam_entry = {"frames": sam_frames, "video_path": None}
if sam_frames:
sam_dir = self._prepare_visualization_dir(
"sam", enabled=self.visualize
)
sam_entry["video_path"] = self._create_temp_video(
np.stack(sam_frames, axis=0),
base_dir=sam_dir,
prefix="sam_visualization",
)
visualizations["sam"] = sam_entry
dino_labels = visualization_data.get("dino_labels")
if dino_labels:
dino_frames = render_dino_frames(frames_np, bboxes, dino_labels)
dino_entry = {"frames": dino_frames, "video_path": None}
if dino_frames:
dino_dir = self._prepare_visualization_dir(
"dino", enabled=self.visualize
)
dino_entry["video_path"] = self._create_temp_video(
np.stack(dino_frames, axis=0),
base_dir=dino_dir,
prefix="dino_visualization",
)
visualizations["dino"] = dino_entry
for name in ("object", "unary", "binary"):
frames_list = vine_frame_sets.get(name, [])
entry: Dict[str, Any] = {"frames": frames_list, "video_path": None}
if frames_list:
vine_dir = self._prepare_visualization_dir(
name, enabled=self.visualize
)
entry["video_path"] = self._create_temp_video(
np.stack(frames_list, axis=0),
base_dir=vine_dir,
prefix=f"{name}_visualization",
)
vine_visuals[name] = entry
if vine_visuals:
visualizations["vine"] = vine_visuals
if visualizations:
results["visualizations"] = visualizations
return results
# ------------------------------------------------------------------ #
# Summary JSON
# ------------------------------------------------------------------ #
def _generate_summary(self, model_outputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Per-object summary:
{
"num_objects_detected": N,
"objects": {
"<obj_id>": {
"top_categories": [{"label": str, "probability": float}, ...],
"top_unary": [{"frame_id": int, "predicate": str, "probability": float}, ...],
}
}
}
"""
categorical_preds = model_outputs.get("categorical_predictions", {})
unary_preds = model_outputs.get("unary_predictions", {})
unary_by_obj: Dict[int, List[Tuple[float, str, int]]] = {}
for (frame_id, obj_id), preds in unary_preds.items():
for prob, predicate in preds:
prob_val = (
float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
)
unary_by_obj.setdefault(obj_id, []).append((prob_val, predicate, frame_id))
objects_summary: Dict[str, Dict[str, Any]] = {}
all_obj_ids = set(categorical_preds.keys()) | set(unary_by_obj.keys())
for obj_id in sorted(all_obj_ids):
cat_list = categorical_preds.get(obj_id, [])
cat_sorted = sorted(
[
(
float(p.detach().cpu()) if torch.is_tensor(p) else float(p),
label,
)
for p, label in cat_list
],
key=lambda x: x[0],
reverse=True,
)[:3]
top_categories = [
{"label": label, "probability": prob} for prob, label in cat_sorted
]
unary_list = unary_by_obj.get(obj_id, [])
unary_sorted = sorted(unary_list, key=lambda x: x[0], reverse=True)[:3]
top_unary = [
{
"frame_id": int(frame_id),
"predicate": predicate,
"probability": prob,
}
for (prob, predicate, frame_id) in unary_sorted
]
objects_summary[str(obj_id)] = {
"top_categories": top_categories,
"top_unary": top_unary,
}
summary = {
"num_objects_detected": len(objects_summary),
"objects": objects_summary,
}
return summary