Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import random | |
| import math | |
| from matplotlib.patches import Rectangle | |
| import itertools | |
| from typing import Any, Dict, List, Tuple, Optional, Union | |
| # Add src/ to sys.path so LASER, video-sam2, GroundingDINO are importable | |
| current_dir = Path(__file__).resolve().parent | |
| src_dir = current_dir.parent / "src" | |
| if src_dir.is_dir() and str(src_dir) not in sys.path: | |
| sys.path.insert(0, str(src_dir)) | |
| from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox | |
| ######################################################################################## | |
| ########## Visualization Library ######## | |
| ######################################################################################## | |
| # This module renders SAM masks, GroundingDINO boxes, and VINE predictions. | |
| # | |
| # Conventions (RGB frames, pixel coords): | |
| # - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3). | |
| # - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted. | |
| # - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1. | |
| # | |
| # Per-frame stores use one of: | |
| # - Dict[int(frame_id) -> Dict[int(obj_id) -> value]] | |
| # - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order) | |
| # | |
| # Renderer inputs/outputs: | |
| # 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray] | |
| # - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor. | |
| # - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks. | |
| # | |
| # 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray] | |
| # - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2]. | |
| # | |
| # 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None) | |
| # -> List[np.ndarray] (the "all" view) | |
| # - cat_label_lookup: Dict[obj_id, (label: str, prob: float)] | |
| # - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]] | |
| # - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]] | |
| # - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist. | |
| # | |
| # Ground-truth helpers used by plotting utilities: | |
| # - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)]. | |
| # | |
| # All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing. | |
| ######################################################################################## | |
| def clean_label(label): | |
| """Replace underscores and slashes with spaces for uniformity.""" | |
| return label.replace("_", " ").replace("/", " ") | |
| # Should be performed somewhere else I believe | |
| def format_cate_preds(cate_preds): | |
| # Group object predictions from the model output. | |
| obj_pred_dict = {} | |
| for (oid, label), prob in cate_preds.items(): | |
| # Clean the predicted label as well. | |
| clean_pred = clean_label(label) | |
| if oid not in obj_pred_dict: | |
| obj_pred_dict[oid] = [] | |
| obj_pred_dict[oid].append((clean_pred, prob)) | |
| for oid in obj_pred_dict: | |
| obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True) | |
| return obj_pred_dict | |
| def format_binary_cate_preds(binary_preds): | |
| frame_binary_preds = [] | |
| for key, score in binary_preds.items(): | |
| # Expect key format: (frame_id, (subject, object), predicted_relation) | |
| try: | |
| f_id, (subj, obj), pred_rel = key | |
| frame_binary_preds.append((f_id, subj, obj, pred_rel, score)) | |
| except Exception as e: | |
| print("Skipping key with unexpected format:", key) | |
| continue | |
| frame_binary_preds.sort(key=lambda x: x[3], reverse=True) | |
| return frame_binary_preds | |
| _FONT = cv2.FONT_HERSHEY_SIMPLEX | |
| def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]: | |
| if mask is None: | |
| return None | |
| if isinstance(mask, torch.Tensor): | |
| mask_np = mask.detach().cpu().numpy() | |
| else: | |
| mask_np = np.asarray(mask) | |
| if mask_np.ndim == 0: | |
| return None | |
| if mask_np.ndim == 3: | |
| mask_np = np.squeeze(mask_np) | |
| if mask_np.ndim != 2: | |
| return None | |
| if mask_np.dtype == bool: | |
| return mask_np | |
| return mask_np > 0 | |
| def _sanitize_bbox( | |
| bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int | |
| ) -> Optional[Tuple[int, int, int, int]]: | |
| if bbox is None: | |
| return None | |
| if isinstance(bbox, (list, tuple)) and len(bbox) >= 4: | |
| x1, y1, x2, y2 = [float(b) for b in bbox[:4]] | |
| elif isinstance(bbox, np.ndarray) and bbox.size >= 4: | |
| x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]] | |
| else: | |
| return None | |
| x1 = int(np.clip(round(x1), 0, width - 1)) | |
| y1 = int(np.clip(round(y1), 0, height - 1)) | |
| x2 = int(np.clip(round(x2), 0, width - 1)) | |
| y2 = int(np.clip(round(y2), 0, height - 1)) | |
| if x2 <= x1 or y2 <= y1: | |
| return None | |
| return (x1, y1, x2, y2) | |
| def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]: | |
| color = get_color(obj_id) | |
| rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]] | |
| return (rgb[2], rgb[1], rgb[0]) | |
| def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]: | |
| return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color) | |
| def _draw_label_block( | |
| image: np.ndarray, | |
| lines: List[str], | |
| anchor: Tuple[int, int], | |
| color: Tuple[int, int, int], | |
| font_scale: float = 0.5, | |
| thickness: int = 1, | |
| direction: str = "up", | |
| ) -> None: | |
| if not lines: | |
| return | |
| img_h, img_w = image.shape[:2] | |
| x, y = anchor | |
| x = int(np.clip(x, 0, img_w - 1)) | |
| y_cursor = int(np.clip(y, 0, img_h - 1)) | |
| bg_color = _background_color(color) | |
| if direction == "down": | |
| for text in lines: | |
| text = str(text) | |
| (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) | |
| left_x = x | |
| right_x = min(left_x + tw + 8, img_w - 1) | |
| top_y = int(np.clip(y_cursor + 6, 0, img_h - 1)) | |
| bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1)) | |
| if bottom_y <= top_y: | |
| break | |
| cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1) | |
| text_x = left_x + 4 | |
| text_y = min(bottom_y - baseline - 2, img_h - 1) | |
| cv2.putText( | |
| image, | |
| text, | |
| (text_x, text_y), | |
| _FONT, | |
| font_scale, | |
| (0, 0, 0), | |
| thickness, | |
| cv2.LINE_AA, | |
| ) | |
| y_cursor = bottom_y | |
| else: | |
| for text in lines: | |
| text = str(text) | |
| (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) | |
| top_y = max(y_cursor - th - baseline - 6, 0) | |
| left_x = x | |
| right_x = min(left_x + tw + 8, img_w - 1) | |
| bottom_y = min(top_y + th + baseline + 6, img_h - 1) | |
| cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1) | |
| text_x = left_x + 4 | |
| text_y = min(bottom_y - baseline - 2, img_h - 1) | |
| cv2.putText( | |
| image, | |
| text, | |
| (text_x, text_y), | |
| _FONT, | |
| font_scale, | |
| (0, 0, 0), | |
| thickness, | |
| cv2.LINE_AA, | |
| ) | |
| y_cursor = top_y | |
| def _draw_centered_label( | |
| image: np.ndarray, | |
| text: str, | |
| center: Tuple[int, int], | |
| color: Tuple[int, int, int], | |
| font_scale: float = 0.5, | |
| thickness: int = 1, | |
| ) -> None: | |
| text = str(text) | |
| img_h, img_w = image.shape[:2] | |
| (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness) | |
| cx = int(np.clip(center[0], 0, img_w - 1)) | |
| cy = int(np.clip(center[1], 0, img_h - 1)) | |
| left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1)) | |
| top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1)) | |
| right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1)) | |
| bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1)) | |
| cv2.rectangle( | |
| image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1 | |
| ) | |
| text_x = left_x + 4 | |
| text_y = min(bottom_y - baseline - 2, img_h - 1) | |
| cv2.putText( | |
| image, | |
| text, | |
| (text_x, text_y), | |
| _FONT, | |
| font_scale, | |
| (0, 0, 0), | |
| thickness, | |
| cv2.LINE_AA, | |
| ) | |
| def _extract_frame_entities( | |
| store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int | |
| ) -> Dict[int, Any]: | |
| if isinstance(store, dict): | |
| frame_entry = store.get(frame_idx, {}) | |
| elif isinstance(store, list) and 0 <= frame_idx < len(store): | |
| frame_entry = store[frame_idx] | |
| else: | |
| frame_entry = {} | |
| if isinstance(frame_entry, dict): | |
| return frame_entry | |
| if isinstance(frame_entry, list): | |
| return {i: value for i, value in enumerate(frame_entry)} | |
| return {} | |
| def _label_anchor_and_direction( | |
| bbox: Tuple[int, int, int, int], | |
| position: str, | |
| ) -> Tuple[Tuple[int, int], str]: | |
| x1, y1, x2, y2 = bbox | |
| if position == "bottom": | |
| return (x1, y2), "down" | |
| return (x1, y1), "up" | |
| def _draw_bbox_with_label( | |
| image: np.ndarray, | |
| bbox: Tuple[int, int, int, int], | |
| obj_id: int, | |
| title: Optional[str] = None, | |
| sub_lines: Optional[List[str]] = None, | |
| label_position: str = "top", | |
| ) -> None: | |
| color = _object_color_bgr(obj_id) | |
| cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) | |
| head = title if title else f"#{obj_id}" | |
| if not head.startswith("#"): | |
| head = f"#{obj_id} {head}" | |
| lines = [head] | |
| if sub_lines: | |
| lines.extend(sub_lines) | |
| anchor, direction = _label_anchor_and_direction(bbox, label_position) | |
| _draw_label_block(image, lines, anchor, color, direction=direction) | |
| def render_sam_frames( | |
| frames: Union[np.ndarray, List[np.ndarray]], | |
| sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None], | |
| dino_labels: Optional[Dict[int, str]] = None, | |
| ) -> List[np.ndarray]: | |
| results: List[np.ndarray] = [] | |
| frames_iterable = frames if isinstance(frames, list) else list(frames) | |
| dino_labels = dino_labels or {} | |
| for frame_idx, frame in enumerate(frames_iterable): | |
| if frame is None: | |
| continue | |
| frame_rgb = np.asarray(frame) | |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| overlay = frame_bgr.astype(np.float32) | |
| masks_for_frame = _extract_frame_entities(sam_masks, frame_idx) | |
| for obj_id, mask in masks_for_frame.items(): | |
| mask_np = _to_numpy_mask(mask) | |
| if mask_np is None or not np.any(mask_np): | |
| continue | |
| color = _object_color_bgr(obj_id) | |
| alpha = 0.45 | |
| overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array( | |
| color, dtype=np.float32 | |
| ) | |
| annotated = np.clip(overlay, 0, 255).astype(np.uint8) | |
| frame_h, frame_w = annotated.shape[:2] | |
| for obj_id, mask in masks_for_frame.items(): | |
| mask_np = _to_numpy_mask(mask) | |
| if mask_np is None or not np.any(mask_np): | |
| continue | |
| bbox = mask_to_bbox(mask_np) | |
| bbox = _sanitize_bbox(bbox, frame_w, frame_h) | |
| if not bbox: | |
| continue | |
| label = dino_labels.get(obj_id) | |
| title = f"{label}" if label else None | |
| _draw_bbox_with_label(annotated, bbox, obj_id, title=title) | |
| results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)) | |
| return results | |
| def render_dino_frames( | |
| frames: Union[np.ndarray, List[np.ndarray]], | |
| bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], | |
| dino_labels: Optional[Dict[int, str]] = None, | |
| ) -> List[np.ndarray]: | |
| results: List[np.ndarray] = [] | |
| frames_iterable = frames if isinstance(frames, list) else list(frames) | |
| dino_labels = dino_labels or {} | |
| for frame_idx, frame in enumerate(frames_iterable): | |
| if frame is None: | |
| continue | |
| frame_rgb = np.asarray(frame) | |
| annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| frame_h, frame_w = annotated.shape[:2] | |
| frame_bboxes = _extract_frame_entities(bboxes, frame_idx) | |
| for obj_id, bbox_values in frame_bboxes.items(): | |
| bbox = _sanitize_bbox(bbox_values, frame_w, frame_h) | |
| if not bbox: | |
| continue | |
| label = dino_labels.get(obj_id) | |
| title = f"{label}" if label else None | |
| _draw_bbox_with_label(annotated, bbox, obj_id, title=title) | |
| results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)) | |
| return results | |
| def render_vine_frame_sets( | |
| frames: Union[np.ndarray, List[np.ndarray]], | |
| bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], | |
| cat_label_lookup: Dict[int, Tuple[str, float]], | |
| unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]], | |
| binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]], | |
| masks: Union[ | |
| Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None | |
| ] = None, | |
| binary_confidence_threshold: float = 0.0, | |
| ) -> Dict[str, List[np.ndarray]]: | |
| frame_groups: Dict[str, List[np.ndarray]] = { | |
| "object": [], | |
| "unary": [], | |
| "binary": [], | |
| "all": [], | |
| } | |
| frames_iterable = frames if isinstance(frames, list) else list(frames) | |
| for frame_idx, frame in enumerate(frames_iterable): | |
| if frame is None: | |
| continue | |
| frame_rgb = np.asarray(frame) | |
| base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| frame_h, frame_w = base_bgr.shape[:2] | |
| frame_bboxes = _extract_frame_entities(bboxes, frame_idx) | |
| frame_masks = ( | |
| _extract_frame_entities(masks, frame_idx) if masks is not None else {} | |
| ) | |
| objects_bgr = base_bgr.copy() | |
| unary_bgr = base_bgr.copy() | |
| binary_bgr = base_bgr.copy() | |
| all_bgr = base_bgr.copy() | |
| bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {} | |
| unary_lines_lookup: Dict[int, List[str]] = {} | |
| titles_lookup: Dict[int, Optional[str]] = {} | |
| for obj_id, bbox_values in frame_bboxes.items(): | |
| bbox = _sanitize_bbox(bbox_values, frame_w, frame_h) | |
| if not bbox: | |
| continue | |
| bbox_lookup[obj_id] = bbox | |
| cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None)) | |
| title_parts = [] | |
| if cat_label: | |
| if cat_prob is not None: | |
| title_parts.append(f"{cat_label} {cat_prob:.2f}") | |
| else: | |
| title_parts.append(cat_label) | |
| titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None | |
| unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, []) | |
| unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds] | |
| unary_lines_lookup[obj_id] = unary_lines | |
| for obj_id, bbox in bbox_lookup.items(): | |
| unary_lines = unary_lines_lookup.get(obj_id, []) | |
| if not unary_lines: | |
| continue | |
| mask_raw = frame_masks.get(obj_id) | |
| mask_np = _to_numpy_mask(mask_raw) | |
| if mask_np is None or not np.any(mask_np): | |
| continue | |
| color = np.array(_object_color_bgr(obj_id), dtype=np.float32) | |
| alpha = 0.45 | |
| for target in (unary_bgr, all_bgr): | |
| target_vals = target[mask_np].astype(np.float32) | |
| blended = (1.0 - alpha) * target_vals + alpha * color | |
| target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8) | |
| for obj_id, bbox in bbox_lookup.items(): | |
| title = titles_lookup.get(obj_id) | |
| unary_lines = unary_lines_lookup.get(obj_id, []) | |
| _draw_bbox_with_label( | |
| objects_bgr, bbox, obj_id, title=title, label_position="top" | |
| ) | |
| _draw_bbox_with_label( | |
| unary_bgr, bbox, obj_id, title=title, label_position="top" | |
| ) | |
| if unary_lines: | |
| anchor, direction = _label_anchor_and_direction(bbox, "bottom") | |
| _draw_label_block( | |
| unary_bgr, | |
| unary_lines, | |
| anchor, | |
| _object_color_bgr(obj_id), | |
| direction=direction, | |
| ) | |
| _draw_bbox_with_label( | |
| binary_bgr, bbox, obj_id, title=title, label_position="top" | |
| ) | |
| _draw_bbox_with_label( | |
| all_bgr, bbox, obj_id, title=title, label_position="top" | |
| ) | |
| if unary_lines: | |
| anchor, direction = _label_anchor_and_direction(bbox, "bottom") | |
| _draw_label_block( | |
| all_bgr, | |
| unary_lines, | |
| anchor, | |
| _object_color_bgr(obj_id), | |
| direction=direction, | |
| ) | |
| # First pass: collect all pairs above threshold and deduplicate bidirectional pairs | |
| pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation) | |
| for obj_pair, relation_preds in binary_lookup.get(frame_idx, []): | |
| if len(obj_pair) != 2 or not relation_preds: | |
| continue | |
| subj_id, obj_id = obj_pair | |
| subj_bbox = bbox_lookup.get(subj_id) | |
| obj_bbox = bbox_lookup.get(obj_id) | |
| if not subj_bbox or not obj_bbox: | |
| continue | |
| prob, relation = relation_preds[0] | |
| # Filter by confidence threshold | |
| if prob < binary_confidence_threshold: | |
| continue | |
| # Create canonical key (smaller_id, larger_id) for deduplication | |
| pair_key = (min(subj_id, obj_id), max(subj_id, obj_id)) | |
| # Keep the higher confidence direction | |
| if pair_key not in pairs_to_draw or prob > pairs_to_draw[pair_key][2]: | |
| pairs_to_draw[pair_key] = (subj_id, obj_id, prob, relation) | |
| # Second pass: draw the selected pairs | |
| for subj_id, obj_id, prob, relation in pairs_to_draw.values(): | |
| subj_bbox = bbox_lookup.get(subj_id) | |
| obj_bbox = bbox_lookup.get(obj_id) | |
| start, end = relation_line(subj_bbox, obj_bbox) | |
| color = tuple( | |
| int(c) | |
| for c in np.clip( | |
| ( | |
| np.array(_object_color_bgr(subj_id), dtype=np.float32) | |
| + np.array(_object_color_bgr(obj_id), dtype=np.float32) | |
| ) | |
| / 2.0, | |
| 0, | |
| 255, | |
| ) | |
| ) | |
| label_text = f"{relation} {prob:.2f}" | |
| mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2)) | |
| # Draw arrowed lines showing direction from subject to object (smaller arrow tip) | |
| cv2.arrowedLine( | |
| binary_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05 | |
| ) | |
| cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05) | |
| _draw_centered_label(binary_bgr, label_text, mid_point, color) | |
| _draw_centered_label(all_bgr, label_text, mid_point, color) | |
| frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB)) | |
| frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB)) | |
| frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB)) | |
| frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB)) | |
| return frame_groups | |
| def render_vine_frames( | |
| frames: Union[np.ndarray, List[np.ndarray]], | |
| bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None], | |
| cat_label_lookup: Dict[int, Tuple[str, float]], | |
| unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]], | |
| binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]], | |
| masks: Union[ | |
| Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None | |
| ] = None, | |
| binary_confidence_threshold: float = 0.0, | |
| ) -> List[np.ndarray]: | |
| return render_vine_frame_sets( | |
| frames, | |
| bboxes, | |
| cat_label_lookup, | |
| unary_lookup, | |
| binary_lookup, | |
| masks, | |
| binary_confidence_threshold, | |
| ).get("all", []) | |
| def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object): | |
| all_colors = [] | |
| all_texts = [] | |
| for obj_id, bbox, gt_label in gt_labels: | |
| preds = obj_pred_dict.get(obj_id, []) | |
| if len(preds) == 0: | |
| top1 = "N/A" | |
| box_color = (0, 0, 255) # bright red if no prediction | |
| else: | |
| top1, prob1 = preds[0] | |
| topk_labels = [p[0] for p in preds[:topk_object]] | |
| # Compare cleaned labels. | |
| if top1.lower() == gt_label.lower(): | |
| box_color = (0, 255, 0) # bright green for correct | |
| elif gt_label.lower() in [p.lower() for p in topk_labels]: | |
| box_color = (0, 165, 255) # bright orange for partial match | |
| else: | |
| box_color = (0, 0, 255) # bright red for incorrect | |
| label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}" | |
| all_colors.append(box_color) | |
| all_texts.append(label_text) | |
| return all_colors, all_texts | |
| def plot_unary(frame_img, gt_labels, all_colors, all_texts): | |
| for (obj_id, bbox, gt_label), box_color, label_text in zip( | |
| gt_labels, all_colors, all_texts | |
| ): | |
| x1, y1, x2, y2 = map(int, bbox) | |
| cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2) | |
| (tw, th), baseline = cv2.getTextSize( | |
| label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 | |
| ) | |
| cv2.rectangle( | |
| frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1 | |
| ) | |
| cv2.putText( | |
| frame_img, | |
| label_text, | |
| (x1, y1 - 2), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (0, 0, 0), | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| return frame_img | |
| def get_white_pane( | |
| pane_height, | |
| pane_width=600, | |
| header_height=50, | |
| header_font=cv2.FONT_HERSHEY_SIMPLEX, | |
| header_font_scale=0.7, | |
| header_thickness=2, | |
| header_color=(0, 0, 0), | |
| ): | |
| # Create an expanded white pane to display text info. | |
| white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8) | |
| # --- Adjust pane split: make predictions column wider (60% vs. 40%) --- | |
| left_width = int(pane_width * 0.6) | |
| right_width = pane_width - left_width | |
| left_pane = white_pane[:, :left_width, :].copy() | |
| right_pane = white_pane[:, left_width:, :].copy() | |
| cv2.putText( | |
| left_pane, | |
| "Binary Predictions", | |
| (10, header_height - 30), | |
| header_font, | |
| header_font_scale, | |
| header_color, | |
| header_thickness, | |
| cv2.LINE_AA, | |
| ) | |
| cv2.putText( | |
| right_pane, | |
| "Ground Truth", | |
| (10, header_height - 30), | |
| header_font, | |
| header_font_scale, | |
| header_color, | |
| header_thickness, | |
| cv2.LINE_AA, | |
| ) | |
| return white_pane | |
| # This is for ploting binary prediction results with frame-based scene graphs | |
| def plot_binary_sg( | |
| frame_img, | |
| white_pane, | |
| bin_preds, | |
| gt_relations, | |
| topk_binary, | |
| header_height=50, | |
| indicator_size=20, | |
| pane_width=600, | |
| ): | |
| # Leave vertical space for the headers. | |
| line_height = 30 # vertical spacing per line | |
| x_text = 10 # left margin for text | |
| y_text_left = header_height + 10 # starting y for left pane text | |
| y_text_right = header_height + 10 # starting y for right pane text | |
| # Left section: top-k binary predictions. | |
| left_width = int(pane_width * 0.6) | |
| right_width = pane_width - left_width | |
| left_pane = white_pane[:, :left_width, :].copy() | |
| right_pane = white_pane[:, left_width:, :].copy() | |
| for subj, pred_rel, obj, score in bin_preds[:topk_binary]: | |
| correct = any( | |
| (subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1]) | |
| for gt in gt_relations | |
| ) | |
| indicator_color = (0, 255, 0) if correct else (0, 0, 255) | |
| cv2.rectangle( | |
| left_pane, | |
| (x_text, y_text_left - indicator_size + 5), | |
| (x_text + indicator_size, y_text_left + 5), | |
| indicator_color, | |
| -1, | |
| ) | |
| text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}" | |
| cv2.putText( | |
| left_pane, | |
| text, | |
| (x_text + indicator_size + 5, y_text_left + 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 0, 0), | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| y_text_left += line_height | |
| # Right section: ground truth binary relations. | |
| for gt in gt_relations: | |
| if len(gt) != 3: | |
| continue | |
| text = f"{gt[0]} - {gt[2]} - {gt[1]}" | |
| cv2.putText( | |
| right_pane, | |
| text, | |
| (x_text, y_text_right + 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 0, 0), | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| y_text_right += line_height | |
| # Combine the two text panes and then with the frame image. | |
| combined_pane = np.hstack((left_pane, right_pane)) | |
| combined_image = np.hstack((frame_img, combined_pane)) | |
| return combined_image | |
| def visualized_frame( | |
| frame_img, | |
| bboxes, | |
| object_ids, | |
| gt_labels, | |
| cate_preds, | |
| binary_preds, | |
| gt_relations, | |
| topk_object, | |
| topk_binary, | |
| phase="unary", | |
| ): | |
| """Return the combined annotated frame for frame index i as an image (in BGR).""" | |
| # Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames) | |
| # --- Process Object Predictions (for overlaying bboxes) --- | |
| if phase == "unary": | |
| objs = [] | |
| for (_, f_id, obj_id), bbox, gt_label in zip(object_ids, bboxes, gt_labels): | |
| gt_label = clean_label(gt_label) | |
| objs.append((obj_id, bbox, gt_label)) | |
| formatted_cate_preds = format_cate_preds(cate_preds) | |
| all_colors, all_texts = color_for_cate_correctness( | |
| formatted_cate_preds, gt_labels, topk_object | |
| ) | |
| updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts) | |
| return updated_frame_img | |
| else: | |
| # --- Process Binary Predictions & Ground Truth for the Text Pane --- | |
| formatted_binary_preds = format_binary_cate_preds(binary_preds) | |
| # Ground truth binary relations for the frame. | |
| # Clean ground truth relations. | |
| gt_relations = [ | |
| (clean_label(str(s)), clean_label(str(o)), clean_label(rel)) | |
| for s, o, rel in gt_relations | |
| ] | |
| pane_width = 600 # increased pane width for more horizontal space | |
| pane_height = frame_img.shape[0] | |
| # --- Add header labels to each text pane with extra space --- | |
| header_height = 50 # increased header space | |
| white_pane = get_white_pane( | |
| pane_height, pane_width, header_height=header_height | |
| ) | |
| combined_image = plot_binary_sg( | |
| frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary | |
| ) | |
| return combined_image | |
| def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False): | |
| # Ensure mask is a numpy array | |
| mask = np.array(mask) | |
| # Handle different mask shapes | |
| if mask.ndim == 3: | |
| # (1, H, W) -> (H, W) | |
| if mask.shape[0] == 1: | |
| mask = mask.squeeze(0) | |
| # (H, W, 1) -> (H, W) | |
| elif mask.shape[2] == 1: | |
| mask = mask.squeeze(2) | |
| # Now mask should be (H, W) | |
| assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}" | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) | |
| else: | |
| cmap = plt.get_cmap("gist_rainbow") | |
| cmap_idx = 0 if obj_id is None else obj_id | |
| color = list(cmap((cmap_idx * 47) % 256)) | |
| color[3] = 0.5 | |
| color = np.array(color) | |
| # Expand mask to (H, W, 1) for broadcasting | |
| mask_expanded = mask[..., None] | |
| mask_image = mask_expanded * color.reshape(1, 1, -1) | |
| # draw a box around the mask with the det_class as the label | |
| if not det_class is None: | |
| # Find the bounding box coordinates | |
| y_indices, x_indices = np.where(mask > 0) | |
| if y_indices.size > 0 and x_indices.size > 0: | |
| x_min, x_max = x_indices.min(), x_indices.max() | |
| y_min, y_max = y_indices.min(), y_indices.max() | |
| rect = Rectangle( | |
| (x_min, y_min), | |
| x_max - x_min, | |
| y_max - y_min, | |
| linewidth=1.5, | |
| edgecolor=color[:3], | |
| facecolor="none", | |
| alpha=color[3], | |
| ) | |
| ax.add_patch(rect) | |
| ax.text( | |
| x_min, | |
| y_min - 5, | |
| f"{det_class}", | |
| color="white", | |
| fontsize=6, | |
| backgroundcolor=np.array(color), | |
| alpha=1, | |
| ) | |
| ax.imshow(mask_image) | |
| def save_mask_one_image(frame_image, masks, save_path): | |
| """Render masks on top of a frame and store the visualization on disk.""" | |
| fig, ax = plt.subplots(1, figsize=(6, 6)) | |
| frame_np = ( | |
| frame_image.detach().cpu().numpy() | |
| if torch.is_tensor(frame_image) | |
| else np.asarray(frame_image) | |
| ) | |
| frame_np = np.ascontiguousarray(frame_np) | |
| if isinstance(masks, dict): | |
| mask_iter = masks.items() | |
| else: | |
| mask_iter = enumerate(masks) | |
| prepared_masks = { | |
| obj_id: ( | |
| mask.detach().cpu().numpy() if torch.is_tensor(mask) else np.asarray(mask) | |
| ) | |
| for obj_id, mask in mask_iter | |
| } | |
| ax.imshow(frame_np) | |
| ax.axis("off") | |
| for obj_id, mask_np in prepared_masks.items(): | |
| show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False) | |
| fig.savefig(save_path, bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| return save_path | |
| def get_video_masks_visualization( | |
| video_tensor, | |
| video_masks, | |
| video_id, | |
| video_save_base_dir, | |
| oid_class_pred=None, | |
| sample_rate=1, | |
| ): | |
| video_save_dir = os.path.join(video_save_base_dir, video_id) | |
| if not os.path.exists(video_save_dir): | |
| os.makedirs(video_save_dir, exist_ok=True) | |
| for frame_id, image in enumerate(video_tensor): | |
| if frame_id not in video_masks: | |
| print("No mask for Frame", frame_id) | |
| continue | |
| masks = video_masks[frame_id] | |
| save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") | |
| get_mask_one_image(image, masks, oid_class_pred) | |
| def get_mask_one_image(frame_image, masks, oid_class_pred=None): | |
| # Create a figure and axis | |
| fig, ax = plt.subplots(1, figsize=(6, 6)) | |
| # Display the frame image | |
| ax.imshow(frame_image) | |
| ax.axis("off") | |
| if type(masks) == list: | |
| masks = {i: m for i, m in enumerate(masks)} | |
| # Add the masks | |
| for obj_id, mask in masks.items(): | |
| det_class = ( | |
| f"{obj_id}. {oid_class_pred[obj_id]}" | |
| if not oid_class_pred is None | |
| else None | |
| ) | |
| show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False) | |
| # Show the plot | |
| return fig, ax | |
| def save_video(frames, output_filename, output_fps): | |
| # --- Create a video from all frames --- | |
| num_frames = len(frames) | |
| frame_h, frame_w = frames.shape[:2] | |
| # Use a codec supported by VS Code (H.264 via 'avc1'). | |
| fourcc = cv2.VideoWriter_fourcc(*"avc1") | |
| out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h)) | |
| print(f"Processing {num_frames} frames...") | |
| for i in range(num_frames): | |
| vis_frame = get_visualized_frame(i) | |
| out.write(vis_frame) | |
| if i % 10 == 0: | |
| print(f"Processed frame {i + 1}/{num_frames}") | |
| out.release() | |
| print(f"Video saved as {output_filename}") | |
| def list_depth(lst): | |
| """Calculates the depth of a nested list.""" | |
| if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)): | |
| return 0 | |
| elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or ( | |
| isinstance(lst, list) and len(lst) == 0 | |
| ): | |
| return 1 | |
| else: | |
| return 1 + max(list_depth(item) for item in lst) | |
| def normalize_prompt(points, labels): | |
| if list_depth(points) == 3: | |
| points = torch.stack([p.unsqueeze(0) for p in points]) | |
| labels = torch.stack([l.unsqueeze(0) for l in labels]) | |
| return points, labels | |
| def show_box(box, ax, object_id): | |
| if len(box) == 0: | |
| return | |
| cmap = plt.get_cmap("gist_rainbow") | |
| cmap_idx = 0 if object_id is None else object_id | |
| color = list(cmap((cmap_idx * 47) % 256)) | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch( | |
| plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=2) | |
| ) | |
| def show_points(coords, labels, ax, object_id=None, marker_size=375): | |
| if len(labels) == 0: | |
| return | |
| pos_points = coords[labels == 1] | |
| neg_points = coords[labels == 0] | |
| cmap = plt.get_cmap("gist_rainbow") | |
| cmap_idx = 0 if object_id is None else object_id | |
| color = list(cmap((cmap_idx * 47) % 256)) | |
| ax.scatter( | |
| pos_points[:, 0], | |
| pos_points[:, 1], | |
| color="green", | |
| marker="P", | |
| s=marker_size, | |
| edgecolor=color, | |
| linewidth=1.25, | |
| ) | |
| ax.scatter( | |
| neg_points[:, 0], | |
| neg_points[:, 1], | |
| color="red", | |
| marker="s", | |
| s=marker_size, | |
| edgecolor=color, | |
| linewidth=1.25, | |
| ) | |
| def save_prompts_one_image(frame_image, boxes, points, labels, save_path): | |
| # Create a figure and axis | |
| fig, ax = plt.subplots(1, figsize=(6, 6)) | |
| # Display the frame image | |
| ax.imshow(frame_image) | |
| ax.axis("off") | |
| points, labels = normalize_prompt(points, labels) | |
| if type(boxes) == torch.Tensor: | |
| for object_id, box in enumerate(boxes): | |
| # Add the bounding boxes | |
| if not box is None: | |
| show_box(box.cpu(), ax, object_id=object_id) | |
| elif type(boxes) == dict: | |
| for object_id, box in boxes.items(): | |
| # Add the bounding boxes | |
| if not box is None: | |
| show_box(box.cpu(), ax, object_id=object_id) | |
| elif type(boxes) == list and len(boxes) == 0: | |
| pass | |
| else: | |
| raise Exception() | |
| for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)): | |
| if not len(point_ls) == 0: | |
| show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id) | |
| # Show the plot | |
| plt.savefig(save_path) | |
| plt.close() | |
| def save_video_prompts_visualization( | |
| video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir | |
| ): | |
| video_save_dir = os.path.join(video_save_base_dir, video_id) | |
| if not os.path.exists(video_save_dir): | |
| os.makedirs(video_save_dir, exist_ok=True) | |
| for frame_id, image in enumerate(video_tensor): | |
| boxes, points, labels = [], [], [] | |
| if frame_id in video_boxes: | |
| boxes = video_boxes[frame_id] | |
| if frame_id in video_points: | |
| points = video_points[frame_id] | |
| if frame_id in video_labels: | |
| labels = video_labels[frame_id] | |
| save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") | |
| save_prompts_one_image(image, boxes, points, labels, save_path) | |
| def save_video_masks_visualization( | |
| video_tensor, | |
| video_masks, | |
| video_id, | |
| video_save_base_dir, | |
| oid_class_pred=None, | |
| sample_rate=1, | |
| ): | |
| video_save_dir = os.path.join(video_save_base_dir, video_id) | |
| if not os.path.exists(video_save_dir): | |
| os.makedirs(video_save_dir, exist_ok=True) | |
| for frame_id, image in enumerate(video_tensor): | |
| if random.random() > sample_rate: | |
| continue | |
| if frame_id not in video_masks: | |
| print("No mask for Frame", frame_id) | |
| continue | |
| masks = video_masks[frame_id] | |
| save_path = os.path.join(video_save_dir, f"{frame_id}.jpg") | |
| save_mask_one_image(image, masks, save_path) | |
| def get_color(obj_id, cmap_name="gist_rainbow", alpha=0.5): | |
| cmap = plt.get_cmap(cmap_name) | |
| cmap_idx = 0 if obj_id is None else obj_id | |
| color = list(cmap((cmap_idx * 47) % 256)) | |
| color[3] = 0.5 | |
| color = np.array(color) | |
| return color | |
| def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]: | |
| return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0) | |
| def relation_line( | |
| bbox1: Tuple[int, int, int, int], | |
| bbox2: Tuple[int, int, int, int], | |
| ) -> Tuple[Tuple[int, int], Tuple[int, int]]: | |
| """ | |
| Returns integer pixel centers suitable for drawing a relation line. For | |
| coincident boxes, nudges the target center to ensure the segment has span. | |
| """ | |
| center1 = _bbox_center(bbox1) | |
| center2 = _bbox_center(bbox2) | |
| if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose( | |
| center1[1], center2[1], abs_tol=1e-3 | |
| ): | |
| offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05) | |
| center2 = (center2[0] + offset, center2[1]) | |
| start = (int(round(center1[0])), int(round(center1[1]))) | |
| end = (int(round(center2[0])), int(round(center2[1]))) | |
| if start == end: | |
| end = (end[0] + 1, end[1]) | |
| return start, end | |
| def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None): | |
| # Create a figure and axis | |
| fig, ax = plt.subplots(1, figsize=(6, 6)) | |
| # Display the frame image | |
| ax.imshow(frame_image) | |
| ax.axis("off") | |
| all_objs_to_show = set() | |
| all_lines_to_show = [] | |
| # print(rel_pred_ls[0]) | |
| for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items(): | |
| all_objs_to_show.add(from_obj_id) | |
| all_objs_to_show.add(to_obj_id) | |
| from_mask = masks[from_obj_id] | |
| bbox1 = mask_to_bbox(from_mask) | |
| to_mask = masks[to_obj_id] | |
| bbox2 = mask_to_bbox(to_mask) | |
| c1, c2 = shortest_line_between_bboxes(bbox1, bbox2) | |
| line_color = get_color(from_obj_id) | |
| face_color = get_color(to_obj_id) | |
| line = c1, c2, face_color, line_color, rel_text | |
| all_lines_to_show.append(line) | |
| masks_to_show = {} | |
| for oid in all_objs_to_show: | |
| masks_to_show[oid] = masks[oid] | |
| # Add the masks | |
| for obj_id, mask in masks_to_show.items(): | |
| show_mask(mask, ax, obj_id=obj_id, random_color=False) | |
| for (from_pt_x, from_pt_y), ( | |
| to_pt_x, | |
| to_pt_y, | |
| ), face_color, line_color, rel_text in all_lines_to_show: | |
| plt.plot( | |
| [from_pt_x, to_pt_x], | |
| [from_pt_y, to_pt_y], | |
| color=line_color, | |
| linestyle="-", | |
| linewidth=3, | |
| ) | |
| mid_pt_x = (from_pt_x + to_pt_x) / 2 | |
| mid_pt_y = (from_pt_y + to_pt_y) / 2 | |
| ax.text( | |
| mid_pt_x - 5, | |
| mid_pt_y, | |
| rel_text, | |
| color="white", | |
| fontsize=6, | |
| backgroundcolor=np.array(line_color), | |
| bbox=dict( | |
| facecolor=face_color, edgecolor=line_color, boxstyle="round,pad=1" | |
| ), | |
| alpha=1, | |
| ) | |
| # Show the plot | |
| return fig, ax | |