import os import sys from typing import Dict, List, Tuple, Optional, Any, Union import cv2 import numpy as np import torch from safetensors.torch import load_file from torch import nn import torch.nn.functional as F import torch.utils.checkpoint as cp from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor from huggingface_hub import snapshot_download from .vine_config import VineConfig from laser.models import llava_clip_model_v3 sys.modules["llava_clip_model_v3"] = llava_clip_model_v3 from laser.models.model_utils import ( extract_single_object, extract_object_subject, crop_image_contain_bboxes, segment_list, ) from .flattening import ( extract_valid_object_pairs, flatten_segments_for_batch, ) from .vis_utils import save_mask_one_image class VineModel(PreTrainedModel): """ VINE (Video Understanding with Natural Language) Model. Internally, the core CLIP/text/image/pair logic mirrors llava_clip_model_v3.PredicateModel as closely as possible for a single video, with a small extension to re-normalize categorical probs after pooling. """ config_class = VineConfig def __init__(self, config: VineConfig): super().__init__(config) self.config = config self.visualize = getattr(config, "visualize", False) self.visualization_dir = getattr(config, "visualization_dir", None) self.debug_visualizations = getattr(config, "debug_visualizations", False) self._device = getattr(config, "_device") # CLIP components self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name) if self.clip_tokenizer.pad_token is None: self.clip_tokenizer.pad_token = ( self.clip_tokenizer.unk_token if self.clip_tokenizer.unk_token else self.clip_tokenizer.eos_token ) self.clip_processor = AutoProcessor.from_pretrained(config.model_name) self.clip_cate_model = AutoModel.from_pretrained(config.model_name) self.clip_unary_model = AutoModel.from_pretrained(config.model_name) self.clip_binary_model = AutoModel.from_pretrained(config.model_name) # Load fine-tuned weights if available if config.use_hf_repo: self._load_huggingface_vine_weights(config.model_repo, config.model_file) else: self._load_local_pretrained_vine_weights( config.local_dir, config.local_filename ) # Optionally reset categorical model to base CLIP (ignore fine-tune) if not getattr(config, "use_pretrained_cate_weights", True): self.clip_cate_model = AutoModel.from_pretrained(config.model_name) self.clip_cate_model.to(self._device) self.to(self._device) # ------------------------------------------------------------------ # # Weight loading # ------------------------------------------------------------------ # def _load_huggingface_vine_weights( self, model_repo: str, model_file: Optional[str] = None ): try: print(f"Loading VINE weights from HuggingFace repo: {model_repo}") repo_path = snapshot_download(model_repo, revision=model_file or "main") weights = load_file(os.path.join(repo_path, "model.safetensors")) self.load_state_dict(weights, strict=False) print("✓ Successfully loaded VINE weights from HuggingFace Hub") return True except Exception as e: print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}") print("Using base CLIP models instead") return False def _load_local_pretrained_vine_weights( self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0 ): if local_dir is None and local_filename is None: return False full_path = ( os.path.join(local_dir, local_filename) if local_filename else local_dir ) # .pkl – usually pickled PredicateModel if isinstance(full_path, str) and full_path.endswith(".pkl"): print(f"Loading VINE weights from: {full_path}") loaded_vine_model = torch.load( full_path, map_location=self._device, weights_only=False ) print(f"Loaded state type: {type(loaded_vine_model)}") if not isinstance(loaded_vine_model, dict): if hasattr(loaded_vine_model, "clip_tokenizer"): self.clip_tokenizer = loaded_vine_model.clip_tokenizer if hasattr(loaded_vine_model, "clip_processor"): self.clip_processor = loaded_vine_model.clip_processor if hasattr(loaded_vine_model, "clip_cate_model"): self.clip_cate_model.load_state_dict( loaded_vine_model.clip_cate_model.state_dict() ) if hasattr(loaded_vine_model, "clip_unary_model"): self.clip_unary_model.load_state_dict( loaded_vine_model.clip_unary_model.state_dict() ) if hasattr(loaded_vine_model, "clip_binary_model"): self.clip_binary_model.load_state_dict( loaded_vine_model.clip_binary_model.state_dict() ) print("✓ Loaded VINE weights from .pkl PredicateModel checkpoint") return True # .pt / .pth – plain state_dict elif isinstance(full_path, str) and ( full_path.endswith(".pt") or full_path.endswith(".pth") ): print(f"Loading VINE weights from: {full_path}") state = torch.load(full_path, map_location=self._device, weights_only=True) print(f"Loaded state type: {type(state)}") self.load_state_dict(state, strict=False) print("✓ Loaded VINE weights from state_dict") return True # .model – full PredicateModel object elif isinstance(full_path, str) and full_path.endswith(".model"): print(f"Loading VINE weights from: {full_path}") pretrained_model = torch.load( full_path, map_location="cpu", weights_only=False ) if hasattr(pretrained_model, "clip_tokenizer"): self.clip_tokenizer = pretrained_model.clip_tokenizer if hasattr(pretrained_model, "clip_processor"): self.clip_processor = pretrained_model.clip_processor if hasattr(pretrained_model, "clip_cate_model"): self.clip_cate_model.load_state_dict( pretrained_model.clip_cate_model.state_dict() ) if hasattr(pretrained_model, "clip_unary_model"): self.clip_unary_model.load_state_dict( pretrained_model.clip_unary_model.state_dict() ) if hasattr(pretrained_model, "clip_binary_model"): self.clip_binary_model.load_state_dict( pretrained_model.clip_binary_model.state_dict() ) print("✓ Loaded all sub-model weights from .model file") return True # directory of .model files if isinstance(full_path, str) and os.path.isdir(full_path): model_files = [ f for f in os.listdir(full_path) if f.endswith(f".{epoch}.model") ] if model_files: model_file = os.path.join(full_path, model_files[0]) print(f"Loading VINE weights from: {model_file}") pretrained_model = torch.load(model_file, map_location="cpu") if hasattr(pretrained_model, "clip_tokenizer"): self.clip_tokenizer = pretrained_model.clip_tokenizer if hasattr(pretrained_model, "clip_processor"): self.clip_processor = pretrained_model.clip_processor if hasattr(pretrained_model, "clip_cate_model"): self.clip_cate_model.load_state_dict( pretrained_model.clip_cate_model.state_dict() ) if hasattr(pretrained_model, "clip_unary_model"): self.clip_unary_model.load_state_dict( pretrained_model.clip_unary_model.state_dict() ) if hasattr(pretrained_model, "clip_binary_model"): self.clip_binary_model.load_state_dict( pretrained_model.clip_binary_model.state_dict() ) print("✓ Loaded all sub-model weights from ensemble format") return True else: print(f"No model file found for epoch {epoch} in {full_path}") return False print("Unsupported format for pretrained VINE path:", full_path) return False @classmethod def from_pretrained_vine( cls, model_path: str, config: Optional[VineConfig] = None, epoch: int = 0, **kwargs: Any, ): if config is None: if model_path and ("/" in model_path and not os.path.exists(model_path)): config = VineConfig(use_hf_repo=True, model_repo=model_path) else: if os.path.isdir(model_path): config = VineConfig(use_hf_repo=False, local_dir=model_path) else: config = VineConfig( use_hf_repo=False, local_dir=os.path.dirname(model_path) or None, local_filename=os.path.basename(model_path) or None, ) else: if model_path and ("/" in model_path and not os.path.exists(model_path)): config.use_hf_repo = True config.model_repo = model_path config.model_file = None config.local_dir = None config.local_filename = None else: config.use_hf_repo = False if os.path.isdir(model_path): config.local_dir = model_path config.local_filename = None else: config.local_dir = os.path.dirname(model_path) or None config.local_filename = os.path.basename(model_path) or None model = cls(config, **kwargs) return model # ------------------------------------------------------------------ # # Gradient checkpoint helpers # ------------------------------------------------------------------ # def _text_features_checkpoint(self, model, token_dict): input_ids = token_dict["input_ids"] attention_mask = token_dict["attention_mask"] token_type_ids = token_dict.get("token_type_ids", None) if token_type_ids is not None: def forward_pass(input_ids, attention_mask, token_type_ids): return model.get_text_features( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) return cp.checkpoint( forward_pass, input_ids, attention_mask, token_type_ids, use_reentrant=False, ) else: def forward_pass(input_ids, attention_mask): return model.get_text_features( input_ids=input_ids, attention_mask=attention_mask, ) return cp.checkpoint( forward_pass, input_ids, attention_mask, use_reentrant=False ) def _image_features_checkpoint(self, model, pixel_values): def forward_pass(pixel_values): return model.get_image_features(pixel_values=pixel_values) return cp.checkpoint(forward_pass, pixel_values, use_reentrant=False) # ------------------------------------------------------------------ # # CLIP similarity # ------------------------------------------------------------------ # def clip_sim(self, model, nl_feat, img_feat): img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True) nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True) logit_scale = getattr(model, "logit_scale", None) logits_per_text = torch.matmul(nl_feat, img_feat.t()) if logit_scale is not None: logits_per_text = logits_per_text * logit_scale.exp() return logits_per_text # ------------------------------------------------------------------ # # Forward: single-video PredicateModel-style logic # ------------------------------------------------------------------ # def forward( self, video_frames: torch.Tensor, masks: Dict[int, Dict[int, torch.Tensor]], bboxes: Dict[int, Dict[int, List]], categorical_keywords: List[str], unary_keywords: Optional[List[str]] = None, binary_keywords: Optional[List[str]] = None, object_pairs: Optional[List[Tuple[int, int]]] = None, return_flattened_segments: Optional[bool] = None, return_valid_pairs: Optional[bool] = None, interested_object_pairs: Optional[List[Tuple[int, int]]] = None, debug_visualizations: Optional[bool] = None, **kwargs: Any, ) -> Dict[str, Any]: if unary_keywords is None: unary_keywords = [] if binary_keywords is None: binary_keywords = [] if object_pairs is None: object_pairs = [] if return_flattened_segments is None: return_flattened_segments = getattr( self.config, "return_flattened_segments", False ) if return_valid_pairs is None: return_valid_pairs = getattr(self.config, "return_valid_pairs", False) if interested_object_pairs is None or len(interested_object_pairs) == 0: interested_object_pairs = ( getattr(self.config, "interested_object_pairs", []) or [] ) if debug_visualizations is None: debug_visualizations = self.debug_visualizations alpha = getattr(self.config, "alpha", 0.5) white_alpha = getattr(self.config, "white_alpha", 0.8) topk_cate = kwargs.pop("topk_cate", getattr(self.config, "topk_cate", 3)) dummy_str = kwargs.pop("dummy_str", getattr(self.config, "dummy_str", "$$$")) multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False)) output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False)) output_embeddings = kwargs.pop("output_embeddings", False) batched_video_ids = [0] if torch.is_tensor(video_frames): num_frames = video_frames.shape[0] batched_videos = [ self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames) ] else: num_frames = len(video_frames) batched_videos = [ self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames) ] batched_masks: List[np.ndarray] = [] batched_bboxes: List[List[float]] = [] batched_object_ids: List[Tuple[int, int, int]] = [] for frame_id, frame_masks in masks.items(): if frame_id >= num_frames: continue frame_boxes = bboxes.get(frame_id, {}) for obj_id, mask in frame_masks.items(): if obj_id not in frame_boxes: continue bbox = frame_boxes[obj_id] batched_object_ids.append((0, frame_id, obj_id)) batched_masks.append(self._mask_to_numpy(mask)) batched_bboxes.append(bbox) batched_names = [list(categorical_keywords)] batched_unary_kws = [list(unary_keywords)] batched_binary_kws = [list(binary_keywords)] batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = [] if object_pairs: for frame_id, frame_masks in masks.items(): if frame_id >= num_frames: continue present_ids = set(frame_masks.keys()) for (from_oid, to_oid) in object_pairs: if from_oid in present_ids and to_oid in present_ids: batched_obj_pairs.append((0, frame_id, (from_oid, to_oid))) batched_video_splits = [0] batched_binary_predicates = [None] def fill_empty(batched_kw): new_batched = [] for kw_ls in batched_kw: if len(kw_ls) == 0: new_batched.append([dummy_str]) else: new_batched.append(list(kw_ls)) return new_batched batched_names = fill_empty(batched_names) batched_unary_kws = fill_empty(batched_unary_kws) batched_binary_kws = fill_empty(batched_binary_kws) dummy_prob = torch.tensor(0.0, device=self._device) batched_obj_name_features = [] batched_unary_nl_features = [] batched_binary_nl_features = [] batched_object_ids_lookup: Dict[int, List[Tuple[int, int]]] = {0: []} batch_size = len(batched_video_ids) # Step 1: text features for object_names, unary_kws, binary_kws in zip( batched_names, batched_unary_kws, batched_binary_kws ): if len(object_names) == 0: batched_obj_name_features.append([]) else: obj_tokens = self.clip_tokenizer( object_names, return_tensors="pt", max_length=75, truncation=True, padding="max_length", ).to(self._device) obj_feats = self._text_features_checkpoint( self.clip_cate_model, obj_tokens ) batched_obj_name_features.append(obj_feats) if len(unary_kws) == 0: batched_unary_nl_features.append([]) else: unary_tokens = self.clip_tokenizer( list(unary_kws), return_tensors="pt", max_length=75, truncation=True, padding="max_length", ).to(self._device) unary_feats = self._text_features_checkpoint( self.clip_unary_model, unary_tokens ) batched_unary_nl_features.append(unary_feats) if len(binary_kws) == 0: batched_binary_nl_features.append([]) else: binary_tokens = self.clip_tokenizer( list(binary_kws), return_tensors="pt", max_length=75, truncation=True, padding="max_length", ).to(self._device) binary_feats = self._text_features_checkpoint( self.clip_binary_model, binary_tokens ) batched_binary_nl_features.append(binary_feats) # Step 2: crop objects batched_frame_masks: Dict[Tuple[int, int, int], np.ndarray] = {} batched_frame_bboxes: Dict[Tuple[int, int, int], List[float]] = {} batched_cropped_objs: Dict[int, List[np.ndarray]] = { vid: [] for vid in range(batch_size) } assert len(batched_object_ids) > 0, f"No object bbox: {batched_video_ids}" batched_video_splits = [0] + batched_video_splits for (video_id, frame_id, obj_id), mask, bbox in zip( batched_object_ids, batched_masks, batched_bboxes ): overall_frame_id = batched_video_splits[video_id] + frame_id object_img = extract_single_object( batched_videos[overall_frame_id], mask, white_alpha ) cropped_object_img = crop_image_contain_bboxes( object_img, [bbox], batched_video_ids ) if self.visualization_dir: debug_crop_dir = os.path.join(self.visualization_dir, "debug_crops") os.makedirs(debug_crop_dir, exist_ok=True) cv2.imwrite( os.path.join(debug_crop_dir, f"frame_{frame_id}_obj_{obj_id}.jpg"), cv2.cvtColor(cropped_object_img, cv2.COLOR_RGB2BGR), ) batched_frame_masks[(video_id, frame_id, obj_id)] = mask batched_frame_bboxes[(video_id, frame_id, obj_id)] = bbox batched_object_ids_lookup[video_id].append((frame_id, obj_id)) batched_cropped_objs[video_id].append(cropped_object_img) # Step 3: categorical + unary batched_image_unary_probs: Dict[int, Dict] = {} batched_image_cate_probs: Dict[int, Dict] = {} batched_obj_cate_features: Dict[int, Any] = {} batched_obj_unary_features: Dict[int, Any] = {} batched_obj_per_cate: Dict[int, Dict[str, List[Tuple[torch.Tensor, int]]]] = {} for vid in range(batch_size): batched_image_unary_probs[vid] = {} batched_image_cate_probs[vid] = {} batched_obj_cate_features[vid] = {} batched_obj_unary_features[vid] = {} batched_obj_per_cate[vid] = {} for vid_id, ( unary_nl_feats, object_name_feats, cate, unary_pred, binary_predicates, ) in enumerate( zip( batched_unary_nl_features, batched_obj_name_features, batched_names, batched_unary_kws, batched_binary_predicates, ) ): cropped_objs = batched_cropped_objs[vid_id] if len(cropped_objs) != 0: inputs = self.clip_processor( images=cropped_objs, return_tensors="pt" ).to(self._device) cate_obj_clip_features = self._image_features_checkpoint( self.clip_cate_model, inputs["pixel_values"] ) unary_obj_clip_features = self._image_features_checkpoint( self.clip_unary_model, inputs["pixel_values"] ) batched_obj_unary_features[vid_id] = unary_obj_clip_features batched_obj_cate_features[vid_id] = cate_obj_clip_features else: batched_obj_cate_features[vid_id] = torch.tensor([]) batched_obj_unary_features[vid_id] = torch.tensor([]) object_ids = batched_object_ids_lookup[vid_id] # Categorical logits if ( len(object_name_feats) == 0 or len(object_ids) == 0 or len(cropped_objs) == 0 ): cate_logits_per_text = torch.tensor([]) else: cate_logits_per_text = self.clip_sim( self.clip_cate_model, object_name_feats, cate_obj_clip_features ) if not output_logit: cate_logits_per_text = cate_logits_per_text.softmax(dim=0) if not ( len(object_ids) == 0 or ( cate_logits_per_text.ndim == 2 and cate_logits_per_text.shape[1] == len(object_ids) ) or len(object_name_feats) == 0 ): print("Object cate shape mismatch here") assert ( len(object_name_feats) == 0 or len(object_ids) == 0 or ( cate_logits_per_text.ndim == 2 and cate_logits_per_text.shape[1] == len(object_ids) ) ), f"Mismatched object id and cate logic: {batched_video_ids}" # Aggregate per object id across frames cate_prob_per_obj: Dict[int, Dict[str, List[torch.Tensor]]] = {} for cate_name, probs in zip(cate, cate_logits_per_text): if cate_name == dummy_str: dummy_prob += probs.sum() else: for prob, (fid, oid) in zip(probs, object_ids): cate_prob_per_obj.setdefault(oid, {}) cate_prob_per_obj[oid].setdefault(cate_name, []).append(prob) new_cate_prob_per_obj: Dict[Tuple[int, str], torch.Tensor] = {} obj_per_cate: Dict[str, List[Tuple[torch.Tensor, int]]] = {} for oid, object_cate_info in cate_prob_per_obj.items(): # Pool across frames per category pooled: Dict[str, torch.Tensor] = {} for cate_name, prob_list in object_cate_info.items(): stacked = torch.stack(prob_list) if getattr(self.config, "categorical_pool", "mean") == "mean": pooled_prob = stacked.mean() else: pooled_prob = stacked.max() pooled[cate_name] = pooled_prob if not pooled: continue # Renormalize across categories so they sum to 1 per object probs_tensor = torch.stack(list(pooled.values())) denom = probs_tensor.sum() if denom.item() <= 0: norm_tensor = torch.ones_like(probs_tensor) / len(pooled) else: norm_tensor = probs_tensor / denom for (cate_name, _), norm_prob in zip(pooled.items(), norm_tensor): obj_per_cate.setdefault(cate_name, []).append((norm_prob, oid)) new_cate_prob_per_obj[(oid, cate_name)] = norm_prob for cate_name in obj_per_cate: obj_per_cate[cate_name] = sorted( obj_per_cate[cate_name], key=lambda x: x[0], reverse=True ) # Unary if len(unary_nl_feats) == 0 or len(cropped_objs) == 0: unary_logits_per_text = torch.tensor([]) else: unary_logits_per_text = self.clip_sim( self.clip_unary_model, unary_nl_feats, unary_obj_clip_features ) if not output_logit: unary_logits_per_text = unary_logits_per_text.softmax(dim=0) unary_prob_per_obj: Dict[Tuple[int, int, str], torch.Tensor] = {} for unary_name, probs in zip(unary_pred, unary_logits_per_text): if unary_name == dummy_str: dummy_prob += probs.sum() else: for prob, (fid, oid) in zip(probs, object_ids): unary_prob_per_obj[(fid, oid, unary_name)] = prob batched_image_cate_probs[vid_id] = new_cate_prob_per_obj batched_image_unary_probs[vid_id] = unary_prob_per_obj batched_obj_per_cate[vid_id] = obj_per_cate # Step 4: binary pairs batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {} frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {} current_info = (0, 0) frame_splits[current_info] = {"start": 0} batched_topk_cate_candidates: Dict[int, Dict[str, List[int]]] = { video_id: {} for video_id in range(batch_size) } for video_id, obj_per_cate in batched_obj_per_cate.items(): topk_cate_candidates: Dict[str, List[int]] = {} for cate_name, pred_oid_ls in obj_per_cate.items(): for _, oid in pred_oid_ls[:topk_cate]: topk_cate_candidates.setdefault(cate_name, []).append(oid) batched_topk_cate_candidates[video_id] = topk_cate_candidates obj_pair_lookup: Dict[int, Dict[Tuple[int, int], List[int]]] = { video_id: {} for video_id in range(len(batched_video_ids)) } for (vid, fid, (from_oid, to_oid)) in batched_obj_pairs: if (from_oid, to_oid) not in obj_pair_lookup[vid]: obj_pair_lookup[vid][(from_oid, to_oid)] = [] obj_pair_lookup[vid][(from_oid, to_oid)].append(fid) selected_pairs = set() if batched_binary_predicates[0] is None: selected_pairs = set(batched_obj_pairs) else: for bp_vid, binary_predicates in enumerate(batched_binary_predicates): topk_cate_candidates = batched_topk_cate_candidates[bp_vid] for (rel_name, from_obj_name, to_obj_name) in binary_predicates: if ( from_obj_name in topk_cate_candidates and to_obj_name in topk_cate_candidates ): from_oids = topk_cate_candidates[from_obj_name] to_oids = topk_cate_candidates[to_obj_name] for from_oid in from_oids: for to_oid in to_oids: if ( bp_vid in obj_pair_lookup and (from_oid, to_oid) in obj_pair_lookup[bp_vid] ): for fid in obj_pair_lookup[bp_vid][ (from_oid, to_oid) ]: selected_pairs.add( (bp_vid, fid, (from_oid, to_oid)) ) selected_pairs = list(selected_pairs) new_select_pairs: Dict[int, List[Tuple[int, int, Tuple[int, int]]]] = { video_id: [] for video_id in range(len(batched_video_ids)) } for (vid, fid, (from_oid, to_oid)) in selected_pairs: new_select_pairs[vid].append((vid, fid, (from_oid, to_oid))) for vid in range(len(batched_video_ids)): batched_cropped_obj_pairs[vid] = [] for (vid, fid, (from_id, to_id)) in selected_pairs: if (vid, fid, from_id) not in batched_frame_masks or ( vid, fid, to_id, ) not in batched_frame_masks: continue if (vid, fid, from_id) not in batched_frame_bboxes or ( vid, fid, to_id, ) not in batched_frame_bboxes: continue overall_frame_id = batched_video_splits[vid] + fid mask1 = batched_frame_masks[(vid, fid, from_id)] mask2 = batched_frame_masks[(vid, fid, to_id)] bbox1 = batched_frame_bboxes[(vid, fid, from_id)] bbox2 = batched_frame_bboxes[(vid, fid, to_id)] bb_pop_image = extract_object_subject( batched_videos[overall_frame_id], mask1, mask2, alpha=alpha, white_alpha=white_alpha, ) cropped_bb_pop_image = crop_image_contain_bboxes( img=bb_pop_image, bbox_ls=[bbox1, bbox2], data_id=batched_video_ids, ) batched_cropped_obj_pairs[vid].append(cropped_bb_pop_image) if len(selected_pairs) == 0: selected_pairs.append((0, -1, (-1, -1))) new_select_pairs[0] = [(0, -1, (-1, -1))] dummy_img = batched_videos[0] batched_cropped_obj_pairs[0] = [dummy_img] batched_image_binary_probs: List[ Dict[Tuple[int, Tuple[int, int], str], torch.Tensor] ] = [] batched_obj_pair_features: Dict[int, torch.Tensor] = { vid: torch.tensor([]) for vid in range(batch_size) } if len(batched_cropped_obj_pairs) == 0: batched_image_binary_probs.append({}) else: for vid, binary_nl_features in enumerate(batched_binary_nl_features): if len(binary_nl_features) == 0: batched_image_binary_probs.append({}) continue binary_kws = batched_binary_kws[vid] cropped_obj_pairs = batched_cropped_obj_pairs[vid] if len(cropped_obj_pairs) == 0: batched_image_binary_probs.append({}) continue inputs = self.clip_processor( images=cropped_obj_pairs, return_tensors="pt" ).to(self._device) obj_features = self._image_features_checkpoint( self.clip_binary_model, inputs["pixel_values"] ) batched_obj_pair_features[vid] = obj_features obj_clip_features = obj_features / obj_features.norm( p=2, dim=-1, keepdim=True ) binary_nl_features = binary_nl_features / binary_nl_features.norm( p=2, dim=-1, keepdim=True ) logit_scale = self.clip_binary_model.logit_scale binary_logits_per_text = torch.matmul( binary_nl_features, obj_clip_features.t() ) * logit_scale.exp() if not output_logit: if not multi_class: binary_logits_per_text = binary_logits_per_text.softmax(dim=0) else: binary_logits_per_text = binary_logits_per_text.sigmoid() binary_prob_per_obj: Dict[ Tuple[int, Tuple[int, int], str], torch.Tensor ] = {} for binary_name, probs in zip(binary_kws, binary_logits_per_text): if binary_name == dummy_str: dummy_prob += probs.sum() else: for prob, (vid_, fid, obj_pair) in zip( probs, new_select_pairs[vid] ): if fid == -1: dummy_prob += prob else: binary_prob_per_obj[(fid, obj_pair, binary_name)] = prob batched_image_binary_probs.append(binary_prob_per_obj) result: Dict[str, Any] = { "categorical_probs": batched_image_cate_probs, "unary_probs": batched_image_unary_probs, "binary_probs": batched_image_binary_probs, "dummy_prob": dummy_prob, } if output_embeddings: embeddings_dict = { "cate_obj_clip_features": batched_obj_cate_features, "cate_object_ids": batched_object_ids_lookup, "unary_obj_clip_features": batched_obj_unary_features, "unary_object_ids": batched_object_ids_lookup, "binary_obj_pair_features": batched_obj_pair_features, "binary_object_pairs": new_select_pairs, } result["embeddings"] = embeddings_dict if return_flattened_segments or return_valid_pairs: flattened = flatten_segments_for_batch( video_id=0, segments=masks, bbox_min_dim=self.config.bbox_min_dim, ) if return_flattened_segments: result["flattened_segments"] = flattened if return_valid_pairs: interested_pairs = ( interested_object_pairs if interested_object_pairs else None ) result["valid_pairs"] = extract_valid_object_pairs( flattened["object_ids"], interested_pairs, ) if interested_pairs is None: result["valid_pairs_metadata"] = {"pair_source": "all_pairs"} else: result["valid_pairs_metadata"] = { "pair_source": "filtered", "requested_pairs": interested_object_pairs, } return result # ------------------------------------------------------------------ # # Helpers # ------------------------------------------------------------------ # def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray: if torch.is_tensor(frame): frame_np = frame.detach().cpu().numpy() else: frame_np = np.asarray(frame) return np.ascontiguousarray(frame_np) def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray: if torch.is_tensor(mask): mask_np = mask.detach().cpu().numpy() else: mask_np = np.asarray(mask) if mask_np.ndim == 3: if mask_np.shape[0] == 1: mask_np = mask_np.squeeze(0) elif mask_np.shape[2] == 1: mask_np = mask_np.squeeze(2) if mask_np.ndim != 2: raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}") return mask_np.astype(bool, copy=False) def _extract_text_features(self, model, keywords: List[str]): tokens = self.clip_tokenizer( keywords, return_tensors="pt", max_length=75, truncation=True, padding="max_length", ).to(self._device) return self._text_features_checkpoint(model, tokens) def _extract_image_features(self, model, image): if torch.is_tensor(image): image = image.detach().cpu().numpy() elif isinstance(image, np.ndarray): pass inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device) return self._image_features_checkpoint(model, inputs["pixel_values"]) # ------------------------------------------------------------------ # # High-level predict API # ------------------------------------------------------------------ # def predict( self, video_frames: torch.Tensor, masks: Dict[int, Dict[int, torch.Tensor]], bboxes: Dict[int, Dict[int, List]], categorical_keywords: List[str], unary_keywords: Optional[List[str]] = None, binary_keywords: Optional[List[str]] = None, object_pairs: Optional[List[Tuple[int, int]]] = None, return_top_k: int = 3, return_flattened_segments: Optional[bool] = None, return_valid_pairs: Optional[bool] = None, interested_object_pairs: Optional[List[Tuple[int, int]]] = None, debug_visualizations: Optional[bool] = None, ) -> Dict[str, Any]: with torch.no_grad(): outputs = self.forward( video_frames=video_frames, masks=masks, bboxes=bboxes, categorical_keywords=categorical_keywords, unary_keywords=unary_keywords, binary_keywords=binary_keywords, object_pairs=object_pairs, return_flattened_segments=return_flattened_segments, return_valid_pairs=return_valid_pairs, interested_object_pairs=interested_object_pairs, debug_visualizations=debug_visualizations, ) formatted_categorical: Dict[int, List[Tuple[float, str]]] = {} for (obj_id, category), prob in outputs["categorical_probs"][0].items(): if obj_id not in formatted_categorical: formatted_categorical[obj_id] = [] prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) formatted_categorical[obj_id].append((prob_val, category)) for obj_id in formatted_categorical: formatted_categorical[obj_id] = sorted( formatted_categorical[obj_id], reverse=True )[:return_top_k] formatted_unary: Dict[Tuple[int, int], List[Tuple[float, str]]] = {} for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items(): key = (frame_id, obj_id) if key not in formatted_unary: formatted_unary[key] = [] prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) formatted_unary[key].append((prob_val, predicate)) for key in formatted_unary: formatted_unary[key] = sorted( formatted_unary[key], reverse=True )[:return_top_k] formatted_binary: Dict[Tuple[int, Tuple[int, int]], List[Tuple[float, str]]] = {} if len(outputs["binary_probs"]) > 0: for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items(): key = (frame_id, obj_pair) if key not in formatted_binary: formatted_binary[key] = [] prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob) formatted_binary[key].append((prob_val, predicate)) for key in formatted_binary: formatted_binary[key] = sorted( formatted_binary[key], reverse=True )[:return_top_k] def max_conf(d: Dict[Any, List[Tuple[float, str]]]) -> float: if not d: return 0.0 return max( (max((p for p, _ in preds), default=0.0) for preds in d.values()), default=0.0, ) result: Dict[str, Any] = { "categorical_predictions": formatted_categorical, "unary_predictions": formatted_unary, "binary_predictions": formatted_binary, "confidence_scores": { "categorical": max_conf(formatted_categorical), "unary": max_conf(formatted_unary), "binary": max_conf(formatted_binary), }, } if "flattened_segments" in outputs: result["flattened_segments"] = outputs["flattened_segments"] if "valid_pairs" in outputs: result["valid_pairs"] = outputs["valid_pairs"] if "valid_pairs_metadata" in outputs: result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"] return result