Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| from typing import Dict, List, Tuple, Optional, Any, Union | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from safetensors.torch import load_file | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as cp | |
| from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor | |
| from huggingface_hub import snapshot_download | |
| from .vine_config import VineConfig | |
| from laser.models import llava_clip_model_v3 | |
| sys.modules["llava_clip_model_v3"] = llava_clip_model_v3 | |
| from laser.models.model_utils import ( | |
| extract_single_object, | |
| extract_object_subject, | |
| crop_image_contain_bboxes, | |
| segment_list, | |
| ) | |
| from .flattening import ( | |
| extract_valid_object_pairs, | |
| flatten_segments_for_batch, | |
| ) | |
| from .vis_utils import save_mask_one_image | |
| class VineModel(PreTrainedModel): | |
| """ | |
| VINE (Video Understanding with Natural Language) Model. | |
| Internally, the core CLIP/text/image/pair logic mirrors | |
| llava_clip_model_v3.PredicateModel as closely as possible for a single video, | |
| with a small extension to re-normalize categorical probs after pooling. | |
| """ | |
| config_class = VineConfig | |
| def __init__(self, config: VineConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.visualize = getattr(config, "visualize", False) | |
| self.visualization_dir = getattr(config, "visualization_dir", None) | |
| self.debug_visualizations = getattr(config, "debug_visualizations", False) | |
| self._device = getattr(config, "_device") | |
| # CLIP components | |
| self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| if self.clip_tokenizer.pad_token is None: | |
| self.clip_tokenizer.pad_token = ( | |
| self.clip_tokenizer.unk_token | |
| if self.clip_tokenizer.unk_token | |
| else self.clip_tokenizer.eos_token | |
| ) | |
| self.clip_processor = AutoProcessor.from_pretrained(config.model_name) | |
| self.clip_cate_model = AutoModel.from_pretrained(config.model_name) | |
| self.clip_unary_model = AutoModel.from_pretrained(config.model_name) | |
| self.clip_binary_model = AutoModel.from_pretrained(config.model_name) | |
| # Load fine-tuned weights if available | |
| if config.use_hf_repo: | |
| self._load_huggingface_vine_weights(config.model_repo, config.model_file) | |
| else: | |
| self._load_local_pretrained_vine_weights( | |
| config.local_dir, config.local_filename | |
| ) | |
| # Optionally reset categorical model to base CLIP (ignore fine-tune) | |
| if not getattr(config, "use_pretrained_cate_weights", True): | |
| self.clip_cate_model = AutoModel.from_pretrained(config.model_name) | |
| self.clip_cate_model.to(self._device) | |
| self.to(self._device) | |
| # ------------------------------------------------------------------ # | |
| # Weight loading | |
| # ------------------------------------------------------------------ # | |
| def _load_huggingface_vine_weights( | |
| self, model_repo: str, model_file: Optional[str] = None | |
| ): | |
| try: | |
| print(f"Loading VINE weights from HuggingFace repo: {model_repo}") | |
| repo_path = snapshot_download(model_repo, revision=model_file or "main") | |
| weights = load_file(os.path.join(repo_path, "model.safetensors")) | |
| self.load_state_dict(weights, strict=False) | |
| print("✓ Successfully loaded VINE weights from HuggingFace Hub") | |
| return True | |
| except Exception as e: | |
| print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}") | |
| print("Using base CLIP models instead") | |
| return False | |
| def _load_local_pretrained_vine_weights( | |
| self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0 | |
| ): | |
| if local_dir is None and local_filename is None: | |
| return False | |
| full_path = ( | |
| os.path.join(local_dir, local_filename) if local_filename else local_dir | |
| ) | |
| # .pkl – usually pickled PredicateModel | |
| if isinstance(full_path, str) and full_path.endswith(".pkl"): | |
| print(f"Loading VINE weights from: {full_path}") | |
| loaded_vine_model = torch.load( | |
| full_path, map_location=self._device, weights_only=False | |
| ) | |
| print(f"Loaded state type: {type(loaded_vine_model)}") | |
| if not isinstance(loaded_vine_model, dict): | |
| if hasattr(loaded_vine_model, "clip_tokenizer"): | |
| self.clip_tokenizer = loaded_vine_model.clip_tokenizer | |
| if hasattr(loaded_vine_model, "clip_processor"): | |
| self.clip_processor = loaded_vine_model.clip_processor | |
| if hasattr(loaded_vine_model, "clip_cate_model"): | |
| self.clip_cate_model.load_state_dict( | |
| loaded_vine_model.clip_cate_model.state_dict() | |
| ) | |
| if hasattr(loaded_vine_model, "clip_unary_model"): | |
| self.clip_unary_model.load_state_dict( | |
| loaded_vine_model.clip_unary_model.state_dict() | |
| ) | |
| if hasattr(loaded_vine_model, "clip_binary_model"): | |
| self.clip_binary_model.load_state_dict( | |
| loaded_vine_model.clip_binary_model.state_dict() | |
| ) | |
| print("✓ Loaded VINE weights from .pkl PredicateModel checkpoint") | |
| return True | |
| # .pt / .pth – plain state_dict | |
| elif isinstance(full_path, str) and ( | |
| full_path.endswith(".pt") or full_path.endswith(".pth") | |
| ): | |
| print(f"Loading VINE weights from: {full_path}") | |
| state = torch.load(full_path, map_location=self._device, weights_only=True) | |
| print(f"Loaded state type: {type(state)}") | |
| self.load_state_dict(state, strict=False) | |
| print("✓ Loaded VINE weights from state_dict") | |
| return True | |
| # .model – full PredicateModel object | |
| elif isinstance(full_path, str) and full_path.endswith(".model"): | |
| print(f"Loading VINE weights from: {full_path}") | |
| pretrained_model = torch.load( | |
| full_path, map_location="cpu", weights_only=False | |
| ) | |
| if hasattr(pretrained_model, "clip_tokenizer"): | |
| self.clip_tokenizer = pretrained_model.clip_tokenizer | |
| if hasattr(pretrained_model, "clip_processor"): | |
| self.clip_processor = pretrained_model.clip_processor | |
| if hasattr(pretrained_model, "clip_cate_model"): | |
| self.clip_cate_model.load_state_dict( | |
| pretrained_model.clip_cate_model.state_dict() | |
| ) | |
| if hasattr(pretrained_model, "clip_unary_model"): | |
| self.clip_unary_model.load_state_dict( | |
| pretrained_model.clip_unary_model.state_dict() | |
| ) | |
| if hasattr(pretrained_model, "clip_binary_model"): | |
| self.clip_binary_model.load_state_dict( | |
| pretrained_model.clip_binary_model.state_dict() | |
| ) | |
| print("✓ Loaded all sub-model weights from .model file") | |
| return True | |
| # directory of .model files | |
| if isinstance(full_path, str) and os.path.isdir(full_path): | |
| model_files = [ | |
| f for f in os.listdir(full_path) if f.endswith(f".{epoch}.model") | |
| ] | |
| if model_files: | |
| model_file = os.path.join(full_path, model_files[0]) | |
| print(f"Loading VINE weights from: {model_file}") | |
| pretrained_model = torch.load(model_file, map_location="cpu") | |
| if hasattr(pretrained_model, "clip_tokenizer"): | |
| self.clip_tokenizer = pretrained_model.clip_tokenizer | |
| if hasattr(pretrained_model, "clip_processor"): | |
| self.clip_processor = pretrained_model.clip_processor | |
| if hasattr(pretrained_model, "clip_cate_model"): | |
| self.clip_cate_model.load_state_dict( | |
| pretrained_model.clip_cate_model.state_dict() | |
| ) | |
| if hasattr(pretrained_model, "clip_unary_model"): | |
| self.clip_unary_model.load_state_dict( | |
| pretrained_model.clip_unary_model.state_dict() | |
| ) | |
| if hasattr(pretrained_model, "clip_binary_model"): | |
| self.clip_binary_model.load_state_dict( | |
| pretrained_model.clip_binary_model.state_dict() | |
| ) | |
| print("✓ Loaded all sub-model weights from ensemble format") | |
| return True | |
| else: | |
| print(f"No model file found for epoch {epoch} in {full_path}") | |
| return False | |
| print("Unsupported format for pretrained VINE path:", full_path) | |
| return False | |
| def from_pretrained_vine( | |
| cls, | |
| model_path: str, | |
| config: Optional[VineConfig] = None, | |
| epoch: int = 0, | |
| **kwargs: Any, | |
| ): | |
| if config is None: | |
| if model_path and ("/" in model_path and not os.path.exists(model_path)): | |
| config = VineConfig(use_hf_repo=True, model_repo=model_path) | |
| else: | |
| if os.path.isdir(model_path): | |
| config = VineConfig(use_hf_repo=False, local_dir=model_path) | |
| else: | |
| config = VineConfig( | |
| use_hf_repo=False, | |
| local_dir=os.path.dirname(model_path) or None, | |
| local_filename=os.path.basename(model_path) or None, | |
| ) | |
| else: | |
| if model_path and ("/" in model_path and not os.path.exists(model_path)): | |
| config.use_hf_repo = True | |
| config.model_repo = model_path | |
| config.model_file = None | |
| config.local_dir = None | |
| config.local_filename = None | |
| else: | |
| config.use_hf_repo = False | |
| if os.path.isdir(model_path): | |
| config.local_dir = model_path | |
| config.local_filename = None | |
| else: | |
| config.local_dir = os.path.dirname(model_path) or None | |
| config.local_filename = os.path.basename(model_path) or None | |
| model = cls(config, **kwargs) | |
| return model | |
| # ------------------------------------------------------------------ # | |
| # Gradient checkpoint helpers | |
| # ------------------------------------------------------------------ # | |
| def _text_features_checkpoint(self, model, token_dict): | |
| input_ids = token_dict["input_ids"] | |
| attention_mask = token_dict["attention_mask"] | |
| token_type_ids = token_dict.get("token_type_ids", None) | |
| if token_type_ids is not None: | |
| def forward_pass(input_ids, attention_mask, token_type_ids): | |
| return model.get_text_features( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| ) | |
| return cp.checkpoint( | |
| forward_pass, | |
| input_ids, | |
| attention_mask, | |
| token_type_ids, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| def forward_pass(input_ids, attention_mask): | |
| return model.get_text_features( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| return cp.checkpoint( | |
| forward_pass, input_ids, attention_mask, use_reentrant=False | |
| ) | |
| def _image_features_checkpoint(self, model, pixel_values): | |
| def forward_pass(pixel_values): | |
| return model.get_image_features(pixel_values=pixel_values) | |
| return cp.checkpoint(forward_pass, pixel_values, use_reentrant=False) | |
| # ------------------------------------------------------------------ # | |
| # CLIP similarity | |
| # ------------------------------------------------------------------ # | |
| def clip_sim(self, model, nl_feat, img_feat): | |
| img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True) | |
| nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True) | |
| logit_scale = getattr(model, "logit_scale", None) | |
| logits_per_text = torch.matmul(nl_feat, img_feat.t()) | |
| if logit_scale is not None: | |
| logits_per_text = logits_per_text * logit_scale.exp() | |
| return logits_per_text | |
| # ------------------------------------------------------------------ # | |
| # Forward: single-video PredicateModel-style logic | |
| # ------------------------------------------------------------------ # | |
| def forward( | |
| self, | |
| video_frames: torch.Tensor, | |
| masks: Dict[int, Dict[int, torch.Tensor]], | |
| bboxes: Dict[int, Dict[int, List]], | |
| categorical_keywords: List[str], | |
| unary_keywords: Optional[List[str]] = None, | |
| binary_keywords: Optional[List[str]] = None, | |
| object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| return_flattened_segments: Optional[bool] = None, | |
| return_valid_pairs: Optional[bool] = None, | |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| debug_visualizations: Optional[bool] = None, | |
| **kwargs: Any, | |
| ) -> Dict[str, Any]: | |
| if unary_keywords is None: | |
| unary_keywords = [] | |
| if binary_keywords is None: | |
| binary_keywords = [] | |
| if object_pairs is None: | |
| object_pairs = [] | |
| if return_flattened_segments is None: | |
| return_flattened_segments = getattr( | |
| self.config, "return_flattened_segments", False | |
| ) | |
| if return_valid_pairs is None: | |
| return_valid_pairs = getattr(self.config, "return_valid_pairs", False) | |
| if interested_object_pairs is None or len(interested_object_pairs) == 0: | |
| interested_object_pairs = ( | |
| getattr(self.config, "interested_object_pairs", []) or [] | |
| ) | |
| if debug_visualizations is None: | |
| debug_visualizations = self.debug_visualizations | |
| alpha = getattr(self.config, "alpha", 0.5) | |
| white_alpha = getattr(self.config, "white_alpha", 0.8) | |
| topk_cate = kwargs.pop("topk_cate", getattr(self.config, "topk_cate", 3)) | |
| dummy_str = kwargs.pop("dummy_str", getattr(self.config, "dummy_str", "$$$")) | |
| multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False)) | |
| output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False)) | |
| output_embeddings = kwargs.pop("output_embeddings", False) | |
| batched_video_ids = [0] | |
| if torch.is_tensor(video_frames): | |
| num_frames = video_frames.shape[0] | |
| batched_videos = [ | |
| self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames) | |
| ] | |
| else: | |
| num_frames = len(video_frames) | |
| batched_videos = [ | |
| self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames) | |
| ] | |
| batched_masks: List[np.ndarray] = [] | |
| batched_bboxes: List[List[float]] = [] | |
| batched_object_ids: List[Tuple[int, int, int]] = [] | |
| for frame_id, frame_masks in masks.items(): | |
| if frame_id >= num_frames: | |
| continue | |
| frame_boxes = bboxes.get(frame_id, {}) | |
| for obj_id, mask in frame_masks.items(): | |
| if obj_id not in frame_boxes: | |
| continue | |
| bbox = frame_boxes[obj_id] | |
| batched_object_ids.append((0, frame_id, obj_id)) | |
| batched_masks.append(self._mask_to_numpy(mask)) | |
| batched_bboxes.append(bbox) | |
| batched_names = [list(categorical_keywords)] | |
| batched_unary_kws = [list(unary_keywords)] | |
| batched_binary_kws = [list(binary_keywords)] | |
| batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] | |
| if object_pairs: | |
| for frame_id, frame_masks in masks.items(): | |
| if frame_id >= num_frames: | |
| continue | |
| present_ids = set(frame_masks.keys()) | |
| for (from_oid, to_oid) in object_pairs: | |
| if from_oid in present_ids and to_oid in present_ids: | |
| batched_obj_pairs.append((0, frame_id, (from_oid, to_oid))) | |
| batched_video_splits = [0] | |
| batched_binary_predicates = [None] | |
| def fill_empty(batched_kw): | |
| new_batched = [] | |
| for kw_ls in batched_kw: | |
| if len(kw_ls) == 0: | |
| new_batched.append([dummy_str]) | |
| else: | |
| new_batched.append(list(kw_ls)) | |
| return new_batched | |
| batched_names = fill_empty(batched_names) | |
| batched_unary_kws = fill_empty(batched_unary_kws) | |
| batched_binary_kws = fill_empty(batched_binary_kws) | |
| dummy_prob = torch.tensor(0.0, device=self._device) | |
| batched_obj_name_features = [] | |
| batched_unary_nl_features = [] | |
| batched_binary_nl_features = [] | |
| batched_object_ids_lookup: Dict[int, List[Tuple[int, int]]] = {0: []} | |
| batch_size = len(batched_video_ids) | |
| # Step 1: text features | |
| for object_names, unary_kws, binary_kws in zip( | |
| batched_names, batched_unary_kws, batched_binary_kws | |
| ): | |
| if len(object_names) == 0: | |
| batched_obj_name_features.append([]) | |
| else: | |
| obj_tokens = self.clip_tokenizer( | |
| object_names, | |
| return_tensors="pt", | |
| max_length=75, | |
| truncation=True, | |
| padding="max_length", | |
| ).to(self._device) | |
| obj_feats = self._text_features_checkpoint( | |
| self.clip_cate_model, obj_tokens | |
| ) | |
| batched_obj_name_features.append(obj_feats) | |
| if len(unary_kws) == 0: | |
| batched_unary_nl_features.append([]) | |
| else: | |
| unary_tokens = self.clip_tokenizer( | |
| list(unary_kws), | |
| return_tensors="pt", | |
| max_length=75, | |
| truncation=True, | |
| padding="max_length", | |
| ).to(self._device) | |
| unary_feats = self._text_features_checkpoint( | |
| self.clip_unary_model, unary_tokens | |
| ) | |
| batched_unary_nl_features.append(unary_feats) | |
| if len(binary_kws) == 0: | |
| batched_binary_nl_features.append([]) | |
| else: | |
| binary_tokens = self.clip_tokenizer( | |
| list(binary_kws), | |
| return_tensors="pt", | |
| max_length=75, | |
| truncation=True, | |
| padding="max_length", | |
| ).to(self._device) | |
| binary_feats = self._text_features_checkpoint( | |
| self.clip_binary_model, binary_tokens | |
| ) | |
| batched_binary_nl_features.append(binary_feats) | |
| # Step 2: crop objects | |
| batched_frame_masks: Dict[Tuple[int, int, int], np.ndarray] = {} | |
| batched_frame_bboxes: Dict[Tuple[int, int, int], List[float]] = {} | |
| batched_cropped_objs: Dict[int, List[np.ndarray]] = { | |
| vid: [] for vid in range(batch_size) | |
| } | |
| assert len(batched_object_ids) > 0, f"No object bbox: {batched_video_ids}" | |
| batched_video_splits = [0] + batched_video_splits | |
| for (video_id, frame_id, obj_id), mask, bbox in zip( | |
| batched_object_ids, batched_masks, batched_bboxes | |
| ): | |
| overall_frame_id = batched_video_splits[video_id] + frame_id | |
| object_img = extract_single_object( | |
| batched_videos[overall_frame_id], mask, white_alpha | |
| ) | |
| cropped_object_img = crop_image_contain_bboxes( | |
| object_img, [bbox], batched_video_ids | |
| ) | |
| if self.visualization_dir: | |
| debug_crop_dir = os.path.join(self.visualization_dir, "debug_crops") | |
| os.makedirs(debug_crop_dir, exist_ok=True) | |
| cv2.imwrite( | |
| os.path.join(debug_crop_dir, f"frame_{frame_id}_obj_{obj_id}.jpg"), | |
| cv2.cvtColor(cropped_object_img, cv2.COLOR_RGB2BGR), | |
| ) | |
| batched_frame_masks[(video_id, frame_id, obj_id)] = mask | |
| batched_frame_bboxes[(video_id, frame_id, obj_id)] = bbox | |
| batched_object_ids_lookup[video_id].append((frame_id, obj_id)) | |
| batched_cropped_objs[video_id].append(cropped_object_img) | |
| # Step 3: categorical + unary | |
| batched_image_unary_probs: Dict[int, Dict] = {} | |
| batched_image_cate_probs: Dict[int, Dict] = {} | |
| batched_obj_cate_features: Dict[int, Any] = {} | |
| batched_obj_unary_features: Dict[int, Any] = {} | |
| batched_obj_per_cate: Dict[int, Dict[str, List[Tuple[torch.Tensor, int]]]] = {} | |
| for vid in range(batch_size): | |
| batched_image_unary_probs[vid] = {} | |
| batched_image_cate_probs[vid] = {} | |
| batched_obj_cate_features[vid] = {} | |
| batched_obj_unary_features[vid] = {} | |
| batched_obj_per_cate[vid] = {} | |
| for vid_id, ( | |
| unary_nl_feats, | |
| object_name_feats, | |
| cate, | |
| unary_pred, | |
| binary_predicates, | |
| ) in enumerate( | |
| zip( | |
| batched_unary_nl_features, | |
| batched_obj_name_features, | |
| batched_names, | |
| batched_unary_kws, | |
| batched_binary_predicates, | |
| ) | |
| ): | |
| cropped_objs = batched_cropped_objs[vid_id] | |
| if len(cropped_objs) != 0: | |
| inputs = self.clip_processor( | |
| images=cropped_objs, return_tensors="pt" | |
| ).to(self._device) | |
| cate_obj_clip_features = self._image_features_checkpoint( | |
| self.clip_cate_model, inputs["pixel_values"] | |
| ) | |
| unary_obj_clip_features = self._image_features_checkpoint( | |
| self.clip_unary_model, inputs["pixel_values"] | |
| ) | |
| batched_obj_unary_features[vid_id] = unary_obj_clip_features | |
| batched_obj_cate_features[vid_id] = cate_obj_clip_features | |
| else: | |
| batched_obj_cate_features[vid_id] = torch.tensor([]) | |
| batched_obj_unary_features[vid_id] = torch.tensor([]) | |
| object_ids = batched_object_ids_lookup[vid_id] | |
| # Categorical logits | |
| if ( | |
| len(object_name_feats) == 0 | |
| or len(object_ids) == 0 | |
| or len(cropped_objs) == 0 | |
| ): | |
| cate_logits_per_text = torch.tensor([]) | |
| else: | |
| cate_logits_per_text = self.clip_sim( | |
| self.clip_cate_model, object_name_feats, cate_obj_clip_features | |
| ) | |
| if not output_logit: | |
| cate_logits_per_text = cate_logits_per_text.softmax(dim=0) | |
| if not ( | |
| len(object_ids) == 0 | |
| or ( | |
| cate_logits_per_text.ndim == 2 | |
| and cate_logits_per_text.shape[1] == len(object_ids) | |
| ) | |
| or len(object_name_feats) == 0 | |
| ): | |
| print("Object cate shape mismatch here") | |
| assert ( | |
| len(object_name_feats) == 0 | |
| or len(object_ids) == 0 | |
| or ( | |
| cate_logits_per_text.ndim == 2 | |
| and cate_logits_per_text.shape[1] == len(object_ids) | |
| ) | |
| ), f"Mismatched object id and cate logic: {batched_video_ids}" | |
| # Aggregate per object id across frames | |
| cate_prob_per_obj: Dict[int, Dict[str, List[torch.Tensor]]] = {} | |
| for cate_name, probs in zip(cate, cate_logits_per_text): | |
| if cate_name == dummy_str: | |
| dummy_prob += probs.sum() | |
| else: | |
| for prob, (fid, oid) in zip(probs, object_ids): | |
| cate_prob_per_obj.setdefault(oid, {}) | |
| cate_prob_per_obj[oid].setdefault(cate_name, []).append(prob) | |
| new_cate_prob_per_obj: Dict[Tuple[int, str], torch.Tensor] = {} | |
| obj_per_cate: Dict[str, List[Tuple[torch.Tensor, int]]] = {} | |
| for oid, object_cate_info in cate_prob_per_obj.items(): | |
| # Pool across frames per category | |
| pooled: Dict[str, torch.Tensor] = {} | |
| for cate_name, prob_list in object_cate_info.items(): | |
| stacked = torch.stack(prob_list) | |
| if getattr(self.config, "categorical_pool", "mean") == "mean": | |
| pooled_prob = stacked.mean() | |
| else: | |
| pooled_prob = stacked.max() | |
| pooled[cate_name] = pooled_prob | |
| if not pooled: | |
| continue | |
| # Renormalize across categories so they sum to 1 per object | |
| probs_tensor = torch.stack(list(pooled.values())) | |
| denom = probs_tensor.sum() | |
| if denom.item() <= 0: | |
| norm_tensor = torch.ones_like(probs_tensor) / len(pooled) | |
| else: | |
| norm_tensor = probs_tensor / denom | |
| for (cate_name, _), norm_prob in zip(pooled.items(), norm_tensor): | |
| obj_per_cate.setdefault(cate_name, []).append((norm_prob, oid)) | |
| new_cate_prob_per_obj[(oid, cate_name)] = norm_prob | |
| for cate_name in obj_per_cate: | |
| obj_per_cate[cate_name] = sorted( | |
| obj_per_cate[cate_name], key=lambda x: x[0], reverse=True | |
| ) | |
| # Unary | |
| if len(unary_nl_feats) == 0 or len(cropped_objs) == 0: | |
| unary_logits_per_text = torch.tensor([]) | |
| else: | |
| unary_logits_per_text = self.clip_sim( | |
| self.clip_unary_model, unary_nl_feats, unary_obj_clip_features | |
| ) | |
| if not output_logit: | |
| unary_logits_per_text = unary_logits_per_text.softmax(dim=0) | |
| unary_prob_per_obj: Dict[Tuple[int, int, str], torch.Tensor] = {} | |
| for unary_name, probs in zip(unary_pred, unary_logits_per_text): | |
| if unary_name == dummy_str: | |
| dummy_prob += probs.sum() | |
| else: | |
| for prob, (fid, oid) in zip(probs, object_ids): | |
| unary_prob_per_obj[(fid, oid, unary_name)] = prob | |
| batched_image_cate_probs[vid_id] = new_cate_prob_per_obj | |
| batched_image_unary_probs[vid_id] = unary_prob_per_obj | |
| batched_obj_per_cate[vid_id] = obj_per_cate | |
| # Step 4: binary pairs | |
| batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {} | |
| frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {} | |
| current_info = (0, 0) | |
| frame_splits[current_info] = {"start": 0} | |
| batched_topk_cate_candidates: Dict[int, Dict[str, List[int]]] = { | |
| video_id: {} for video_id in range(batch_size) | |
| } | |
| for video_id, obj_per_cate in batched_obj_per_cate.items(): | |
| topk_cate_candidates: Dict[str, List[int]] = {} | |
| for cate_name, pred_oid_ls in obj_per_cate.items(): | |
| for _, oid in pred_oid_ls[:topk_cate]: | |
| topk_cate_candidates.setdefault(cate_name, []).append(oid) | |
| batched_topk_cate_candidates[video_id] = topk_cate_candidates | |
| obj_pair_lookup: Dict[int, Dict[Tuple[int, int], List[int]]] = { | |
| video_id: {} for video_id in range(len(batched_video_ids)) | |
| } | |
| for (vid, fid, (from_oid, to_oid)) in batched_obj_pairs: | |
| if (from_oid, to_oid) not in obj_pair_lookup[vid]: | |
| obj_pair_lookup[vid][(from_oid, to_oid)] = [] | |
| obj_pair_lookup[vid][(from_oid, to_oid)].append(fid) | |
| selected_pairs = set() | |
| if batched_binary_predicates[0] is None: | |
| selected_pairs = set(batched_obj_pairs) | |
| else: | |
| for bp_vid, binary_predicates in enumerate(batched_binary_predicates): | |
| topk_cate_candidates = batched_topk_cate_candidates[bp_vid] | |
| for (rel_name, from_obj_name, to_obj_name) in binary_predicates: | |
| if ( | |
| from_obj_name in topk_cate_candidates | |
| and to_obj_name in topk_cate_candidates | |
| ): | |
| from_oids = topk_cate_candidates[from_obj_name] | |
| to_oids = topk_cate_candidates[to_obj_name] | |
| for from_oid in from_oids: | |
| for to_oid in to_oids: | |
| if ( | |
| bp_vid in obj_pair_lookup | |
| and (from_oid, to_oid) in obj_pair_lookup[bp_vid] | |
| ): | |
| for fid in obj_pair_lookup[bp_vid][ | |
| (from_oid, to_oid) | |
| ]: | |
| selected_pairs.add( | |
| (bp_vid, fid, (from_oid, to_oid)) | |
| ) | |
| selected_pairs = list(selected_pairs) | |
| new_select_pairs: Dict[int, List[Tuple[int, int, Tuple[int, int]]]] = { | |
| video_id: [] for video_id in range(len(batched_video_ids)) | |
| } | |
| for (vid, fid, (from_oid, to_oid)) in selected_pairs: | |
| new_select_pairs[vid].append((vid, fid, (from_oid, to_oid))) | |
| for vid in range(len(batched_video_ids)): | |
| batched_cropped_obj_pairs[vid] = [] | |
| for (vid, fid, (from_id, to_id)) in selected_pairs: | |
| if (vid, fid, from_id) not in batched_frame_masks or ( | |
| vid, | |
| fid, | |
| to_id, | |
| ) not in batched_frame_masks: | |
| continue | |
| if (vid, fid, from_id) not in batched_frame_bboxes or ( | |
| vid, | |
| fid, | |
| to_id, | |
| ) not in batched_frame_bboxes: | |
| continue | |
| overall_frame_id = batched_video_splits[vid] + fid | |
| mask1 = batched_frame_masks[(vid, fid, from_id)] | |
| mask2 = batched_frame_masks[(vid, fid, to_id)] | |
| bbox1 = batched_frame_bboxes[(vid, fid, from_id)] | |
| bbox2 = batched_frame_bboxes[(vid, fid, to_id)] | |
| bb_pop_image = extract_object_subject( | |
| batched_videos[overall_frame_id], | |
| mask1, | |
| mask2, | |
| alpha=alpha, | |
| white_alpha=white_alpha, | |
| ) | |
| cropped_bb_pop_image = crop_image_contain_bboxes( | |
| img=bb_pop_image, | |
| bbox_ls=[bbox1, bbox2], | |
| data_id=batched_video_ids, | |
| ) | |
| batched_cropped_obj_pairs[vid].append(cropped_bb_pop_image) | |
| if len(selected_pairs) == 0: | |
| selected_pairs.append((0, -1, (-1, -1))) | |
| new_select_pairs[0] = [(0, -1, (-1, -1))] | |
| dummy_img = batched_videos[0] | |
| batched_cropped_obj_pairs[0] = [dummy_img] | |
| batched_image_binary_probs: List[ | |
| Dict[Tuple[int, Tuple[int, int], str], torch.Tensor] | |
| ] = [] | |
| batched_obj_pair_features: Dict[int, torch.Tensor] = { | |
| vid: torch.tensor([]) for vid in range(batch_size) | |
| } | |
| if len(batched_cropped_obj_pairs) == 0: | |
| batched_image_binary_probs.append({}) | |
| else: | |
| for vid, binary_nl_features in enumerate(batched_binary_nl_features): | |
| if len(binary_nl_features) == 0: | |
| batched_image_binary_probs.append({}) | |
| continue | |
| binary_kws = batched_binary_kws[vid] | |
| cropped_obj_pairs = batched_cropped_obj_pairs[vid] | |
| if len(cropped_obj_pairs) == 0: | |
| batched_image_binary_probs.append({}) | |
| continue | |
| inputs = self.clip_processor( | |
| images=cropped_obj_pairs, return_tensors="pt" | |
| ).to(self._device) | |
| obj_features = self._image_features_checkpoint( | |
| self.clip_binary_model, inputs["pixel_values"] | |
| ) | |
| batched_obj_pair_features[vid] = obj_features | |
| obj_clip_features = obj_features / obj_features.norm( | |
| p=2, dim=-1, keepdim=True | |
| ) | |
| binary_nl_features = binary_nl_features / binary_nl_features.norm( | |
| p=2, dim=-1, keepdim=True | |
| ) | |
| logit_scale = self.clip_binary_model.logit_scale | |
| binary_logits_per_text = torch.matmul( | |
| binary_nl_features, obj_clip_features.t() | |
| ) * logit_scale.exp() | |
| if not output_logit: | |
| if not multi_class: | |
| binary_logits_per_text = binary_logits_per_text.softmax(dim=0) | |
| else: | |
| binary_logits_per_text = binary_logits_per_text.sigmoid() | |
| binary_prob_per_obj: Dict[ | |
| Tuple[int, Tuple[int, int], str], torch.Tensor | |
| ] = {} | |
| for binary_name, probs in zip(binary_kws, binary_logits_per_text): | |
| if binary_name == dummy_str: | |
| dummy_prob += probs.sum() | |
| else: | |
| for prob, (vid_, fid, obj_pair) in zip( | |
| probs, new_select_pairs[vid] | |
| ): | |
| if fid == -1: | |
| dummy_prob += prob | |
| else: | |
| binary_prob_per_obj[(fid, obj_pair, binary_name)] = prob | |
| batched_image_binary_probs.append(binary_prob_per_obj) | |
| result: Dict[str, Any] = { | |
| "categorical_probs": batched_image_cate_probs, | |
| "unary_probs": batched_image_unary_probs, | |
| "binary_probs": batched_image_binary_probs, | |
| "dummy_prob": dummy_prob, | |
| } | |
| if output_embeddings: | |
| embeddings_dict = { | |
| "cate_obj_clip_features": batched_obj_cate_features, | |
| "cate_object_ids": batched_object_ids_lookup, | |
| "unary_obj_clip_features": batched_obj_unary_features, | |
| "unary_object_ids": batched_object_ids_lookup, | |
| "binary_obj_pair_features": batched_obj_pair_features, | |
| "binary_object_pairs": new_select_pairs, | |
| } | |
| result["embeddings"] = embeddings_dict | |
| if return_flattened_segments or return_valid_pairs: | |
| flattened = flatten_segments_for_batch( | |
| video_id=0, | |
| segments=masks, | |
| bbox_min_dim=self.config.bbox_min_dim, | |
| ) | |
| if return_flattened_segments: | |
| result["flattened_segments"] = flattened | |
| if return_valid_pairs: | |
| interested_pairs = ( | |
| interested_object_pairs if interested_object_pairs else None | |
| ) | |
| result["valid_pairs"] = extract_valid_object_pairs( | |
| flattened["object_ids"], | |
| interested_pairs, | |
| ) | |
| if interested_pairs is None: | |
| result["valid_pairs_metadata"] = {"pair_source": "all_pairs"} | |
| else: | |
| result["valid_pairs_metadata"] = { | |
| "pair_source": "filtered", | |
| "requested_pairs": interested_object_pairs, | |
| } | |
| return result | |
| # ------------------------------------------------------------------ # | |
| # Helpers | |
| # ------------------------------------------------------------------ # | |
| def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray: | |
| if torch.is_tensor(frame): | |
| frame_np = frame.detach().cpu().numpy() | |
| else: | |
| frame_np = np.asarray(frame) | |
| return np.ascontiguousarray(frame_np) | |
| def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray: | |
| if torch.is_tensor(mask): | |
| mask_np = mask.detach().cpu().numpy() | |
| else: | |
| mask_np = np.asarray(mask) | |
| if mask_np.ndim == 3: | |
| if mask_np.shape[0] == 1: | |
| mask_np = mask_np.squeeze(0) | |
| elif mask_np.shape[2] == 1: | |
| mask_np = mask_np.squeeze(2) | |
| if mask_np.ndim != 2: | |
| raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}") | |
| return mask_np.astype(bool, copy=False) | |
| def _extract_text_features(self, model, keywords: List[str]): | |
| tokens = self.clip_tokenizer( | |
| keywords, | |
| return_tensors="pt", | |
| max_length=75, | |
| truncation=True, | |
| padding="max_length", | |
| ).to(self._device) | |
| return self._text_features_checkpoint(model, tokens) | |
| def _extract_image_features(self, model, image): | |
| if torch.is_tensor(image): | |
| image = image.detach().cpu().numpy() | |
| elif isinstance(image, np.ndarray): | |
| pass | |
| inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device) | |
| return self._image_features_checkpoint(model, inputs["pixel_values"]) | |
| # ------------------------------------------------------------------ # | |
| # High-level predict API | |
| # ------------------------------------------------------------------ # | |
| def predict( | |
| self, | |
| video_frames: torch.Tensor, | |
| masks: Dict[int, Dict[int, torch.Tensor]], | |
| bboxes: Dict[int, Dict[int, List]], | |
| categorical_keywords: List[str], | |
| unary_keywords: Optional[List[str]] = None, | |
| binary_keywords: Optional[List[str]] = None, | |
| object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| return_top_k: int = 3, | |
| return_flattened_segments: Optional[bool] = None, | |
| return_valid_pairs: Optional[bool] = None, | |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| debug_visualizations: Optional[bool] = None, | |
| ) -> Dict[str, Any]: | |
| with torch.no_grad(): | |
| outputs = self.forward( | |
| video_frames=video_frames, | |
| masks=masks, | |
| bboxes=bboxes, | |
| categorical_keywords=categorical_keywords, | |
| unary_keywords=unary_keywords, | |
| binary_keywords=binary_keywords, | |
| object_pairs=object_pairs, | |
| return_flattened_segments=return_flattened_segments, | |
| return_valid_pairs=return_valid_pairs, | |
| interested_object_pairs=interested_object_pairs, | |
| debug_visualizations=debug_visualizations, | |
| ) | |
| formatted_categorical: Dict[int, List[Tuple[float, str]]] = {} | |
| for (obj_id, category), prob in outputs["categorical_probs"][0].items(): | |
| if obj_id not in formatted_categorical: | |
| formatted_categorical[obj_id] = [] | |
| prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) | |
| formatted_categorical[obj_id].append((prob_val, category)) | |
| for obj_id in formatted_categorical: | |
| formatted_categorical[obj_id] = sorted( | |
| formatted_categorical[obj_id], reverse=True | |
| )[:return_top_k] | |
| formatted_unary: Dict[Tuple[int, int], List[Tuple[float, str]]] = {} | |
| for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items(): | |
| key = (frame_id, obj_id) | |
| if key not in formatted_unary: | |
| formatted_unary[key] = [] | |
| prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) | |
| formatted_unary[key].append((prob_val, predicate)) | |
| for key in formatted_unary: | |
| formatted_unary[key] = sorted( | |
| formatted_unary[key], reverse=True | |
| )[:return_top_k] | |
| formatted_binary: Dict[Tuple[int, Tuple[int, int]], List[Tuple[float, str]]] = {} | |
| if len(outputs["binary_probs"]) > 0: | |
| for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items(): | |
| key = (frame_id, obj_pair) | |
| if key not in formatted_binary: | |
| formatted_binary[key] = [] | |
| prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) | |
| formatted_binary[key].append((prob_val, predicate)) | |
| for key in formatted_binary: | |
| formatted_binary[key] = sorted( | |
| formatted_binary[key], reverse=True | |
| )[:return_top_k] | |
| def max_conf(d: Dict[Any, List[Tuple[float, str]]]) -> float: | |
| if not d: | |
| return 0.0 | |
| return max( | |
| (max((p for p, _ in preds), default=0.0) for preds in d.values()), | |
| default=0.0, | |
| ) | |
| result: Dict[str, Any] = { | |
| "categorical_predictions": formatted_categorical, | |
| "unary_predictions": formatted_unary, | |
| "binary_predictions": formatted_binary, | |
| "confidence_scores": { | |
| "categorical": max_conf(formatted_categorical), | |
| "unary": max_conf(formatted_unary), | |
| "binary": max_conf(formatted_binary), | |
| }, | |
| } | |
| if "flattened_segments" in outputs: | |
| result["flattened_segments"] = outputs["flattened_segments"] | |
| if "valid_pairs" in outputs: | |
| result["valid_pairs"] = outputs["valid_pairs"] | |
| if "valid_pairs_metadata" in outputs: | |
| result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"] | |
| return result | |