Spaces:
Running
on
Zero
Running
on
Zero
| from flax import config | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as cp | |
| from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor | |
| from typing import Dict, List, Tuple, Optional, Any, Union | |
| import numpy as np | |
| import os | |
| import cv2 | |
| from collections import defaultdict | |
| import builtins | |
| import sys | |
| from laser.models import llava_clip_model_v3 | |
| sys.modules["llava_clip_model_v3"] = llava_clip_model_v3 | |
| from safetensors.torch import load_file | |
| import inspect | |
| from transformers.models.clip import modeling_clip | |
| import transformers | |
| from huggingface_hub import snapshot_download | |
| from .vine_config import VineConfig | |
| from laser.models.model_utils import ( | |
| extract_single_object, | |
| extract_object_subject, | |
| crop_image_contain_bboxes, | |
| segment_list | |
| ) | |
| from .flattening import ( | |
| extract_valid_object_pairs, | |
| flatten_segments_for_batch, | |
| ) | |
| from .vis_utils import save_mask_one_image | |
| class VineModel(PreTrainedModel): | |
| """ | |
| VINE (Video Understanding with Natural Language) Model | |
| This model processes videos along with categorical, unary, and binary keywords | |
| to return probability distributions over those keywords for detected objects | |
| and their relationships in the video. | |
| """ | |
| config_class = VineConfig | |
| def __init__(self, config: VineConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.visualize = getattr(config, "visualize", False) | |
| self.visualization_dir = getattr(config, "visualization_dir", None) | |
| self.debug_visualizations = getattr(config, "debug_visualizations", False) | |
| self._device = getattr(config, "_device") | |
| # Initialize CLIP components | |
| self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| if self.clip_tokenizer.pad_token is None: | |
| self.clip_tokenizer.pad_token = ( | |
| self.clip_tokenizer.unk_token | |
| if self.clip_tokenizer.unk_token | |
| else self.clip_tokenizer.eos_token | |
| ) | |
| self.clip_processor = AutoProcessor.from_pretrained(config.model_name) | |
| self.clip_cate_model = AutoModel.from_pretrained(config.model_name) | |
| self.clip_unary_model = AutoModel.from_pretrained(config.model_name) | |
| self.clip_binary_model = AutoModel.from_pretrained(config.model_name) | |
| # Then try to load pretrained VINE weights if specified | |
| if config.use_hf_repo: | |
| self._load_huggingface_vine_weights(config.model_repo, config.model_file) | |
| else: | |
| self._load_local_pretrained_vine_weights(config.local_dir, config.local_filename) | |
| # Move models to devicexwxw | |
| self.to(self._device) | |
| def _load_huggingface_vine_weights(self, model_repo: str, model_file: Optional[str] = None): | |
| """ | |
| Load pretrained VINE weights from HuggingFace Hub. | |
| """ | |
| try: | |
| print(f"Loading VINE weights from HuggingFace repo: {model_repo}") | |
| repo_path = snapshot_download(model_repo, revision=model_file or "main") | |
| weights = load_file(os.path.join(repo_path, "model.safetensors")) | |
| self.load_state_dict(weights, strict=False) | |
| print("✓ Successfully loaded VINE weights from HuggingFace Hub") | |
| return True | |
| except Exception as e: | |
| print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}") | |
| print("Using base CLIP models instead") | |
| return False | |
| def _load_local_pretrained_vine_weights(self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0): | |
| """ | |
| Load pretrained VINE weights from a saved .pt file or ensemble format. | |
| """ | |
| #try: # simple .pt or .pth checkpoint | |
| # x = torch.load(pretrained_path, map_location=self._device, weights_only=False) | |
| # print(f"Loaded VINE checkpoint type: {type(x)}") | |
| full_path = os.path.join(local_dir, local_filename) if local_filename else local_dir | |
| if full_path.endswith(".pkl"): | |
| print(f"Loading VINE weights from: {full_path}") | |
| loaded_vine_model = torch.load(full_path, map_location=self._device, weights_only=False) | |
| print(f"Loaded state type: {type(loaded_vine_model)}") | |
| if not isinstance(loaded_vine_model, dict): | |
| if hasattr(loaded_vine_model, 'clip_cate_model'): | |
| self.clip_cate_model.load_state_dict(loaded_vine_model.clip_cate_model.state_dict()) | |
| if hasattr(loaded_vine_model, 'clip_unary_model'): | |
| self.clip_unary_model.load_state_dict(loaded_vine_model.clip_unary_model.state_dict()) | |
| if hasattr(loaded_vine_model, 'clip_binary_model'): | |
| self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict()) | |
| return True | |
| elif full_path.endswith(".pt") or full_path.endswith(".pth"): | |
| state = torch.load(full_path, map_location=self._device, weights_only=True) | |
| print(f"Loaded state type: {type(state)}") | |
| self.load_state_dict(state) | |
| return True | |
| # handle directory + epoch format | |
| if os.path.isdir(full_path): | |
| model_files = [f for f in os.listdir(full_path) if f.endswith(f'.{epoch}.model')] | |
| if model_files: | |
| model_file = os.path.join(full_path, model_files[0]) | |
| print(f"Loading VINE weights from: {model_file}") | |
| pretrained_model = torch.load(model_file, map_location="cpu") | |
| # Conversion from PredicateModel-like object to VineModel | |
| # Only copy if attributes exist | |
| if hasattr(pretrained_model, 'clip_cate_model'): | |
| self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict()) | |
| if hasattr(pretrained_model, 'clip_unary_model'): | |
| self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict()) | |
| if hasattr(pretrained_model, 'clip_binary_model'): | |
| self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict()) | |
| print("✓ Loaded all sub-model weights from ensemble format") | |
| return True | |
| else: | |
| print(f"No model file found for epoch {epoch} in {full_path}") | |
| return False | |
| print("Unsupported format for pretrained_vine_path") | |
| return False | |
| # except Exception as e: | |
| # print(f"✗ Error loading VINE weights: {e}") | |
| # print("Using base CLIP models instead") | |
| # return False | |
| # def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0): | |
| # """ | |
| # Load pretrained VINE weights from local ensemble format. | |
| # Args: | |
| # pretrained_path: Path to the pretrained model directory or HF model name | |
| # epoch: Epoch number to load (for ensemble format) | |
| # """ | |
| # if pretrained_path == "video-fm/vine_v0": | |
| # # Try to load from HuggingFace Hubtry: | |
| # # ✅ TODO FIXED: Added support for loading .pt/.pth checkpoints with state dicts | |
| # if pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"): | |
| # print(f"Loading VINE weights from: {pretrained_path}") | |
| # state = torch.load(pretrained_path, map_location="cpu") | |
| # if "clip_cate_model" in state: | |
| # self.clip_cate_model.load_state_dict(state["clip_cate_model"]) | |
| # print("✓ Loaded categorical model weights") | |
| # if "clip_unary_model" in state: | |
| # self.clip_unary_model.load_state_dict(state["clip_unary_model"]) | |
| # print("✓ Loaded unary model weights") | |
| # if "clip_binary_model" in state: | |
| # self.clip_binary_model.load_state_dict(state["clip_binary_model"]) | |
| # print("✓ Loaded binary model weights") | |
| # if "clip_tokenizer" in state: | |
| # self.clip_tokenizer = state["clip_tokenizer"] | |
| # print("✓ Loaded tokenizer") | |
| # if "clip_processor" in state: | |
| # self.clip_processor = state["clip_processor"] | |
| # print("✓ Loaded processor") | |
| # print("✓ All VINE weights loaded successfully") | |
| # return True | |
| # # Load from local ensemble format | |
| # try: | |
| # if os.path.isdir(pretrained_path): | |
| # # Directory format - look for ensemble file | |
| # model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')] | |
| # if model_files: | |
| # model_file = os.path.join(pretrained_path, model_files[0]) | |
| # else: | |
| # print(f"No model file found for epoch {epoch} in {pretrained_path}") | |
| # return False | |
| # else: | |
| # # Direct file path | |
| # model_file = pretrained_path | |
| # print(f"Loading VINE weights from: {model_file}") | |
| # # Load the ensemble model (PredicateModel instance) | |
| # # TODO: conversion from PredicateModel to VineModel | |
| # pretrained_model = torch.load(model_file, map_location='cpu', weights_only=False) | |
| # # Transfer weights from the pretrained model to our HuggingFace models | |
| # if hasattr(pretrained_model, 'clip_cate_model'): | |
| # self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict()) | |
| # print("✓ Loaded categorical model weights") | |
| # if hasattr(pretrained_model, 'clip_unary_model'): | |
| # self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict()) | |
| # print("✓ Loaded unary model weights") | |
| # if hasattr(pretrained_model, 'clip_binary_model'): | |
| # self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict()) | |
| # print("✓ Loaded binary model weights") | |
| # # Also transfer tokenizer and processor if available | |
| # if hasattr(pretrained_model, 'clip_tokenizer'): | |
| # self.clip_tokenizer = pretrained_model.clip_tokenizer | |
| # print("✓ Loaded tokenizer") | |
| # if hasattr(pretrained_model, 'clip_processor'): | |
| # self.clip_processor = pretrained_model.clip_processor | |
| # print("✓ Loaded processor") | |
| # print("✓ Successfully loaded all VINE weights") | |
| # return True | |
| # except Exception as e: | |
| # print(f"✗ Error loading VINE weights: {e}") | |
| # print("Using base CLIP models instead") | |
| # return False | |
| def from_pretrained_vine( | |
| cls, | |
| model_path: str, | |
| config: Optional[VineConfig] = None, | |
| epoch: int = 0, | |
| **kwargs | |
| ): | |
| """ | |
| Create VineModel from pretrained VINE weights. | |
| Args: | |
| model_path: Path to pretrained VINE model | |
| config: Optional config, will create default if None | |
| epoch: Epoch number to load | |
| **kwargs: Additional arguments | |
| Returns: | |
| VineModel instance with loaded weights | |
| """ | |
| # Normalize the incoming model_path into the new VineConfig fields. | |
| if config is None: | |
| # Heuristics: if path looks like a HF repo (contains a "/" and | |
| # doesn't exist on disk) treat it as a repo. Otherwise treat as local. | |
| if model_path and ("/" in model_path and not os.path.exists(model_path)): | |
| config = VineConfig(use_hf_repo=True, model_repo=model_path) | |
| else: | |
| # Local path: could be a file or directory | |
| if os.path.isdir(model_path): | |
| config = VineConfig(use_hf_repo=False, local_dir=model_path) | |
| else: | |
| config = VineConfig( | |
| use_hf_repo=False, | |
| local_dir=os.path.dirname(model_path) or None, | |
| local_filename=os.path.basename(model_path) or None, | |
| ) | |
| else: | |
| # Update provided config to reflect the requested pretrained path | |
| if model_path and ("/" in model_path and not os.path.exists(model_path)): | |
| config.use_hf_repo = True | |
| config.model_repo = model_path | |
| config.model_file = None | |
| config.local_dir = None | |
| config.local_filename = None | |
| else: | |
| config.use_hf_repo = False | |
| if os.path.isdir(model_path): | |
| config.local_dir = model_path | |
| config.local_filename = None | |
| else: | |
| config.local_dir = os.path.dirname(model_path) or None | |
| config.local_filename = os.path.basename(model_path) or None | |
| # Create model instance (will automatically load weights) | |
| model = cls(config, **kwargs) | |
| return model | |
| def _text_features_checkpoint(self, model, tokens): | |
| """Extract text features with gradient checkpointing.""" | |
| token_keys = list(tokens.keys()) | |
| def get_text_features_wrapped(*inputs): | |
| kwargs = {key: value for key, value in zip(token_keys, inputs)} | |
| return model.get_text_features(**kwargs) | |
| token_values = [tokens[key] for key in token_keys] | |
| return cp.checkpoint(get_text_features_wrapped, *token_values, use_reentrant=False) | |
| def _image_features_checkpoint(self, model, images): | |
| """Extract image features with gradient checkpointing.""" | |
| return cp.checkpoint(model.get_image_features, images, use_reentrant=False) | |
| def clip_sim(self, model, nl_feat, img_feat): | |
| img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True) | |
| nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True) | |
| logits = torch.matmul(img_feat, nl_feat.T) | |
| if hasattr(model, "logit_scale"): | |
| logits = logits * model.logit_scale.exp() | |
| return logits | |
| def forward( | |
| self, | |
| video_frames: torch.Tensor, | |
| masks: Dict[int, Dict[int, torch.Tensor]], | |
| bboxes: Dict[int, Dict[int, List]], | |
| categorical_keywords: List[str], | |
| unary_keywords: Optional[List[str]] = None, | |
| binary_keywords: Optional[List[str]] = None, | |
| object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| return_flattened_segments: Optional[bool] = None, | |
| return_valid_pairs: Optional[bool] = None, | |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| debug_visualizations: Optional[bool] = None, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Forward pass of the VINE model. | |
| Args: | |
| video_frames: Tensor of shape (num_frames, height, width, 3) | |
| masks: Dict mapping frame_id -> object_id -> mask tensor | |
| bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2] | |
| categorical_keywords: List of category names to classify objects | |
| unary_keywords: Optional list of unary predicates (actions on single objects) | |
| binary_keywords: Optional list of binary predicates (relations between objects) | |
| object_pairs: Optional list of (obj1_id, obj2_id) pairs for binary classification | |
| Returns: | |
| Dict containing probability distributions for categorical, unary, and binary predictions | |
| """ | |
| if unary_keywords is None: | |
| unary_keywords = [] | |
| if binary_keywords is None: | |
| binary_keywords = [] | |
| if object_pairs is None: | |
| object_pairs = [] | |
| if return_flattened_segments is None: | |
| return_flattened_segments = self.config.return_flattened_segments | |
| if return_valid_pairs is None: | |
| return_valid_pairs = self.config.return_valid_pairs | |
| if interested_object_pairs is None or len(interested_object_pairs) == 0: | |
| interested_object_pairs = getattr(self.config, "interested_object_pairs", []) or [] | |
| if debug_visualizations is None: | |
| debug_visualizations = self.debug_visualizations | |
| # Prepare dummy strings for empty categories | |
| dummy_str = "" | |
| # Fill empty categories with dummy strings | |
| if len(categorical_keywords) == 0: | |
| categorical_keywords = [dummy_str] | |
| if len(unary_keywords) == 0: | |
| unary_keywords = [dummy_str] | |
| if len(binary_keywords) == 0: | |
| binary_keywords = [dummy_str] | |
| # Extract text features for all keyword types | |
| categorical_features = self._extract_text_features( | |
| self.clip_cate_model, categorical_keywords | |
| ) | |
| unary_features = self._extract_text_features( | |
| self.clip_unary_model, unary_keywords | |
| ) | |
| binary_features = self._extract_text_features( | |
| self.clip_binary_model, binary_keywords | |
| ) | |
| # Process video frames and extract object features | |
| categorical_probs = {} | |
| unary_probs = {} | |
| binary_probs = {} | |
| # Process each frame | |
| for frame_id, frame_masks in masks.items(): | |
| if frame_id >= len(video_frames): | |
| continue | |
| frame = self._frame_to_numpy(video_frames[frame_id]) | |
| frame_bboxes = bboxes.get(frame_id, {}) | |
| # Extract object features for categorical classification | |
| for obj_id, mask in frame_masks.items(): | |
| if obj_id not in frame_bboxes: | |
| continue | |
| bbox = frame_bboxes[obj_id] | |
| # Extract single object image | |
| mask_np = self._mask_to_numpy(mask) | |
| obj_image = extract_single_object( | |
| frame, mask_np, alpha=self.config.alpha | |
| ) | |
| # Get image features | |
| obj_features = self._extract_image_features( | |
| self.clip_cate_model, obj_image | |
| ) | |
| # Compute similarities for categorical classification | |
| cat_similarities = self.clip_sim( | |
| self.clip_cate_model, categorical_features, obj_features | |
| ) | |
| cat_probs = F.softmax(cat_similarities, dim=-1) | |
| # Store categorical predictions | |
| for i, keyword in enumerate(categorical_keywords): | |
| if keyword != dummy_str: | |
| categorical_probs[(obj_id, keyword)] = cat_probs[0, i].item() | |
| # Compute unary predictions | |
| if len(unary_keywords) > 0 and unary_keywords[0] != dummy_str: | |
| unary_similarities = self.clip_sim( | |
| self.clip_unary_model, unary_features, obj_features | |
| ) | |
| unary_probs_tensor = F.softmax(unary_similarities, dim=-1) | |
| for i, keyword in enumerate(unary_keywords): | |
| if keyword != dummy_str: | |
| unary_probs[(frame_id, obj_id, keyword)] = unary_probs_tensor[0, i].item() | |
| # Process binary relationships | |
| if len(binary_keywords) > 0 and binary_keywords[0] != dummy_str and len(object_pairs) > 0: | |
| for obj1_id, obj2_id in object_pairs: | |
| for frame_id, frame_masks in masks.items(): | |
| if frame_id >= len(video_frames): | |
| continue | |
| if (obj1_id in frame_masks and obj2_id in frame_masks and | |
| obj1_id in bboxes.get(frame_id, {}) and obj2_id in bboxes.get(frame_id, {})): | |
| frame = self._frame_to_numpy(video_frames[frame_id]) | |
| mask1 = frame_masks[obj1_id] | |
| mask2 = frame_masks[obj2_id] | |
| mask1_np = self._mask_to_numpy(mask1) | |
| mask2_np = self._mask_to_numpy(mask2) | |
| # Extract object pair image | |
| pair_image = extract_object_subject( | |
| frame, mask1_np[..., None], mask2_np[..., None], | |
| alpha=self.config.alpha, | |
| white_alpha=self.config.white_alpha | |
| ) | |
| # Crop to contain both objects | |
| bbox1 = bboxes[frame_id][obj1_id] | |
| bbox2 = bboxes[frame_id][obj2_id] | |
| # Bounding box overlap check | |
| if bbox1[0] >= bbox2[2] or bbox2[1] >= bbox1[3] or \ | |
| bbox2[0] >= bbox1[2] or bbox1[1] >= bbox2[3]: | |
| continue | |
| cropped_image = crop_image_contain_bboxes( | |
| pair_image, [bbox1, bbox2], f"frame_{frame_id}" | |
| ) | |
| # Get image features | |
| pair_features = self._extract_image_features( | |
| self.clip_binary_model, cropped_image | |
| ) | |
| # Compute similarities for binary classification | |
| binary_similarities = self.clip_sim( | |
| self.clip_binary_model, binary_features, pair_features | |
| ) | |
| binary_probs_tensor = F.softmax(binary_similarities, dim=-1) | |
| for i, keyword in enumerate(binary_keywords): | |
| if keyword != dummy_str: | |
| binary_probs[(frame_id, (obj1_id, obj2_id), keyword)] = binary_probs_tensor[0, i].item() | |
| # Calculate dummy probability (for compatibility) | |
| dummy_prob = 1.0 / max(len(categorical_keywords), len(unary_keywords), len(binary_keywords)) | |
| result: Dict[str, Any] = { | |
| "categorical_probs": {0: categorical_probs}, # Video ID 0 | |
| "unary_probs": {0: unary_probs}, | |
| "binary_probs": [binary_probs], # List format for compatibility | |
| "dummy_prob": dummy_prob | |
| } | |
| if return_flattened_segments or return_valid_pairs: | |
| flattened = flatten_segments_for_batch( | |
| video_id=0, | |
| segments=masks, | |
| bbox_min_dim=self.config.bbox_min_dim, | |
| ) | |
| if return_flattened_segments: | |
| result["flattened_segments"] = flattened | |
| if return_valid_pairs: | |
| interested_pairs = interested_object_pairs if interested_object_pairs else None | |
| result["valid_pairs"] = extract_valid_object_pairs( | |
| flattened["object_ids"], | |
| interested_pairs, | |
| ) | |
| if interested_pairs is None: | |
| # Provide all generated pairs for clarity when auto-generated. | |
| result["valid_pairs_metadata"] = {"pair_source": "all_pairs"} | |
| else: | |
| result["valid_pairs_metadata"] = {"pair_source": "filtered", "requested_pairs": interested_pairs} | |
| return result | |
| def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray: | |
| """Convert a frame tensor/array to a contiguous numpy array.""" | |
| if torch.is_tensor(frame): | |
| frame_np = frame.detach().cpu().numpy() | |
| else: | |
| frame_np = np.asarray(frame) | |
| return np.ascontiguousarray(frame_np) | |
| def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray: | |
| """Convert a mask tensor/array to a 2D boolean numpy array.""" | |
| if torch.is_tensor(mask): | |
| mask_np = mask.detach().cpu().numpy() | |
| else: | |
| mask_np = np.asarray(mask) | |
| if mask_np.ndim == 3: | |
| if mask_np.shape[0] == 1: | |
| mask_np = mask_np.squeeze(0) | |
| elif mask_np.shape[2] == 1: | |
| mask_np = mask_np.squeeze(2) | |
| if mask_np.ndim != 2: | |
| raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}") | |
| return mask_np.astype(bool, copy=False) | |
| def _extract_text_features(self, model, keywords): | |
| """Extract text features for given keywords.""" | |
| tokens = self.clip_tokenizer( | |
| keywords, | |
| return_tensors="pt", | |
| max_length=75, | |
| truncation=True, | |
| padding='max_length' | |
| ).to(self._device) | |
| return self._text_features_checkpoint(model, tokens) | |
| def _extract_image_features(self, model, image): | |
| """Extract image features for given image.""" | |
| # Ensure image is in correct format | |
| if isinstance(image, np.ndarray): | |
| if image.dtype != np.uint8: | |
| image = image.astype(np.uint8) | |
| # Convert BGR to RGB if needed | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Process image with CLIP processor | |
| inputs = self.clip_processor( | |
| images=image, | |
| return_tensors="pt" | |
| ).to(self._device) | |
| return self._image_features_checkpoint(model, inputs['pixel_values']) | |
| #TODO: return masks and bboxes and their corresponding index | |
| def predict( | |
| self, | |
| video_frames: torch.Tensor, | |
| masks: Dict[int, Dict[int, torch.Tensor]], | |
| bboxes: Dict[int, Dict[int, List]], | |
| categorical_keywords: List[str], | |
| unary_keywords: Optional[List[str]] = None, | |
| binary_keywords: Optional[List[str]] = None, | |
| object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| return_top_k: int = 3, | |
| return_flattened_segments: Optional[bool] = None, | |
| return_valid_pairs: Optional[bool] = None, | |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| debug_visualizations: Optional[bool] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| High-level prediction method that returns formatted results. | |
| Args: | |
| video_frames: Tensor of shape (num_frames, height, width, 3) | |
| masks: Dict mapping frame_id -> object_id -> mask tensor | |
| bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2] | |
| categorical_keywords: List of category names | |
| unary_keywords: Optional list of unary predicates | |
| binary_keywords: Optional list of binary predicates | |
| object_pairs: Optional list of object pairs for binary relations | |
| return_top_k: Number of top predictions to return | |
| return_flattened_segments: Whether to include flattened mask/bbox tensors | |
| return_valid_pairs: Whether to compute valid object pairs per frame | |
| interested_object_pairs: Optional subset of object pairs to track | |
| Returns: | |
| Formatted prediction results | |
| """ | |
| with torch.no_grad(): | |
| outputs = self.forward( | |
| video_frames=video_frames, | |
| masks=masks, | |
| bboxes=bboxes, | |
| categorical_keywords=categorical_keywords, | |
| unary_keywords=unary_keywords, | |
| binary_keywords=binary_keywords, | |
| object_pairs=object_pairs, | |
| return_flattened_segments=return_flattened_segments, | |
| return_valid_pairs=return_valid_pairs, | |
| interested_object_pairs=interested_object_pairs, | |
| debug_visualizations=debug_visualizations, | |
| ) | |
| # Format categorical results | |
| formatted_categorical = {} | |
| for (obj_id, category), prob in outputs["categorical_probs"][0].items(): | |
| if obj_id not in formatted_categorical: | |
| formatted_categorical[obj_id] = [] | |
| formatted_categorical[obj_id].append((prob, category)) | |
| # Sort and take top-k for each object | |
| for obj_id in formatted_categorical: | |
| formatted_categorical[obj_id] = sorted( | |
| formatted_categorical[obj_id], reverse=True | |
| )[:return_top_k] | |
| # Format unary results | |
| formatted_unary = {} | |
| for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items(): | |
| key = (frame_id, obj_id) | |
| if key not in formatted_unary: | |
| formatted_unary[key] = [] | |
| formatted_unary[key].append((prob, predicate)) | |
| # Sort and take top-k | |
| for key in formatted_unary: | |
| formatted_unary[key] = sorted( | |
| formatted_unary[key], reverse=True | |
| )[:return_top_k] | |
| # Format binary results | |
| formatted_binary = {} | |
| if len(outputs["binary_probs"]) > 0: | |
| for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items(): | |
| key = (frame_id, obj_pair) | |
| if key not in formatted_binary: | |
| formatted_binary[key] = [] | |
| formatted_binary[key].append((prob, predicate)) | |
| # Sort and take top-k | |
| for key in formatted_binary: | |
| formatted_binary[key] = sorted( | |
| formatted_binary[key], reverse=True | |
| )[:return_top_k] | |
| result: Dict[str, Any] = { | |
| "categorical_predictions": formatted_categorical, | |
| "unary_predictions": formatted_unary, | |
| "binary_predictions": formatted_binary, | |
| "confidence_scores": { | |
| "categorical": max([max([p for p, _ in preds], default=0.0) | |
| for preds in formatted_categorical.values()], default=0.0), | |
| "unary": max([max([p for p, _ in preds], default=0.0) | |
| for preds in formatted_unary.values()], default=0.0), | |
| "binary": max([max([p for p, _ in preds], default=0.0) | |
| for preds in formatted_binary.values()], default=0.0) | |
| } | |
| } | |
| if "flattened_segments" in outputs: | |
| result["flattened_segments"] = outputs["flattened_segments"] | |
| if "valid_pairs" in outputs: | |
| result["valid_pairs"] = outputs["valid_pairs"] | |
| if "valid_pairs_metadata" in outputs: | |
| result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"] | |
| return result | |