Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from collections import defaultdict | |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import torch | |
| MaskType = Union[np.ndarray, torch.Tensor] | |
| def _to_numpy_mask(mask: MaskType) -> np.ndarray: | |
| """ | |
| Convert assorted mask formats to a 2D numpy boolean array. | |
| """ | |
| if isinstance(mask, torch.Tensor): | |
| mask_np = mask.detach().cpu().numpy() | |
| else: | |
| mask_np = np.asarray(mask) | |
| # Remove singleton dimensions at the front/back | |
| while mask_np.ndim > 2 and mask_np.shape[0] == 1: | |
| mask_np = np.squeeze(mask_np, axis=0) | |
| if mask_np.ndim > 2 and mask_np.shape[-1] == 1: | |
| mask_np = np.squeeze(mask_np, axis=-1) | |
| if mask_np.ndim != 2: | |
| raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}") | |
| return mask_np.astype(bool) | |
| def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]: | |
| """ | |
| Compute a bounding box for a 2D boolean mask. | |
| """ | |
| if not mask.any(): | |
| return None | |
| rows, cols = np.nonzero(mask) | |
| y_min, y_max = rows.min(), rows.max() | |
| x_min, x_max = cols.min(), cols.max() | |
| return x_min, y_min, x_max, y_max | |
| def flatten_segments_for_batch( | |
| video_id: int, | |
| segments: Dict[int, Dict[int, MaskType]], | |
| bbox_min_dim: int = 5, | |
| ) -> Dict[str, List]: | |
| """ | |
| Flatten nested segmentation data into batched lists suitable for predicate | |
| models or downstream visualizations. Mirrors the notebook helper but is | |
| robust to differing mask dtypes/shapes. | |
| """ | |
| batched_object_ids: List[Tuple[int, int, int]] = [] | |
| batched_masks: List[np.ndarray] = [] | |
| batched_bboxes: List[Tuple[int, int, int, int]] = [] | |
| frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] | |
| for frame_id, frame_objects in segments.items(): | |
| valid_objects: List[int] = [] | |
| for object_id, raw_mask in frame_objects.items(): | |
| mask = _to_numpy_mask(raw_mask) | |
| bbox = _mask_to_bbox(mask) | |
| if bbox is None: | |
| continue | |
| x_min, y_min, x_max, y_max = bbox | |
| if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim: | |
| continue | |
| valid_objects.append(object_id) | |
| batched_object_ids.append((video_id, frame_id, object_id)) | |
| batched_masks.append(mask) | |
| batched_bboxes.append(bbox) | |
| for i in valid_objects: | |
| for j in valid_objects: | |
| if i == j: | |
| continue | |
| frame_pairs.append((video_id, frame_id, (i, j))) | |
| return { | |
| "object_ids": batched_object_ids, | |
| "masks": batched_masks, | |
| "bboxes": batched_bboxes, | |
| "pairs": frame_pairs, | |
| } | |
| def extract_valid_object_pairs( | |
| batched_object_ids: Sequence[Tuple[int, int, int]], | |
| interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None, | |
| ) -> List[Tuple[int, int, Tuple[int, int]]]: | |
| """ | |
| Filter object pairs per frame. If `interested_object_pairs` is provided, only | |
| emit those combinations when both objects are present; otherwise emit all | |
| permutations (i, j) with i != j for each frame. | |
| """ | |
| frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set) | |
| for vid, fid, oid in batched_object_ids: | |
| frame_to_objects[(vid, fid)].add(oid) | |
| interested = ( | |
| list(interested_object_pairs) | |
| if interested_object_pairs is not None | |
| else None | |
| ) | |
| valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] | |
| for (vid, fid), object_ids in frame_to_objects.items(): | |
| if interested: | |
| for src, dst in interested: | |
| if src in object_ids and dst in object_ids: | |
| valid_pairs.append((vid, fid, (src, dst))) | |
| else: | |
| for src in object_ids: | |
| for dst in object_ids: | |
| if src == dst: | |
| continue | |
| valid_pairs.append((vid, fid, (src, dst))) | |
| return valid_pairs | |