LASER / vine_hf /flattening.py
ASethi04's picture
updates
f9a6349
raw
history blame
4.08 kB
from __future__ import annotations
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
MaskType = Union[np.ndarray, torch.Tensor]
def _to_numpy_mask(mask: MaskType) -> np.ndarray:
"""
Convert assorted mask formats to a 2D numpy boolean array.
"""
if isinstance(mask, torch.Tensor):
mask_np = mask.detach().cpu().numpy()
else:
mask_np = np.asarray(mask)
# Remove singleton dimensions at the front/back
while mask_np.ndim > 2 and mask_np.shape[0] == 1:
mask_np = np.squeeze(mask_np, axis=0)
if mask_np.ndim > 2 and mask_np.shape[-1] == 1:
mask_np = np.squeeze(mask_np, axis=-1)
if mask_np.ndim != 2:
raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}")
return mask_np.astype(bool)
def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
"""
Compute a bounding box for a 2D boolean mask.
"""
if not mask.any():
return None
rows, cols = np.nonzero(mask)
y_min, y_max = rows.min(), rows.max()
x_min, x_max = cols.min(), cols.max()
return x_min, y_min, x_max, y_max
def flatten_segments_for_batch(
video_id: int,
segments: Dict[int, Dict[int, MaskType]],
bbox_min_dim: int = 5,
) -> Dict[str, List]:
"""
Flatten nested segmentation data into batched lists suitable for predicate
models or downstream visualizations. Mirrors the notebook helper but is
robust to differing mask dtypes/shapes.
"""
batched_object_ids: List[Tuple[int, int, int]] = []
batched_masks: List[np.ndarray] = []
batched_bboxes: List[Tuple[int, int, int, int]] = []
frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
for frame_id, frame_objects in segments.items():
valid_objects: List[int] = []
for object_id, raw_mask in frame_objects.items():
mask = _to_numpy_mask(raw_mask)
bbox = _mask_to_bbox(mask)
if bbox is None:
continue
x_min, y_min, x_max, y_max = bbox
if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim:
continue
valid_objects.append(object_id)
batched_object_ids.append((video_id, frame_id, object_id))
batched_masks.append(mask)
batched_bboxes.append(bbox)
for i in valid_objects:
for j in valid_objects:
if i == j:
continue
frame_pairs.append((video_id, frame_id, (i, j)))
return {
"object_ids": batched_object_ids,
"masks": batched_masks,
"bboxes": batched_bboxes,
"pairs": frame_pairs,
}
def extract_valid_object_pairs(
batched_object_ids: Sequence[Tuple[int, int, int]],
interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None,
) -> List[Tuple[int, int, Tuple[int, int]]]:
"""
Filter object pairs per frame. If `interested_object_pairs` is provided, only
emit those combinations when both objects are present; otherwise emit all
permutations (i, j) with i != j for each frame.
"""
frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set)
for vid, fid, oid in batched_object_ids:
frame_to_objects[(vid, fid)].add(oid)
interested = (
list(interested_object_pairs)
if interested_object_pairs is not None
else None
)
valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
for (vid, fid), object_ids in frame_to_objects.items():
if interested:
for src, dst in interested:
if src in object_ids and dst in object_ids:
valid_pairs.append((vid, fid, (src, dst)))
else:
for src in object_ids:
for dst in object_ids:
if src == dst:
continue
valid_pairs.append((vid, fid, (src, dst)))
return valid_pairs