LASER / vine_hf /vis_utils.py
ASethi04's picture
updates
21f4849
raw
history blame
40 kB
import os
import sys
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
import math
from matplotlib.patches import Rectangle
import itertools
from typing import Any, Dict, List, Tuple, Optional, Union
# Add src/ to sys.path so LASER, video-sam2, GroundingDINO are importable
current_dir = Path(__file__).resolve().parent
src_dir = current_dir.parent / "src"
if src_dir.is_dir() and str(src_dir) not in sys.path:
sys.path.insert(0, str(src_dir))
from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
########################################################################################
########## Visualization Library ########
########################################################################################
# This module renders SAM masks, GroundingDINO boxes, and VINE predictions.
#
# Conventions (RGB frames, pixel coords):
# - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3).
# - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted.
# - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1.
#
# Per-frame stores use one of:
# - Dict[int(frame_id) -> Dict[int(obj_id) -> value]]
# - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order)
#
# Renderer inputs/outputs:
# 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray]
# - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor.
# - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks.
#
# 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray]
# - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2].
#
# 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None)
# -> List[np.ndarray] (the "all" view)
# - cat_label_lookup: Dict[obj_id, (label: str, prob: float)]
# - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]]
# - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]]
# - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist.
#
# Ground-truth helpers used by plotting utilities:
# - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)].
#
# All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
########################################################################################
def clean_label(label):
"""Replace underscores and slashes with spaces for uniformity."""
return label.replace("_", " ").replace("/", " ")
# Should be performed somewhere else I believe
def format_cate_preds(cate_preds):
# Group object predictions from the model output.
obj_pred_dict = {}
for (oid, label), prob in cate_preds.items():
# Clean the predicted label as well.
clean_pred = clean_label(label)
if oid not in obj_pred_dict:
obj_pred_dict[oid] = []
obj_pred_dict[oid].append((clean_pred, prob))
for oid in obj_pred_dict:
obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
return obj_pred_dict
def format_binary_cate_preds(binary_preds):
frame_binary_preds = []
for key, score in binary_preds.items():
# Expect key format: (frame_id, (subject, object), predicted_relation)
try:
f_id, (subj, obj), pred_rel = key
frame_binary_preds.append((f_id, subj, obj, pred_rel, score))
except Exception as e:
print("Skipping key with unexpected format:", key)
continue
frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
return frame_binary_preds
_FONT = cv2.FONT_HERSHEY_SIMPLEX
def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]:
if mask is None:
return None
if isinstance(mask, torch.Tensor):
mask_np = mask.detach().cpu().numpy()
else:
mask_np = np.asarray(mask)
if mask_np.ndim == 0:
return None
if mask_np.ndim == 3:
mask_np = np.squeeze(mask_np)
if mask_np.ndim != 2:
return None
if mask_np.dtype == bool:
return mask_np
return mask_np > 0
def _sanitize_bbox(
bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int
) -> Optional[Tuple[int, int, int, int]]:
if bbox is None:
return None
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
x1, y1, x2, y2 = [float(b) for b in bbox[:4]]
elif isinstance(bbox, np.ndarray) and bbox.size >= 4:
x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]]
else:
return None
x1 = int(np.clip(round(x1), 0, width - 1))
y1 = int(np.clip(round(y1), 0, height - 1))
x2 = int(np.clip(round(x2), 0, width - 1))
y2 = int(np.clip(round(y2), 0, height - 1))
if x2 <= x1 or y2 <= y1:
return None
return (x1, y1, x2, y2)
def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]:
color = get_color(obj_id)
rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]]
return (rgb[2], rgb[1], rgb[0])
def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]:
return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color)
def _draw_label_block(
image: np.ndarray,
lines: List[str],
anchor: Tuple[int, int],
color: Tuple[int, int, int],
font_scale: float = 0.5,
thickness: int = 1,
direction: str = "up",
) -> None:
if not lines:
return
img_h, img_w = image.shape[:2]
x, y = anchor
x = int(np.clip(x, 0, img_w - 1))
y_cursor = int(np.clip(y, 0, img_h - 1))
bg_color = _background_color(color)
if direction == "down":
for text in lines:
text = str(text)
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
left_x = x
right_x = min(left_x + tw + 8, img_w - 1)
top_y = int(np.clip(y_cursor + 6, 0, img_h - 1))
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
if bottom_y <= top_y:
break
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
text_x = left_x + 4
text_y = min(bottom_y - baseline - 2, img_h - 1)
cv2.putText(
image,
text,
(text_x, text_y),
_FONT,
font_scale,
(0, 0, 0),
thickness,
cv2.LINE_AA,
)
y_cursor = bottom_y
else:
for text in lines:
text = str(text)
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
top_y = max(y_cursor - th - baseline - 6, 0)
left_x = x
right_x = min(left_x + tw + 8, img_w - 1)
bottom_y = min(top_y + th + baseline + 6, img_h - 1)
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
text_x = left_x + 4
text_y = min(bottom_y - baseline - 2, img_h - 1)
cv2.putText(
image,
text,
(text_x, text_y),
_FONT,
font_scale,
(0, 0, 0),
thickness,
cv2.LINE_AA,
)
y_cursor = top_y
def _draw_centered_label(
image: np.ndarray,
text: str,
center: Tuple[int, int],
color: Tuple[int, int, int],
font_scale: float = 0.5,
thickness: int = 1,
) -> None:
text = str(text)
img_h, img_w = image.shape[:2]
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
cx = int(np.clip(center[0], 0, img_w - 1))
cy = int(np.clip(center[1], 0, img_h - 1))
left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1))
top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
cv2.rectangle(
image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1
)
text_x = left_x + 4
text_y = min(bottom_y - baseline - 2, img_h - 1)
cv2.putText(
image,
text,
(text_x, text_y),
_FONT,
font_scale,
(0, 0, 0),
thickness,
cv2.LINE_AA,
)
def _extract_frame_entities(
store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int
) -> Dict[int, Any]:
if isinstance(store, dict):
frame_entry = store.get(frame_idx, {})
elif isinstance(store, list) and 0 <= frame_idx < len(store):
frame_entry = store[frame_idx]
else:
frame_entry = {}
if isinstance(frame_entry, dict):
return frame_entry
if isinstance(frame_entry, list):
return {i: value for i, value in enumerate(frame_entry)}
return {}
def _label_anchor_and_direction(
bbox: Tuple[int, int, int, int],
position: str,
) -> Tuple[Tuple[int, int], str]:
x1, y1, x2, y2 = bbox
if position == "bottom":
return (x1, y2), "down"
return (x1, y1), "up"
def _draw_bbox_with_label(
image: np.ndarray,
bbox: Tuple[int, int, int, int],
obj_id: int,
title: Optional[str] = None,
sub_lines: Optional[List[str]] = None,
label_position: str = "top",
) -> None:
color = _object_color_bgr(obj_id)
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
head = title if title else f"#{obj_id}"
if not head.startswith("#"):
head = f"#{obj_id} {head}"
lines = [head]
if sub_lines:
lines.extend(sub_lines)
anchor, direction = _label_anchor_and_direction(bbox, label_position)
_draw_label_block(image, lines, anchor, color, direction=direction)
def render_sam_frames(
frames: Union[np.ndarray, List[np.ndarray]],
sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None],
dino_labels: Optional[Dict[int, str]] = None,
) -> List[np.ndarray]:
results: List[np.ndarray] = []
frames_iterable = frames if isinstance(frames, list) else list(frames)
dino_labels = dino_labels or {}
for frame_idx, frame in enumerate(frames_iterable):
if frame is None:
continue
frame_rgb = np.asarray(frame)
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
overlay = frame_bgr.astype(np.float32)
masks_for_frame = _extract_frame_entities(sam_masks, frame_idx)
for obj_id, mask in masks_for_frame.items():
mask_np = _to_numpy_mask(mask)
if mask_np is None or not np.any(mask_np):
continue
color = _object_color_bgr(obj_id)
alpha = 0.45
overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(
color, dtype=np.float32
)
annotated = np.clip(overlay, 0, 255).astype(np.uint8)
frame_h, frame_w = annotated.shape[:2]
for obj_id, mask in masks_for_frame.items():
mask_np = _to_numpy_mask(mask)
if mask_np is None or not np.any(mask_np):
continue
bbox = mask_to_bbox(mask_np)
bbox = _sanitize_bbox(bbox, frame_w, frame_h)
if not bbox:
continue
label = dino_labels.get(obj_id)
title = f"{label}" if label else None
_draw_bbox_with_label(annotated, bbox, obj_id, title=title)
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
return results
def render_dino_frames(
frames: Union[np.ndarray, List[np.ndarray]],
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
dino_labels: Optional[Dict[int, str]] = None,
) -> List[np.ndarray]:
results: List[np.ndarray] = []
frames_iterable = frames if isinstance(frames, list) else list(frames)
dino_labels = dino_labels or {}
for frame_idx, frame in enumerate(frames_iterable):
if frame is None:
continue
frame_rgb = np.asarray(frame)
annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
frame_h, frame_w = annotated.shape[:2]
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
for obj_id, bbox_values in frame_bboxes.items():
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
if not bbox:
continue
label = dino_labels.get(obj_id)
title = f"{label}" if label else None
_draw_bbox_with_label(annotated, bbox, obj_id, title=title)
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
return results
def render_vine_frame_sets(
frames: Union[np.ndarray, List[np.ndarray]],
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
cat_label_lookup: Dict[int, Tuple[str, float]],
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
masks: Union[
Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None
] = None,
binary_confidence_threshold: float = 0.0,
) -> Dict[str, List[np.ndarray]]:
frame_groups: Dict[str, List[np.ndarray]] = {
"object": [],
"unary": [],
"binary": [],
"all": [],
}
frames_iterable = frames if isinstance(frames, list) else list(frames)
for frame_idx, frame in enumerate(frames_iterable):
if frame is None:
continue
frame_rgb = np.asarray(frame)
base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
frame_h, frame_w = base_bgr.shape[:2]
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
frame_masks = (
_extract_frame_entities(masks, frame_idx) if masks is not None else {}
)
objects_bgr = base_bgr.copy()
unary_bgr = base_bgr.copy()
binary_bgr = base_bgr.copy()
all_bgr = base_bgr.copy()
bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {}
unary_lines_lookup: Dict[int, List[str]] = {}
titles_lookup: Dict[int, Optional[str]] = {}
for obj_id, bbox_values in frame_bboxes.items():
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
if not bbox:
continue
bbox_lookup[obj_id] = bbox
cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None))
title_parts = []
if cat_label:
if cat_prob is not None:
title_parts.append(f"{cat_label} {cat_prob:.2f}")
else:
title_parts.append(cat_label)
titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None
unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, [])
unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds]
unary_lines_lookup[obj_id] = unary_lines
for obj_id, bbox in bbox_lookup.items():
unary_lines = unary_lines_lookup.get(obj_id, [])
if not unary_lines:
continue
mask_raw = frame_masks.get(obj_id)
mask_np = _to_numpy_mask(mask_raw)
if mask_np is None or not np.any(mask_np):
continue
color = np.array(_object_color_bgr(obj_id), dtype=np.float32)
alpha = 0.45
for target in (unary_bgr, all_bgr):
target_vals = target[mask_np].astype(np.float32)
blended = (1.0 - alpha) * target_vals + alpha * color
target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8)
for obj_id, bbox in bbox_lookup.items():
title = titles_lookup.get(obj_id)
unary_lines = unary_lines_lookup.get(obj_id, [])
_draw_bbox_with_label(
objects_bgr, bbox, obj_id, title=title, label_position="top"
)
_draw_bbox_with_label(
unary_bgr, bbox, obj_id, title=title, label_position="top"
)
if unary_lines:
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
_draw_label_block(
unary_bgr,
unary_lines,
anchor,
_object_color_bgr(obj_id),
direction=direction,
)
_draw_bbox_with_label(
binary_bgr, bbox, obj_id, title=title, label_position="top"
)
_draw_bbox_with_label(
all_bgr, bbox, obj_id, title=title, label_position="top"
)
if unary_lines:
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
_draw_label_block(
all_bgr,
unary_lines,
anchor,
_object_color_bgr(obj_id),
direction=direction,
)
# First pass: collect all pairs above threshold and deduplicate bidirectional pairs
pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation)
for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
if len(obj_pair) != 2 or not relation_preds:
continue
subj_id, obj_id = obj_pair
subj_bbox = bbox_lookup.get(subj_id)
obj_bbox = bbox_lookup.get(obj_id)
if not subj_bbox or not obj_bbox:
continue
prob, relation = relation_preds[0]
# Filter by confidence threshold
if prob < binary_confidence_threshold:
continue
# Create canonical key (smaller_id, larger_id) for deduplication
pair_key = (min(subj_id, obj_id), max(subj_id, obj_id))
# Keep the higher confidence direction
if pair_key not in pairs_to_draw or prob > pairs_to_draw[pair_key][2]:
pairs_to_draw[pair_key] = (subj_id, obj_id, prob, relation)
# Second pass: draw the selected pairs
for subj_id, obj_id, prob, relation in pairs_to_draw.values():
subj_bbox = bbox_lookup.get(subj_id)
obj_bbox = bbox_lookup.get(obj_id)
start, end = relation_line(subj_bbox, obj_bbox)
color = tuple(
int(c)
for c in np.clip(
(
np.array(_object_color_bgr(subj_id), dtype=np.float32)
+ np.array(_object_color_bgr(obj_id), dtype=np.float32)
)
/ 2.0,
0,
255,
)
)
label_text = f"{relation} {prob:.2f}"
mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
# Draw arrowed lines showing direction from subject to object (smaller arrow tip)
cv2.arrowedLine(
binary_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05
)
cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
_draw_centered_label(binary_bgr, label_text, mid_point, color)
_draw_centered_label(all_bgr, label_text, mid_point, color)
frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB))
frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB))
frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB))
frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB))
return frame_groups
def render_vine_frames(
frames: Union[np.ndarray, List[np.ndarray]],
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
cat_label_lookup: Dict[int, Tuple[str, float]],
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
masks: Union[
Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None
] = None,
binary_confidence_threshold: float = 0.0,
) -> List[np.ndarray]:
return render_vine_frame_sets(
frames,
bboxes,
cat_label_lookup,
unary_lookup,
binary_lookup,
masks,
binary_confidence_threshold,
).get("all", [])
def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
all_colors = []
all_texts = []
for obj_id, bbox, gt_label in gt_labels:
preds = obj_pred_dict.get(obj_id, [])
if len(preds) == 0:
top1 = "N/A"
box_color = (0, 0, 255) # bright red if no prediction
else:
top1, prob1 = preds[0]
topk_labels = [p[0] for p in preds[:topk_object]]
# Compare cleaned labels.
if top1.lower() == gt_label.lower():
box_color = (0, 255, 0) # bright green for correct
elif gt_label.lower() in [p.lower() for p in topk_labels]:
box_color = (0, 165, 255) # bright orange for partial match
else:
box_color = (0, 0, 255) # bright red for incorrect
label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
all_colors.append(box_color)
all_texts.append(label_text)
return all_colors, all_texts
def plot_unary(frame_img, gt_labels, all_colors, all_texts):
for (obj_id, bbox, gt_label), box_color, label_text in zip(
gt_labels, all_colors, all_texts
):
x1, y1, x2, y2 = map(int, bbox)
cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
(tw, th), baseline = cv2.getTextSize(
label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1
)
cv2.putText(
frame_img,
label_text,
(x1, y1 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
cv2.LINE_AA,
)
return frame_img
def get_white_pane(
pane_height,
pane_width=600,
header_height=50,
header_font=cv2.FONT_HERSHEY_SIMPLEX,
header_font_scale=0.7,
header_thickness=2,
header_color=(0, 0, 0),
):
# Create an expanded white pane to display text info.
white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
# --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
left_width = int(pane_width * 0.6)
right_width = pane_width - left_width
left_pane = white_pane[:, :left_width, :].copy()
right_pane = white_pane[:, left_width:, :].copy()
cv2.putText(
left_pane,
"Binary Predictions",
(10, header_height - 30),
header_font,
header_font_scale,
header_color,
header_thickness,
cv2.LINE_AA,
)
cv2.putText(
right_pane,
"Ground Truth",
(10, header_height - 30),
header_font,
header_font_scale,
header_color,
header_thickness,
cv2.LINE_AA,
)
return white_pane
# This is for ploting binary prediction results with frame-based scene graphs
def plot_binary_sg(
frame_img,
white_pane,
bin_preds,
gt_relations,
topk_binary,
header_height=50,
indicator_size=20,
pane_width=600,
):
# Leave vertical space for the headers.
line_height = 30 # vertical spacing per line
x_text = 10 # left margin for text
y_text_left = header_height + 10 # starting y for left pane text
y_text_right = header_height + 10 # starting y for right pane text
# Left section: top-k binary predictions.
left_width = int(pane_width * 0.6)
right_width = pane_width - left_width
left_pane = white_pane[:, :left_width, :].copy()
right_pane = white_pane[:, left_width:, :].copy()
for subj, pred_rel, obj, score in bin_preds[:topk_binary]:
correct = any(
(subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
for gt in gt_relations
)
indicator_color = (0, 255, 0) if correct else (0, 0, 255)
cv2.rectangle(
left_pane,
(x_text, y_text_left - indicator_size + 5),
(x_text + indicator_size, y_text_left + 5),
indicator_color,
-1,
)
text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
cv2.putText(
left_pane,
text,
(x_text + indicator_size + 5, y_text_left + 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 0),
1,
cv2.LINE_AA,
)
y_text_left += line_height
# Right section: ground truth binary relations.
for gt in gt_relations:
if len(gt) != 3:
continue
text = f"{gt[0]} - {gt[2]} - {gt[1]}"
cv2.putText(
right_pane,
text,
(x_text, y_text_right + 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 0),
1,
cv2.LINE_AA,
)
y_text_right += line_height
# Combine the two text panes and then with the frame image.
combined_pane = np.hstack((left_pane, right_pane))
combined_image = np.hstack((frame_img, combined_pane))
return combined_image
def visualized_frame(
frame_img,
bboxes,
object_ids,
gt_labels,
cate_preds,
binary_preds,
gt_relations,
topk_object,
topk_binary,
phase="unary",
):
"""Return the combined annotated frame for frame index i as an image (in BGR)."""
# Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
# --- Process Object Predictions (for overlaying bboxes) ---
if phase == "unary":
objs = []
for (_, f_id, obj_id), bbox, gt_label in zip(object_ids, bboxes, gt_labels):
gt_label = clean_label(gt_label)
objs.append((obj_id, bbox, gt_label))
formatted_cate_preds = format_cate_preds(cate_preds)
all_colors, all_texts = color_for_cate_correctness(
formatted_cate_preds, gt_labels, topk_object
)
updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
return updated_frame_img
else:
# --- Process Binary Predictions & Ground Truth for the Text Pane ---
formatted_binary_preds = format_binary_cate_preds(binary_preds)
# Ground truth binary relations for the frame.
# Clean ground truth relations.
gt_relations = [
(clean_label(str(s)), clean_label(str(o)), clean_label(rel))
for s, o, rel in gt_relations
]
pane_width = 600 # increased pane width for more horizontal space
pane_height = frame_img.shape[0]
# --- Add header labels to each text pane with extra space ---
header_height = 50 # increased header space
white_pane = get_white_pane(
pane_height, pane_width, header_height=header_height
)
combined_image = plot_binary_sg(
frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary
)
return combined_image
def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
# Ensure mask is a numpy array
mask = np.array(mask)
# Handle different mask shapes
if mask.ndim == 3:
# (1, H, W) -> (H, W)
if mask.shape[0] == 1:
mask = mask.squeeze(0)
# (H, W, 1) -> (H, W)
elif mask.shape[2] == 1:
mask = mask.squeeze(2)
# Now mask should be (H, W)
assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}"
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
else:
cmap = plt.get_cmap("gist_rainbow")
cmap_idx = 0 if obj_id is None else obj_id
color = list(cmap((cmap_idx * 47) % 256))
color[3] = 0.5
color = np.array(color)
# Expand mask to (H, W, 1) for broadcasting
mask_expanded = mask[..., None]
mask_image = mask_expanded * color.reshape(1, 1, -1)
# draw a box around the mask with the det_class as the label
if not det_class is None:
# Find the bounding box coordinates
y_indices, x_indices = np.where(mask > 0)
if y_indices.size > 0 and x_indices.size > 0:
x_min, x_max = x_indices.min(), x_indices.max()
y_min, y_max = y_indices.min(), y_indices.max()
rect = Rectangle(
(x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=1.5,
edgecolor=color[:3],
facecolor="none",
alpha=color[3],
)
ax.add_patch(rect)
ax.text(
x_min,
y_min - 5,
f"{det_class}",
color="white",
fontsize=6,
backgroundcolor=np.array(color),
alpha=1,
)
ax.imshow(mask_image)
def save_mask_one_image(frame_image, masks, save_path):
"""Render masks on top of a frame and store the visualization on disk."""
fig, ax = plt.subplots(1, figsize=(6, 6))
frame_np = (
frame_image.detach().cpu().numpy()
if torch.is_tensor(frame_image)
else np.asarray(frame_image)
)
frame_np = np.ascontiguousarray(frame_np)
if isinstance(masks, dict):
mask_iter = masks.items()
else:
mask_iter = enumerate(masks)
prepared_masks = {
obj_id: (
mask.detach().cpu().numpy() if torch.is_tensor(mask) else np.asarray(mask)
)
for obj_id, mask in mask_iter
}
ax.imshow(frame_np)
ax.axis("off")
for obj_id, mask_np in prepared_masks.items():
show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False)
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
plt.close(fig)
return save_path
def get_video_masks_visualization(
video_tensor,
video_masks,
video_id,
video_save_base_dir,
oid_class_pred=None,
sample_rate=1,
):
video_save_dir = os.path.join(video_save_base_dir, video_id)
if not os.path.exists(video_save_dir):
os.makedirs(video_save_dir, exist_ok=True)
for frame_id, image in enumerate(video_tensor):
if frame_id not in video_masks:
print("No mask for Frame", frame_id)
continue
masks = video_masks[frame_id]
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
get_mask_one_image(image, masks, oid_class_pred)
def get_mask_one_image(frame_image, masks, oid_class_pred=None):
# Create a figure and axis
fig, ax = plt.subplots(1, figsize=(6, 6))
# Display the frame image
ax.imshow(frame_image)
ax.axis("off")
if type(masks) == list:
masks = {i: m for i, m in enumerate(masks)}
# Add the masks
for obj_id, mask in masks.items():
det_class = (
f"{obj_id}. {oid_class_pred[obj_id]}"
if not oid_class_pred is None
else None
)
show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
# Show the plot
return fig, ax
def save_video(frames, output_filename, output_fps):
# --- Create a video from all frames ---
num_frames = len(frames)
frame_h, frame_w = frames.shape[:2]
# Use a codec supported by VS Code (H.264 via 'avc1').
fourcc = cv2.VideoWriter_fourcc(*"avc1")
out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
print(f"Processing {num_frames} frames...")
for i in range(num_frames):
vis_frame = get_visualized_frame(i)
out.write(vis_frame)
if i % 10 == 0:
print(f"Processed frame {i + 1}/{num_frames}")
out.release()
print(f"Video saved as {output_filename}")
def list_depth(lst):
"""Calculates the depth of a nested list."""
if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
return 0
elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (
isinstance(lst, list) and len(lst) == 0
):
return 1
else:
return 1 + max(list_depth(item) for item in lst)
def normalize_prompt(points, labels):
if list_depth(points) == 3:
points = torch.stack([p.unsqueeze(0) for p in points])
labels = torch.stack([l.unsqueeze(0) for l in labels])
return points, labels
def show_box(box, ax, object_id):
if len(box) == 0:
return
cmap = plt.get_cmap("gist_rainbow")
cmap_idx = 0 if object_id is None else object_id
color = list(cmap((cmap_idx * 47) % 256))
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=2)
)
def show_points(coords, labels, ax, object_id=None, marker_size=375):
if len(labels) == 0:
return
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
cmap = plt.get_cmap("gist_rainbow")
cmap_idx = 0 if object_id is None else object_id
color = list(cmap((cmap_idx * 47) % 256))
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="P",
s=marker_size,
edgecolor=color,
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="s",
s=marker_size,
edgecolor=color,
linewidth=1.25,
)
def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
# Create a figure and axis
fig, ax = plt.subplots(1, figsize=(6, 6))
# Display the frame image
ax.imshow(frame_image)
ax.axis("off")
points, labels = normalize_prompt(points, labels)
if type(boxes) == torch.Tensor:
for object_id, box in enumerate(boxes):
# Add the bounding boxes
if not box is None:
show_box(box.cpu(), ax, object_id=object_id)
elif type(boxes) == dict:
for object_id, box in boxes.items():
# Add the bounding boxes
if not box is None:
show_box(box.cpu(), ax, object_id=object_id)
elif type(boxes) == list and len(boxes) == 0:
pass
else:
raise Exception()
for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
if not len(point_ls) == 0:
show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
# Show the plot
plt.savefig(save_path)
plt.close()
def save_video_prompts_visualization(
video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir
):
video_save_dir = os.path.join(video_save_base_dir, video_id)
if not os.path.exists(video_save_dir):
os.makedirs(video_save_dir, exist_ok=True)
for frame_id, image in enumerate(video_tensor):
boxes, points, labels = [], [], []
if frame_id in video_boxes:
boxes = video_boxes[frame_id]
if frame_id in video_points:
points = video_points[frame_id]
if frame_id in video_labels:
labels = video_labels[frame_id]
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
save_prompts_one_image(image, boxes, points, labels, save_path)
def save_video_masks_visualization(
video_tensor,
video_masks,
video_id,
video_save_base_dir,
oid_class_pred=None,
sample_rate=1,
):
video_save_dir = os.path.join(video_save_base_dir, video_id)
if not os.path.exists(video_save_dir):
os.makedirs(video_save_dir, exist_ok=True)
for frame_id, image in enumerate(video_tensor):
if random.random() > sample_rate:
continue
if frame_id not in video_masks:
print("No mask for Frame", frame_id)
continue
masks = video_masks[frame_id]
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
save_mask_one_image(image, masks, save_path)
def get_color(obj_id, cmap_name="gist_rainbow", alpha=0.5):
cmap = plt.get_cmap(cmap_name)
cmap_idx = 0 if obj_id is None else obj_id
color = list(cmap((cmap_idx * 47) % 256))
color[3] = 0.5
color = np.array(color)
return color
def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
def relation_line(
bbox1: Tuple[int, int, int, int],
bbox2: Tuple[int, int, int, int],
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
"""
Returns integer pixel centers suitable for drawing a relation line. For
coincident boxes, nudges the target center to ensure the segment has span.
"""
center1 = _bbox_center(bbox1)
center2 = _bbox_center(bbox2)
if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(
center1[1], center2[1], abs_tol=1e-3
):
offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
center2 = (center2[0] + offset, center2[1])
start = (int(round(center1[0])), int(round(center1[1])))
end = (int(round(center2[0])), int(round(center2[1])))
if start == end:
end = (end[0] + 1, end[1])
return start, end
def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
# Create a figure and axis
fig, ax = plt.subplots(1, figsize=(6, 6))
# Display the frame image
ax.imshow(frame_image)
ax.axis("off")
all_objs_to_show = set()
all_lines_to_show = []
# print(rel_pred_ls[0])
for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
all_objs_to_show.add(from_obj_id)
all_objs_to_show.add(to_obj_id)
from_mask = masks[from_obj_id]
bbox1 = mask_to_bbox(from_mask)
to_mask = masks[to_obj_id]
bbox2 = mask_to_bbox(to_mask)
c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
line_color = get_color(from_obj_id)
face_color = get_color(to_obj_id)
line = c1, c2, face_color, line_color, rel_text
all_lines_to_show.append(line)
masks_to_show = {}
for oid in all_objs_to_show:
masks_to_show[oid] = masks[oid]
# Add the masks
for obj_id, mask in masks_to_show.items():
show_mask(mask, ax, obj_id=obj_id, random_color=False)
for (from_pt_x, from_pt_y), (
to_pt_x,
to_pt_y,
), face_color, line_color, rel_text in all_lines_to_show:
plt.plot(
[from_pt_x, to_pt_x],
[from_pt_y, to_pt_y],
color=line_color,
linestyle="-",
linewidth=3,
)
mid_pt_x = (from_pt_x + to_pt_x) / 2
mid_pt_y = (from_pt_y + to_pt_y) / 2
ax.text(
mid_pt_x - 5,
mid_pt_y,
rel_text,
color="white",
fontsize=6,
backgroundcolor=np.array(line_color),
bbox=dict(
facecolor=face_color, edgecolor=line_color, boxstyle="round,pad=1"
),
alpha=1,
)
# Show the plot
return fig, ax