from __future__ import annotations from collections import defaultdict from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch MaskType = Union[np.ndarray, torch.Tensor] def _to_numpy_mask(mask: MaskType) -> np.ndarray: """ Convert assorted mask formats to a 2D numpy boolean array. """ if isinstance(mask, torch.Tensor): mask_np = mask.detach().cpu().numpy() else: mask_np = np.asarray(mask) # Remove singleton dimensions at the front/back while mask_np.ndim > 2 and mask_np.shape[0] == 1: mask_np = np.squeeze(mask_np, axis=0) if mask_np.ndim > 2 and mask_np.shape[-1] == 1: mask_np = np.squeeze(mask_np, axis=-1) if mask_np.ndim != 2: raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}") return mask_np.astype(bool) def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]: """ Compute a bounding box for a 2D boolean mask. """ if not mask.any(): return None rows, cols = np.nonzero(mask) y_min, y_max = rows.min(), rows.max() x_min, x_max = cols.min(), cols.max() return x_min, y_min, x_max, y_max def flatten_segments_for_batch( video_id: int, segments: Dict[int, Dict[int, MaskType]], bbox_min_dim: int = 5, ) -> Dict[str, List]: """ Flatten nested segmentation data into batched lists suitable for predicate models or downstream visualizations. Mirrors the notebook helper but is robust to differing mask dtypes/shapes. """ batched_object_ids: List[Tuple[int, int, int]] = [] batched_masks: List[np.ndarray] = [] batched_bboxes: List[Tuple[int, int, int, int]] = [] frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] for frame_id, frame_objects in segments.items(): valid_objects: List[int] = [] for object_id, raw_mask in frame_objects.items(): mask = _to_numpy_mask(raw_mask) bbox = _mask_to_bbox(mask) if bbox is None: continue x_min, y_min, x_max, y_max = bbox if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim: continue valid_objects.append(object_id) batched_object_ids.append((video_id, frame_id, object_id)) batched_masks.append(mask) batched_bboxes.append(bbox) for i in valid_objects: for j in valid_objects: if i == j: continue frame_pairs.append((video_id, frame_id, (i, j))) return { "object_ids": batched_object_ids, "masks": batched_masks, "bboxes": batched_bboxes, "pairs": frame_pairs, } def extract_valid_object_pairs( batched_object_ids: Sequence[Tuple[int, int, int]], interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None, ) -> List[Tuple[int, int, Tuple[int, int]]]: """ Filter object pairs per frame. If `interested_object_pairs` is provided, only emit those combinations when both objects are present; otherwise emit all permutations (i, j) with i != j for each frame. """ frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set) for vid, fid, oid in batched_object_ids: frame_to_objects[(vid, fid)].add(oid) interested = ( list(interested_object_pairs) if interested_object_pairs is not None else None ) valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] for (vid, fid), object_ids in frame_to_objects.items(): if interested: for src, dst in interested: if src in object_ids and dst in object_ids: valid_pairs.append((vid, fid, (src, dst))) else: for src in object_ids: for dst in object_ids: if src == dst: continue valid_pairs.append((vid, fid, (src, dst))) return valid_pairs