LASER / vine_hf /vine_model.py
ASethi04's picture
updates
f9a6349
raw
history blame
42.7 kB
import os
import sys
from typing import Dict, List, Tuple, Optional, Any, Union
import cv2
import numpy as np
import torch
from safetensors.torch import load_file
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor
from huggingface_hub import snapshot_download
from .vine_config import VineConfig
from laser.models import llava_clip_model_v3
sys.modules["llava_clip_model_v3"] = llava_clip_model_v3
from laser.models.model_utils import (
extract_single_object,
extract_object_subject,
crop_image_contain_bboxes,
segment_list,
)
from .flattening import (
extract_valid_object_pairs,
flatten_segments_for_batch,
)
from .vis_utils import save_mask_one_image
class VineModel(PreTrainedModel):
"""
VINE (Video Understanding with Natural Language) Model.
Internally, the core CLIP/text/image/pair logic mirrors
llava_clip_model_v3.PredicateModel as closely as possible for a single video,
with a small extension to re-normalize categorical probs after pooling.
"""
config_class = VineConfig
def __init__(self, config: VineConfig):
super().__init__(config)
self.config = config
self.visualize = getattr(config, "visualize", False)
self.visualization_dir = getattr(config, "visualization_dir", None)
self.debug_visualizations = getattr(config, "debug_visualizations", False)
self._device = getattr(config, "_device")
# CLIP components
self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if self.clip_tokenizer.pad_token is None:
self.clip_tokenizer.pad_token = (
self.clip_tokenizer.unk_token
if self.clip_tokenizer.unk_token
else self.clip_tokenizer.eos_token
)
self.clip_processor = AutoProcessor.from_pretrained(config.model_name)
self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
self.clip_unary_model = AutoModel.from_pretrained(config.model_name)
self.clip_binary_model = AutoModel.from_pretrained(config.model_name)
# Load fine-tuned weights if available
if config.use_hf_repo:
self._load_huggingface_vine_weights(config.model_repo, config.model_file)
else:
self._load_local_pretrained_vine_weights(
config.local_dir, config.local_filename
)
# Optionally reset categorical model to base CLIP (ignore fine-tune)
if not getattr(config, "use_pretrained_cate_weights", True):
self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
self.clip_cate_model.to(self._device)
self.to(self._device)
# ------------------------------------------------------------------ #
# Weight loading
# ------------------------------------------------------------------ #
def _load_huggingface_vine_weights(
self, model_repo: str, model_file: Optional[str] = None
):
try:
print(f"Loading VINE weights from HuggingFace repo: {model_repo}")
repo_path = snapshot_download(model_repo, revision=model_file or "main")
weights = load_file(os.path.join(repo_path, "model.safetensors"))
self.load_state_dict(weights, strict=False)
print("✓ Successfully loaded VINE weights from HuggingFace Hub")
return True
except Exception as e:
print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}")
print("Using base CLIP models instead")
return False
def _load_local_pretrained_vine_weights(
self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0
):
if local_dir is None and local_filename is None:
return False
full_path = (
os.path.join(local_dir, local_filename) if local_filename else local_dir
)
# .pkl – usually pickled PredicateModel
if isinstance(full_path, str) and full_path.endswith(".pkl"):
print(f"Loading VINE weights from: {full_path}")
loaded_vine_model = torch.load(
full_path, map_location=self._device, weights_only=False
)
print(f"Loaded state type: {type(loaded_vine_model)}")
if not isinstance(loaded_vine_model, dict):
if hasattr(loaded_vine_model, "clip_tokenizer"):
self.clip_tokenizer = loaded_vine_model.clip_tokenizer
if hasattr(loaded_vine_model, "clip_processor"):
self.clip_processor = loaded_vine_model.clip_processor
if hasattr(loaded_vine_model, "clip_cate_model"):
self.clip_cate_model.load_state_dict(
loaded_vine_model.clip_cate_model.state_dict()
)
if hasattr(loaded_vine_model, "clip_unary_model"):
self.clip_unary_model.load_state_dict(
loaded_vine_model.clip_unary_model.state_dict()
)
if hasattr(loaded_vine_model, "clip_binary_model"):
self.clip_binary_model.load_state_dict(
loaded_vine_model.clip_binary_model.state_dict()
)
print("✓ Loaded VINE weights from .pkl PredicateModel checkpoint")
return True
# .pt / .pth – plain state_dict
elif isinstance(full_path, str) and (
full_path.endswith(".pt") or full_path.endswith(".pth")
):
print(f"Loading VINE weights from: {full_path}")
state = torch.load(full_path, map_location=self._device, weights_only=True)
print(f"Loaded state type: {type(state)}")
self.load_state_dict(state, strict=False)
print("✓ Loaded VINE weights from state_dict")
return True
# .model – full PredicateModel object
elif isinstance(full_path, str) and full_path.endswith(".model"):
print(f"Loading VINE weights from: {full_path}")
pretrained_model = torch.load(
full_path, map_location="cpu", weights_only=False
)
if hasattr(pretrained_model, "clip_tokenizer"):
self.clip_tokenizer = pretrained_model.clip_tokenizer
if hasattr(pretrained_model, "clip_processor"):
self.clip_processor = pretrained_model.clip_processor
if hasattr(pretrained_model, "clip_cate_model"):
self.clip_cate_model.load_state_dict(
pretrained_model.clip_cate_model.state_dict()
)
if hasattr(pretrained_model, "clip_unary_model"):
self.clip_unary_model.load_state_dict(
pretrained_model.clip_unary_model.state_dict()
)
if hasattr(pretrained_model, "clip_binary_model"):
self.clip_binary_model.load_state_dict(
pretrained_model.clip_binary_model.state_dict()
)
print("✓ Loaded all sub-model weights from .model file")
return True
# directory of .model files
if isinstance(full_path, str) and os.path.isdir(full_path):
model_files = [
f for f in os.listdir(full_path) if f.endswith(f".{epoch}.model")
]
if model_files:
model_file = os.path.join(full_path, model_files[0])
print(f"Loading VINE weights from: {model_file}")
pretrained_model = torch.load(model_file, map_location="cpu")
if hasattr(pretrained_model, "clip_tokenizer"):
self.clip_tokenizer = pretrained_model.clip_tokenizer
if hasattr(pretrained_model, "clip_processor"):
self.clip_processor = pretrained_model.clip_processor
if hasattr(pretrained_model, "clip_cate_model"):
self.clip_cate_model.load_state_dict(
pretrained_model.clip_cate_model.state_dict()
)
if hasattr(pretrained_model, "clip_unary_model"):
self.clip_unary_model.load_state_dict(
pretrained_model.clip_unary_model.state_dict()
)
if hasattr(pretrained_model, "clip_binary_model"):
self.clip_binary_model.load_state_dict(
pretrained_model.clip_binary_model.state_dict()
)
print("✓ Loaded all sub-model weights from ensemble format")
return True
else:
print(f"No model file found for epoch {epoch} in {full_path}")
return False
print("Unsupported format for pretrained VINE path:", full_path)
return False
@classmethod
def from_pretrained_vine(
cls,
model_path: str,
config: Optional[VineConfig] = None,
epoch: int = 0,
**kwargs: Any,
):
if config is None:
if model_path and ("/" in model_path and not os.path.exists(model_path)):
config = VineConfig(use_hf_repo=True, model_repo=model_path)
else:
if os.path.isdir(model_path):
config = VineConfig(use_hf_repo=False, local_dir=model_path)
else:
config = VineConfig(
use_hf_repo=False,
local_dir=os.path.dirname(model_path) or None,
local_filename=os.path.basename(model_path) or None,
)
else:
if model_path and ("/" in model_path and not os.path.exists(model_path)):
config.use_hf_repo = True
config.model_repo = model_path
config.model_file = None
config.local_dir = None
config.local_filename = None
else:
config.use_hf_repo = False
if os.path.isdir(model_path):
config.local_dir = model_path
config.local_filename = None
else:
config.local_dir = os.path.dirname(model_path) or None
config.local_filename = os.path.basename(model_path) or None
model = cls(config, **kwargs)
return model
# ------------------------------------------------------------------ #
# Gradient checkpoint helpers
# ------------------------------------------------------------------ #
def _text_features_checkpoint(self, model, token_dict):
input_ids = token_dict["input_ids"]
attention_mask = token_dict["attention_mask"]
token_type_ids = token_dict.get("token_type_ids", None)
if token_type_ids is not None:
def forward_pass(input_ids, attention_mask, token_type_ids):
return model.get_text_features(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
return cp.checkpoint(
forward_pass,
input_ids,
attention_mask,
token_type_ids,
use_reentrant=False,
)
else:
def forward_pass(input_ids, attention_mask):
return model.get_text_features(
input_ids=input_ids,
attention_mask=attention_mask,
)
return cp.checkpoint(
forward_pass, input_ids, attention_mask, use_reentrant=False
)
def _image_features_checkpoint(self, model, pixel_values):
def forward_pass(pixel_values):
return model.get_image_features(pixel_values=pixel_values)
return cp.checkpoint(forward_pass, pixel_values, use_reentrant=False)
# ------------------------------------------------------------------ #
# CLIP similarity
# ------------------------------------------------------------------ #
def clip_sim(self, model, nl_feat, img_feat):
img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True)
logit_scale = getattr(model, "logit_scale", None)
logits_per_text = torch.matmul(nl_feat, img_feat.t())
if logit_scale is not None:
logits_per_text = logits_per_text * logit_scale.exp()
return logits_per_text
# ------------------------------------------------------------------ #
# Forward: single-video PredicateModel-style logic
# ------------------------------------------------------------------ #
def forward(
self,
video_frames: torch.Tensor,
masks: Dict[int, Dict[int, torch.Tensor]],
bboxes: Dict[int, Dict[int, List]],
categorical_keywords: List[str],
unary_keywords: Optional[List[str]] = None,
binary_keywords: Optional[List[str]] = None,
object_pairs: Optional[List[Tuple[int, int]]] = None,
return_flattened_segments: Optional[bool] = None,
return_valid_pairs: Optional[bool] = None,
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
debug_visualizations: Optional[bool] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if unary_keywords is None:
unary_keywords = []
if binary_keywords is None:
binary_keywords = []
if object_pairs is None:
object_pairs = []
if return_flattened_segments is None:
return_flattened_segments = getattr(
self.config, "return_flattened_segments", False
)
if return_valid_pairs is None:
return_valid_pairs = getattr(self.config, "return_valid_pairs", False)
if interested_object_pairs is None or len(interested_object_pairs) == 0:
interested_object_pairs = (
getattr(self.config, "interested_object_pairs", []) or []
)
if debug_visualizations is None:
debug_visualizations = self.debug_visualizations
alpha = getattr(self.config, "alpha", 0.5)
white_alpha = getattr(self.config, "white_alpha", 0.8)
topk_cate = kwargs.pop("topk_cate", getattr(self.config, "topk_cate", 3))
dummy_str = kwargs.pop("dummy_str", getattr(self.config, "dummy_str", "$$$"))
multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
output_embeddings = kwargs.pop("output_embeddings", False)
batched_video_ids = [0]
if torch.is_tensor(video_frames):
num_frames = video_frames.shape[0]
batched_videos = [
self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames)
]
else:
num_frames = len(video_frames)
batched_videos = [
self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames)
]
batched_masks: List[np.ndarray] = []
batched_bboxes: List[List[float]] = []
batched_object_ids: List[Tuple[int, int, int]] = []
for frame_id, frame_masks in masks.items():
if frame_id >= num_frames:
continue
frame_boxes = bboxes.get(frame_id, {})
for obj_id, mask in frame_masks.items():
if obj_id not in frame_boxes:
continue
bbox = frame_boxes[obj_id]
batched_object_ids.append((0, frame_id, obj_id))
batched_masks.append(self._mask_to_numpy(mask))
batched_bboxes.append(bbox)
batched_names = [list(categorical_keywords)]
batched_unary_kws = [list(unary_keywords)]
batched_binary_kws = [list(binary_keywords)]
batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
if object_pairs:
for frame_id, frame_masks in masks.items():
if frame_id >= num_frames:
continue
present_ids = set(frame_masks.keys())
for (from_oid, to_oid) in object_pairs:
if from_oid in present_ids and to_oid in present_ids:
batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
batched_video_splits = [0]
batched_binary_predicates = [None]
def fill_empty(batched_kw):
new_batched = []
for kw_ls in batched_kw:
if len(kw_ls) == 0:
new_batched.append([dummy_str])
else:
new_batched.append(list(kw_ls))
return new_batched
batched_names = fill_empty(batched_names)
batched_unary_kws = fill_empty(batched_unary_kws)
batched_binary_kws = fill_empty(batched_binary_kws)
dummy_prob = torch.tensor(0.0, device=self._device)
batched_obj_name_features = []
batched_unary_nl_features = []
batched_binary_nl_features = []
batched_object_ids_lookup: Dict[int, List[Tuple[int, int]]] = {0: []}
batch_size = len(batched_video_ids)
# Step 1: text features
for object_names, unary_kws, binary_kws in zip(
batched_names, batched_unary_kws, batched_binary_kws
):
if len(object_names) == 0:
batched_obj_name_features.append([])
else:
obj_tokens = self.clip_tokenizer(
object_names,
return_tensors="pt",
max_length=75,
truncation=True,
padding="max_length",
).to(self._device)
obj_feats = self._text_features_checkpoint(
self.clip_cate_model, obj_tokens
)
batched_obj_name_features.append(obj_feats)
if len(unary_kws) == 0:
batched_unary_nl_features.append([])
else:
unary_tokens = self.clip_tokenizer(
list(unary_kws),
return_tensors="pt",
max_length=75,
truncation=True,
padding="max_length",
).to(self._device)
unary_feats = self._text_features_checkpoint(
self.clip_unary_model, unary_tokens
)
batched_unary_nl_features.append(unary_feats)
if len(binary_kws) == 0:
batched_binary_nl_features.append([])
else:
binary_tokens = self.clip_tokenizer(
list(binary_kws),
return_tensors="pt",
max_length=75,
truncation=True,
padding="max_length",
).to(self._device)
binary_feats = self._text_features_checkpoint(
self.clip_binary_model, binary_tokens
)
batched_binary_nl_features.append(binary_feats)
# Step 2: crop objects
batched_frame_masks: Dict[Tuple[int, int, int], np.ndarray] = {}
batched_frame_bboxes: Dict[Tuple[int, int, int], List[float]] = {}
batched_cropped_objs: Dict[int, List[np.ndarray]] = {
vid: [] for vid in range(batch_size)
}
assert len(batched_object_ids) > 0, f"No object bbox: {batched_video_ids}"
batched_video_splits = [0] + batched_video_splits
for (video_id, frame_id, obj_id), mask, bbox in zip(
batched_object_ids, batched_masks, batched_bboxes
):
overall_frame_id = batched_video_splits[video_id] + frame_id
object_img = extract_single_object(
batched_videos[overall_frame_id], mask, white_alpha
)
cropped_object_img = crop_image_contain_bboxes(
object_img, [bbox], batched_video_ids
)
if self.visualization_dir:
debug_crop_dir = os.path.join(self.visualization_dir, "debug_crops")
os.makedirs(debug_crop_dir, exist_ok=True)
cv2.imwrite(
os.path.join(debug_crop_dir, f"frame_{frame_id}_obj_{obj_id}.jpg"),
cv2.cvtColor(cropped_object_img, cv2.COLOR_RGB2BGR),
)
batched_frame_masks[(video_id, frame_id, obj_id)] = mask
batched_frame_bboxes[(video_id, frame_id, obj_id)] = bbox
batched_object_ids_lookup[video_id].append((frame_id, obj_id))
batched_cropped_objs[video_id].append(cropped_object_img)
# Step 3: categorical + unary
batched_image_unary_probs: Dict[int, Dict] = {}
batched_image_cate_probs: Dict[int, Dict] = {}
batched_obj_cate_features: Dict[int, Any] = {}
batched_obj_unary_features: Dict[int, Any] = {}
batched_obj_per_cate: Dict[int, Dict[str, List[Tuple[torch.Tensor, int]]]] = {}
for vid in range(batch_size):
batched_image_unary_probs[vid] = {}
batched_image_cate_probs[vid] = {}
batched_obj_cate_features[vid] = {}
batched_obj_unary_features[vid] = {}
batched_obj_per_cate[vid] = {}
for vid_id, (
unary_nl_feats,
object_name_feats,
cate,
unary_pred,
binary_predicates,
) in enumerate(
zip(
batched_unary_nl_features,
batched_obj_name_features,
batched_names,
batched_unary_kws,
batched_binary_predicates,
)
):
cropped_objs = batched_cropped_objs[vid_id]
if len(cropped_objs) != 0:
inputs = self.clip_processor(
images=cropped_objs, return_tensors="pt"
).to(self._device)
cate_obj_clip_features = self._image_features_checkpoint(
self.clip_cate_model, inputs["pixel_values"]
)
unary_obj_clip_features = self._image_features_checkpoint(
self.clip_unary_model, inputs["pixel_values"]
)
batched_obj_unary_features[vid_id] = unary_obj_clip_features
batched_obj_cate_features[vid_id] = cate_obj_clip_features
else:
batched_obj_cate_features[vid_id] = torch.tensor([])
batched_obj_unary_features[vid_id] = torch.tensor([])
object_ids = batched_object_ids_lookup[vid_id]
# Categorical logits
if (
len(object_name_feats) == 0
or len(object_ids) == 0
or len(cropped_objs) == 0
):
cate_logits_per_text = torch.tensor([])
else:
cate_logits_per_text = self.clip_sim(
self.clip_cate_model, object_name_feats, cate_obj_clip_features
)
if not output_logit:
cate_logits_per_text = cate_logits_per_text.softmax(dim=0)
if not (
len(object_ids) == 0
or (
cate_logits_per_text.ndim == 2
and cate_logits_per_text.shape[1] == len(object_ids)
)
or len(object_name_feats) == 0
):
print("Object cate shape mismatch here")
assert (
len(object_name_feats) == 0
or len(object_ids) == 0
or (
cate_logits_per_text.ndim == 2
and cate_logits_per_text.shape[1] == len(object_ids)
)
), f"Mismatched object id and cate logic: {batched_video_ids}"
# Aggregate per object id across frames
cate_prob_per_obj: Dict[int, Dict[str, List[torch.Tensor]]] = {}
for cate_name, probs in zip(cate, cate_logits_per_text):
if cate_name == dummy_str:
dummy_prob += probs.sum()
else:
for prob, (fid, oid) in zip(probs, object_ids):
cate_prob_per_obj.setdefault(oid, {})
cate_prob_per_obj[oid].setdefault(cate_name, []).append(prob)
new_cate_prob_per_obj: Dict[Tuple[int, str], torch.Tensor] = {}
obj_per_cate: Dict[str, List[Tuple[torch.Tensor, int]]] = {}
for oid, object_cate_info in cate_prob_per_obj.items():
# Pool across frames per category
pooled: Dict[str, torch.Tensor] = {}
for cate_name, prob_list in object_cate_info.items():
stacked = torch.stack(prob_list)
if getattr(self.config, "categorical_pool", "mean") == "mean":
pooled_prob = stacked.mean()
else:
pooled_prob = stacked.max()
pooled[cate_name] = pooled_prob
if not pooled:
continue
# Renormalize across categories so they sum to 1 per object
probs_tensor = torch.stack(list(pooled.values()))
denom = probs_tensor.sum()
if denom.item() <= 0:
norm_tensor = torch.ones_like(probs_tensor) / len(pooled)
else:
norm_tensor = probs_tensor / denom
for (cate_name, _), norm_prob in zip(pooled.items(), norm_tensor):
obj_per_cate.setdefault(cate_name, []).append((norm_prob, oid))
new_cate_prob_per_obj[(oid, cate_name)] = norm_prob
for cate_name in obj_per_cate:
obj_per_cate[cate_name] = sorted(
obj_per_cate[cate_name], key=lambda x: x[0], reverse=True
)
# Unary
if len(unary_nl_feats) == 0 or len(cropped_objs) == 0:
unary_logits_per_text = torch.tensor([])
else:
unary_logits_per_text = self.clip_sim(
self.clip_unary_model, unary_nl_feats, unary_obj_clip_features
)
if not output_logit:
unary_logits_per_text = unary_logits_per_text.softmax(dim=0)
unary_prob_per_obj: Dict[Tuple[int, int, str], torch.Tensor] = {}
for unary_name, probs in zip(unary_pred, unary_logits_per_text):
if unary_name == dummy_str:
dummy_prob += probs.sum()
else:
for prob, (fid, oid) in zip(probs, object_ids):
unary_prob_per_obj[(fid, oid, unary_name)] = prob
batched_image_cate_probs[vid_id] = new_cate_prob_per_obj
batched_image_unary_probs[vid_id] = unary_prob_per_obj
batched_obj_per_cate[vid_id] = obj_per_cate
# Step 4: binary pairs
batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
current_info = (0, 0)
frame_splits[current_info] = {"start": 0}
batched_topk_cate_candidates: Dict[int, Dict[str, List[int]]] = {
video_id: {} for video_id in range(batch_size)
}
for video_id, obj_per_cate in batched_obj_per_cate.items():
topk_cate_candidates: Dict[str, List[int]] = {}
for cate_name, pred_oid_ls in obj_per_cate.items():
for _, oid in pred_oid_ls[:topk_cate]:
topk_cate_candidates.setdefault(cate_name, []).append(oid)
batched_topk_cate_candidates[video_id] = topk_cate_candidates
obj_pair_lookup: Dict[int, Dict[Tuple[int, int], List[int]]] = {
video_id: {} for video_id in range(len(batched_video_ids))
}
for (vid, fid, (from_oid, to_oid)) in batched_obj_pairs:
if (from_oid, to_oid) not in obj_pair_lookup[vid]:
obj_pair_lookup[vid][(from_oid, to_oid)] = []
obj_pair_lookup[vid][(from_oid, to_oid)].append(fid)
selected_pairs = set()
if batched_binary_predicates[0] is None:
selected_pairs = set(batched_obj_pairs)
else:
for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
if (
from_obj_name in topk_cate_candidates
and to_obj_name in topk_cate_candidates
):
from_oids = topk_cate_candidates[from_obj_name]
to_oids = topk_cate_candidates[to_obj_name]
for from_oid in from_oids:
for to_oid in to_oids:
if (
bp_vid in obj_pair_lookup
and (from_oid, to_oid) in obj_pair_lookup[bp_vid]
):
for fid in obj_pair_lookup[bp_vid][
(from_oid, to_oid)
]:
selected_pairs.add(
(bp_vid, fid, (from_oid, to_oid))
)
selected_pairs = list(selected_pairs)
new_select_pairs: Dict[int, List[Tuple[int, int, Tuple[int, int]]]] = {
video_id: [] for video_id in range(len(batched_video_ids))
}
for (vid, fid, (from_oid, to_oid)) in selected_pairs:
new_select_pairs[vid].append((vid, fid, (from_oid, to_oid)))
for vid in range(len(batched_video_ids)):
batched_cropped_obj_pairs[vid] = []
for (vid, fid, (from_id, to_id)) in selected_pairs:
if (vid, fid, from_id) not in batched_frame_masks or (
vid,
fid,
to_id,
) not in batched_frame_masks:
continue
if (vid, fid, from_id) not in batched_frame_bboxes or (
vid,
fid,
to_id,
) not in batched_frame_bboxes:
continue
overall_frame_id = batched_video_splits[vid] + fid
mask1 = batched_frame_masks[(vid, fid, from_id)]
mask2 = batched_frame_masks[(vid, fid, to_id)]
bbox1 = batched_frame_bboxes[(vid, fid, from_id)]
bbox2 = batched_frame_bboxes[(vid, fid, to_id)]
bb_pop_image = extract_object_subject(
batched_videos[overall_frame_id],
mask1,
mask2,
alpha=alpha,
white_alpha=white_alpha,
)
cropped_bb_pop_image = crop_image_contain_bboxes(
img=bb_pop_image,
bbox_ls=[bbox1, bbox2],
data_id=batched_video_ids,
)
batched_cropped_obj_pairs[vid].append(cropped_bb_pop_image)
if len(selected_pairs) == 0:
selected_pairs.append((0, -1, (-1, -1)))
new_select_pairs[0] = [(0, -1, (-1, -1))]
dummy_img = batched_videos[0]
batched_cropped_obj_pairs[0] = [dummy_img]
batched_image_binary_probs: List[
Dict[Tuple[int, Tuple[int, int], str], torch.Tensor]
] = []
batched_obj_pair_features: Dict[int, torch.Tensor] = {
vid: torch.tensor([]) for vid in range(batch_size)
}
if len(batched_cropped_obj_pairs) == 0:
batched_image_binary_probs.append({})
else:
for vid, binary_nl_features in enumerate(batched_binary_nl_features):
if len(binary_nl_features) == 0:
batched_image_binary_probs.append({})
continue
binary_kws = batched_binary_kws[vid]
cropped_obj_pairs = batched_cropped_obj_pairs[vid]
if len(cropped_obj_pairs) == 0:
batched_image_binary_probs.append({})
continue
inputs = self.clip_processor(
images=cropped_obj_pairs, return_tensors="pt"
).to(self._device)
obj_features = self._image_features_checkpoint(
self.clip_binary_model, inputs["pixel_values"]
)
batched_obj_pair_features[vid] = obj_features
obj_clip_features = obj_features / obj_features.norm(
p=2, dim=-1, keepdim=True
)
binary_nl_features = binary_nl_features / binary_nl_features.norm(
p=2, dim=-1, keepdim=True
)
logit_scale = self.clip_binary_model.logit_scale
binary_logits_per_text = torch.matmul(
binary_nl_features, obj_clip_features.t()
) * logit_scale.exp()
if not output_logit:
if not multi_class:
binary_logits_per_text = binary_logits_per_text.softmax(dim=0)
else:
binary_logits_per_text = binary_logits_per_text.sigmoid()
binary_prob_per_obj: Dict[
Tuple[int, Tuple[int, int], str], torch.Tensor
] = {}
for binary_name, probs in zip(binary_kws, binary_logits_per_text):
if binary_name == dummy_str:
dummy_prob += probs.sum()
else:
for prob, (vid_, fid, obj_pair) in zip(
probs, new_select_pairs[vid]
):
if fid == -1:
dummy_prob += prob
else:
binary_prob_per_obj[(fid, obj_pair, binary_name)] = prob
batched_image_binary_probs.append(binary_prob_per_obj)
result: Dict[str, Any] = {
"categorical_probs": batched_image_cate_probs,
"unary_probs": batched_image_unary_probs,
"binary_probs": batched_image_binary_probs,
"dummy_prob": dummy_prob,
}
if output_embeddings:
embeddings_dict = {
"cate_obj_clip_features": batched_obj_cate_features,
"cate_object_ids": batched_object_ids_lookup,
"unary_obj_clip_features": batched_obj_unary_features,
"unary_object_ids": batched_object_ids_lookup,
"binary_obj_pair_features": batched_obj_pair_features,
"binary_object_pairs": new_select_pairs,
}
result["embeddings"] = embeddings_dict
if return_flattened_segments or return_valid_pairs:
flattened = flatten_segments_for_batch(
video_id=0,
segments=masks,
bbox_min_dim=self.config.bbox_min_dim,
)
if return_flattened_segments:
result["flattened_segments"] = flattened
if return_valid_pairs:
interested_pairs = (
interested_object_pairs if interested_object_pairs else None
)
result["valid_pairs"] = extract_valid_object_pairs(
flattened["object_ids"],
interested_pairs,
)
if interested_pairs is None:
result["valid_pairs_metadata"] = {"pair_source": "all_pairs"}
else:
result["valid_pairs_metadata"] = {
"pair_source": "filtered",
"requested_pairs": interested_object_pairs,
}
return result
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
if torch.is_tensor(frame):
frame_np = frame.detach().cpu().numpy()
else:
frame_np = np.asarray(frame)
return np.ascontiguousarray(frame_np)
def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
if torch.is_tensor(mask):
mask_np = mask.detach().cpu().numpy()
else:
mask_np = np.asarray(mask)
if mask_np.ndim == 3:
if mask_np.shape[0] == 1:
mask_np = mask_np.squeeze(0)
elif mask_np.shape[2] == 1:
mask_np = mask_np.squeeze(2)
if mask_np.ndim != 2:
raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}")
return mask_np.astype(bool, copy=False)
def _extract_text_features(self, model, keywords: List[str]):
tokens = self.clip_tokenizer(
keywords,
return_tensors="pt",
max_length=75,
truncation=True,
padding="max_length",
).to(self._device)
return self._text_features_checkpoint(model, tokens)
def _extract_image_features(self, model, image):
if torch.is_tensor(image):
image = image.detach().cpu().numpy()
elif isinstance(image, np.ndarray):
pass
inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
return self._image_features_checkpoint(model, inputs["pixel_values"])
# ------------------------------------------------------------------ #
# High-level predict API
# ------------------------------------------------------------------ #
def predict(
self,
video_frames: torch.Tensor,
masks: Dict[int, Dict[int, torch.Tensor]],
bboxes: Dict[int, Dict[int, List]],
categorical_keywords: List[str],
unary_keywords: Optional[List[str]] = None,
binary_keywords: Optional[List[str]] = None,
object_pairs: Optional[List[Tuple[int, int]]] = None,
return_top_k: int = 3,
return_flattened_segments: Optional[bool] = None,
return_valid_pairs: Optional[bool] = None,
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
debug_visualizations: Optional[bool] = None,
) -> Dict[str, Any]:
with torch.no_grad():
outputs = self.forward(
video_frames=video_frames,
masks=masks,
bboxes=bboxes,
categorical_keywords=categorical_keywords,
unary_keywords=unary_keywords,
binary_keywords=binary_keywords,
object_pairs=object_pairs,
return_flattened_segments=return_flattened_segments,
return_valid_pairs=return_valid_pairs,
interested_object_pairs=interested_object_pairs,
debug_visualizations=debug_visualizations,
)
formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
for (obj_id, category), prob in outputs["categorical_probs"][0].items():
if obj_id not in formatted_categorical:
formatted_categorical[obj_id] = []
prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
formatted_categorical[obj_id].append((prob_val, category))
for obj_id in formatted_categorical:
formatted_categorical[obj_id] = sorted(
formatted_categorical[obj_id], reverse=True
)[:return_top_k]
formatted_unary: Dict[Tuple[int, int], List[Tuple[float, str]]] = {}
for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items():
key = (frame_id, obj_id)
if key not in formatted_unary:
formatted_unary[key] = []
prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
formatted_unary[key].append((prob_val, predicate))
for key in formatted_unary:
formatted_unary[key] = sorted(
formatted_unary[key], reverse=True
)[:return_top_k]
formatted_binary: Dict[Tuple[int, Tuple[int, int]], List[Tuple[float, str]]] = {}
if len(outputs["binary_probs"]) > 0:
for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items():
key = (frame_id, obj_pair)
if key not in formatted_binary:
formatted_binary[key] = []
prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
formatted_binary[key].append((prob_val, predicate))
for key in formatted_binary:
formatted_binary[key] = sorted(
formatted_binary[key], reverse=True
)[:return_top_k]
def max_conf(d: Dict[Any, List[Tuple[float, str]]]) -> float:
if not d:
return 0.0
return max(
(max((p for p, _ in preds), default=0.0) for preds in d.values()),
default=0.0,
)
result: Dict[str, Any] = {
"categorical_predictions": formatted_categorical,
"unary_predictions": formatted_unary,
"binary_predictions": formatted_binary,
"confidence_scores": {
"categorical": max_conf(formatted_categorical),
"unary": max_conf(formatted_unary),
"binary": max_conf(formatted_binary),
},
}
if "flattened_segments" in outputs:
result["flattened_segments"] = outputs["flattened_segments"]
if "valid_pairs" in outputs:
result["valid_pairs"] = outputs["valid_pairs"]
if "valid_pairs_metadata" in outputs:
result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"]
return result