Spaces:
Running
on
T4
Running
on
T4
Commit
·
f9a6349
0
Parent(s):
updates
Browse files- .gitattributes +35 -0
- GroundingDINO_SwinT_OGC.py +43 -0
- README.md +12 -0
- app.py +288 -0
- groundingdino_swint_ogc.pth +3 -0
- requirements.txt +24 -0
- sam2_hiera_t.yaml +118 -0
- sam2_hiera_tiny.pt +3 -0
- vine_hf/OVERVIEW.md +218 -0
- vine_hf/README.md +355 -0
- vine_hf/README_HF.md +345 -0
- vine_hf/__init__.py +23 -0
- vine_hf/convert_inference.py +288 -0
- vine_hf/example_ensemble_weights.py +333 -0
- vine_hf/example_sam2_masks.py +331 -0
- vine_hf/example_usage.ipynb +310 -0
- vine_hf/example_usage.py +283 -0
- vine_hf/example_visualization.py +146 -0
- vine_hf/example_with_pretrained_vine.py +287 -0
- vine_hf/flattening.py +124 -0
- vine_hf/push_to_hub.py +232 -0
- vine_hf/push_to_video_fm.py +274 -0
- vine_hf/setup.py +73 -0
- vine_hf/vine_config.py +86 -0
- vine_hf/vine_model.py +1001 -0
- vine_hf/vine_pipeline.py +923 -0
- vine_hf/vis_utils.py +941 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
GroundingDINO_SwinT_OGC.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size = 1
|
| 2 |
+
modelname = "groundingdino"
|
| 3 |
+
backbone = "swin_T_224_1k"
|
| 4 |
+
position_embedding = "sine"
|
| 5 |
+
pe_temperatureH = 20
|
| 6 |
+
pe_temperatureW = 20
|
| 7 |
+
return_interm_indices = [1, 2, 3]
|
| 8 |
+
backbone_freeze_keywords = None
|
| 9 |
+
enc_layers = 6
|
| 10 |
+
dec_layers = 6
|
| 11 |
+
pre_norm = False
|
| 12 |
+
dim_feedforward = 2048
|
| 13 |
+
hidden_dim = 256
|
| 14 |
+
dropout = 0.0
|
| 15 |
+
nheads = 8
|
| 16 |
+
num_queries = 900
|
| 17 |
+
query_dim = 4
|
| 18 |
+
num_patterns = 0
|
| 19 |
+
num_feature_levels = 4
|
| 20 |
+
enc_n_points = 4
|
| 21 |
+
dec_n_points = 4
|
| 22 |
+
two_stage_type = "standard"
|
| 23 |
+
two_stage_bbox_embed_share = False
|
| 24 |
+
two_stage_class_embed_share = False
|
| 25 |
+
transformer_activation = "relu"
|
| 26 |
+
dec_pred_bbox_embed_share = True
|
| 27 |
+
dn_box_noise_scale = 1.0
|
| 28 |
+
dn_label_noise_ratio = 0.5
|
| 29 |
+
dn_label_coef = 1.0
|
| 30 |
+
dn_bbox_coef = 1.0
|
| 31 |
+
embed_init_tgt = True
|
| 32 |
+
dn_labelbook_size = 2000
|
| 33 |
+
max_text_len = 256
|
| 34 |
+
text_encoder_type = "bert-base-uncased"
|
| 35 |
+
use_text_enhancer = True
|
| 36 |
+
use_fusion_layer = True
|
| 37 |
+
use_checkpoint = True
|
| 38 |
+
use_transformer_ckpt = True
|
| 39 |
+
use_text_cross_attention = True
|
| 40 |
+
text_dropout = 0.0
|
| 41 |
+
fusion_dropout = 0.0
|
| 42 |
+
fusion_droppath = 0.1
|
| 43 |
+
sub_sentence_present = True
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LASER Demo
|
| 3 |
+
emoji: 🐠
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.0.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from collections.abc import Mapping, Sequence
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
import inspect
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
import spaces # <-- ZeroGPU integration
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import pipeline # not strictly necessary, but fine
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# -----------------------------
|
| 17 |
+
# Environment / diagnostics
|
| 18 |
+
# -----------------------------
|
| 19 |
+
os.environ["GRADIO_TEMP_DIR"] = str(Path(__file__).parent / "gradio_temp")
|
| 20 |
+
os.environ["OPENAI_API_KEY"] = "test"
|
| 21 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
| 22 |
+
|
| 23 |
+
print("All imports finished")
|
| 24 |
+
print(f"Python version: {sys.version}")
|
| 25 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 26 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 27 |
+
print(f"CUDA version: {torch.version.cuda}")
|
| 28 |
+
print(f"cuDNN version: {torch.backends.cudnn.version()}")
|
| 29 |
+
print(f"Number of GPUs: {torch.cuda.device_count()}")
|
| 30 |
+
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
for i in range(torch.cuda.device_count()):
|
| 33 |
+
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 34 |
+
print(
|
| 35 |
+
f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 39 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 40 |
+
os.environ["TORCH_DTYPE"] = "float32"
|
| 41 |
+
torch.set_default_dtype(torch.float32)
|
| 42 |
+
|
| 43 |
+
current_dir = Path(__file__).resolve().parent
|
| 44 |
+
# For Spaces, assume checkpoints live alongside app.py or in a "checkpoints" subdir.
|
| 45 |
+
# If you keep them next to app.py locally, this still works.
|
| 46 |
+
sam_config_path = str(current_dir / "sam2_hiera_t.yaml")
|
| 47 |
+
sam_checkpoint_path = str(current_dir / "sam2_hiera_tiny.pt")
|
| 48 |
+
gd_config_path = str(current_dir / "GroundingDINO_SwinT_OGC.py")
|
| 49 |
+
gd_checkpoint_path = str(current_dir / "groundingdino_swint_ogc.pth")
|
| 50 |
+
visualization_dir = str(current_dir / "outputs")
|
| 51 |
+
print(
|
| 52 |
+
f"Setting up paths: {sam_config_path}, {sam_checkpoint_path}, {gd_config_path}, {gd_checkpoint_path}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@lru_cache(maxsize=1)
|
| 57 |
+
def _load_vine_pipeline():
|
| 58 |
+
"""
|
| 59 |
+
Lazy-load and cache the Vine pipeline so we don't re-download/rebuild it on every request.
|
| 60 |
+
"""
|
| 61 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 62 |
+
|
| 63 |
+
config = VineConfig(
|
| 64 |
+
segmentation_method="grounding_dino_sam2",
|
| 65 |
+
model_name="openai/clip-vit-base-patch32",
|
| 66 |
+
use_hf_repo=True,
|
| 67 |
+
model_repo="KevinX-Penn28/testing",
|
| 68 |
+
box_threshold=0.35,
|
| 69 |
+
text_threshold=0.25,
|
| 70 |
+
target_fps=1, # default 1 FPS
|
| 71 |
+
topk_cate=5,
|
| 72 |
+
white_alpha=0.3,
|
| 73 |
+
visualization_dir=visualization_dir,
|
| 74 |
+
visualize=True,
|
| 75 |
+
debug_visualizations=False,
|
| 76 |
+
device="cuda",
|
| 77 |
+
categorical_pool="max",
|
| 78 |
+
)
|
| 79 |
+
model = VineModel(config)
|
| 80 |
+
return VinePipeline(
|
| 81 |
+
model=model,
|
| 82 |
+
tokenizer=None,
|
| 83 |
+
sam_config_path=sam_config_path,
|
| 84 |
+
sam_checkpoint_path=sam_checkpoint_path,
|
| 85 |
+
gd_config_path=gd_config_path,
|
| 86 |
+
gd_checkpoint_path=gd_checkpoint_path,
|
| 87 |
+
device="cuda",
|
| 88 |
+
trust_remote_code=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@spaces.GPU(duration=300) # Up to ~5 minutes of H200 ZeroGPU time per call
|
| 93 |
+
def process_video(
|
| 94 |
+
video_file,
|
| 95 |
+
categorical_keywords,
|
| 96 |
+
unary_keywords,
|
| 97 |
+
binary_keywords,
|
| 98 |
+
object_pairs,
|
| 99 |
+
output_fps,
|
| 100 |
+
box_threshold,
|
| 101 |
+
text_threshold,
|
| 102 |
+
):
|
| 103 |
+
vine_pipe = _load_vine_pipeline()
|
| 104 |
+
|
| 105 |
+
# Normalize incoming video input to a file path
|
| 106 |
+
if isinstance(video_file, dict):
|
| 107 |
+
video_file = (
|
| 108 |
+
video_file.get("name")
|
| 109 |
+
or video_file.get("filepath")
|
| 110 |
+
or video_file.get("data")
|
| 111 |
+
)
|
| 112 |
+
if not isinstance(video_file, (str, Path)):
|
| 113 |
+
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
| 114 |
+
|
| 115 |
+
categorical_keywords = (
|
| 116 |
+
[kw.strip() for kw in categorical_keywords.split(",")]
|
| 117 |
+
if categorical_keywords
|
| 118 |
+
else []
|
| 119 |
+
)
|
| 120 |
+
unary_keywords = (
|
| 121 |
+
[kw.strip() for kw in unary_keywords.split(",")] if unary_keywords else []
|
| 122 |
+
)
|
| 123 |
+
binary_keywords = (
|
| 124 |
+
[kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
|
| 125 |
+
)
|
| 126 |
+
object_pairs = (
|
| 127 |
+
[tuple(map(int, pair.split("-"))) for pair in object_pairs.split(",")]
|
| 128 |
+
if object_pairs
|
| 129 |
+
else []
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
results = vine_pipe(
|
| 133 |
+
inputs=video_file,
|
| 134 |
+
categorical_keywords=categorical_keywords,
|
| 135 |
+
unary_keywords=unary_keywords,
|
| 136 |
+
binary_keywords=binary_keywords,
|
| 137 |
+
object_pairs=object_pairs,
|
| 138 |
+
segmentation_method="grounding_dino_sam2",
|
| 139 |
+
return_top_k=5,
|
| 140 |
+
include_visualizations=True,
|
| 141 |
+
debug_visualizations=False,
|
| 142 |
+
device="cuda",
|
| 143 |
+
box_threshold=box_threshold,
|
| 144 |
+
text_threshold=text_threshold,
|
| 145 |
+
target_fps=output_fps,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
vine_pipe.box_threshold = box_threshold
|
| 149 |
+
vine_pipe.text_threshold = text_threshold
|
| 150 |
+
vine_pipe.target_fps = output_fps
|
| 151 |
+
|
| 152 |
+
if isinstance(results, Mapping):
|
| 153 |
+
results_dict = results
|
| 154 |
+
elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping):
|
| 155 |
+
results_dict = results[0]
|
| 156 |
+
else:
|
| 157 |
+
results_dict = {}
|
| 158 |
+
|
| 159 |
+
visualizations = results_dict.get("visualizations") or {}
|
| 160 |
+
vine = visualizations.get("vine") or {}
|
| 161 |
+
all_vis = vine.get("all") or {}
|
| 162 |
+
result_video_path = all_vis.get("video_path")
|
| 163 |
+
if not result_video_path:
|
| 164 |
+
candidates = sorted(
|
| 165 |
+
Path(visualization_dir).rglob("*.mp4"),
|
| 166 |
+
key=lambda p: p.stat().st_mtime,
|
| 167 |
+
reverse=True,
|
| 168 |
+
)
|
| 169 |
+
result_video_path = str(candidates[0]) if candidates else None
|
| 170 |
+
summary = results_dict.get("summary") or {}
|
| 171 |
+
|
| 172 |
+
if result_video_path and os.path.exists(result_video_path):
|
| 173 |
+
gradio_tmp = Path(
|
| 174 |
+
os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
|
| 175 |
+
) / "vine_outputs"
|
| 176 |
+
gradio_tmp.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
dest_path = gradio_tmp / Path(result_video_path).name
|
| 178 |
+
try:
|
| 179 |
+
shutil.copyfile(result_video_path, dest_path)
|
| 180 |
+
video_path_for_ui = str(dest_path)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"Warning: failed to copy video to Gradio temp dir: {e}")
|
| 183 |
+
video_path_for_ui = str(result_video_path)
|
| 184 |
+
else:
|
| 185 |
+
video_path_for_ui = None
|
| 186 |
+
print(
|
| 187 |
+
"Warning: annotated video not found or empty; check visualization settings."
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return video_path_for_ui, summary
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _video_component(label: str, *, is_output: bool = False):
|
| 194 |
+
"""
|
| 195 |
+
Build a Gradio Video component that is compatible with older Gradio versions
|
| 196 |
+
(no `type`/`sources`/`format` kwargs) and newer ones when available.
|
| 197 |
+
"""
|
| 198 |
+
kwargs = {"label": label}
|
| 199 |
+
sig = inspect.signature(gr.Video.__init__)
|
| 200 |
+
|
| 201 |
+
# Only set format for OUTPUT components
|
| 202 |
+
if is_output and "format" in sig.parameters:
|
| 203 |
+
kwargs["format"] = "mp4"
|
| 204 |
+
|
| 205 |
+
if not is_output:
|
| 206 |
+
if "type" in sig.parameters:
|
| 207 |
+
kwargs["type"] = "filepath"
|
| 208 |
+
if "sources" in sig.parameters:
|
| 209 |
+
kwargs["sources"] = ["upload"]
|
| 210 |
+
|
| 211 |
+
if is_output and "autoplay" in sig.parameters:
|
| 212 |
+
kwargs["autoplay"] = True
|
| 213 |
+
|
| 214 |
+
return gr.Video(**kwargs)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _create_blocks():
|
| 218 |
+
"""
|
| 219 |
+
Build a Blocks context that works across Gradio versions.
|
| 220 |
+
"""
|
| 221 |
+
blocks_kwargs = {"title": "VINE Demo"}
|
| 222 |
+
soft_theme = None
|
| 223 |
+
|
| 224 |
+
if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
|
| 225 |
+
try:
|
| 226 |
+
soft_theme = gr.themes.Soft()
|
| 227 |
+
except Exception:
|
| 228 |
+
soft_theme = None
|
| 229 |
+
|
| 230 |
+
if "theme" in inspect.signature(gr.Blocks).parameters and soft_theme is not None:
|
| 231 |
+
blocks_kwargs["theme"] = soft_theme
|
| 232 |
+
|
| 233 |
+
return gr.Blocks(**blocks_kwargs)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Create Gradio interface
|
| 237 |
+
with _create_blocks() as demo:
|
| 238 |
+
video_input = _video_component("Upload Video", is_output=False)
|
| 239 |
+
categorical_input = gr.Textbox(
|
| 240 |
+
label="Categorical Keywords (comma-separated)",
|
| 241 |
+
value="person, car, tree, background",
|
| 242 |
+
)
|
| 243 |
+
unary_input = gr.Textbox(
|
| 244 |
+
label="Unary Keywords (comma-separated)", value="walking, running, standing"
|
| 245 |
+
)
|
| 246 |
+
binary_input = gr.Textbox(
|
| 247 |
+
label="Binary Keywords (comma-separated)",
|
| 248 |
+
placeholder="e.g., chasing, carrying",
|
| 249 |
+
)
|
| 250 |
+
pairs_input = gr.Textbox(
|
| 251 |
+
label="Object Pairs (comma-separated indices)",
|
| 252 |
+
placeholder="e.g., 0-1,0-2 for pairs of objects",
|
| 253 |
+
)
|
| 254 |
+
fps_input = gr.Number(
|
| 255 |
+
label="Output FPS (affects processing speed)", value=1 # default 1 FPS
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 259 |
+
box_threshold_input = gr.Slider(
|
| 260 |
+
label="Box Threshold", minimum=0.1, maximum=0.9, value=0.35, step=0.05
|
| 261 |
+
)
|
| 262 |
+
text_threshold_input = gr.Slider(
|
| 263 |
+
label="Text Threshold", minimum=0.1, maximum=0.9, value=0.25, step=0.05
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
submit_btn = gr.Button("Process Video", variant="primary")
|
| 267 |
+
|
| 268 |
+
video_output = _video_component("Output Video with Annotations", is_output=True)
|
| 269 |
+
json_output = gr.JSON(label="Summary of Detected Events")
|
| 270 |
+
|
| 271 |
+
submit_btn.click(
|
| 272 |
+
fn=process_video,
|
| 273 |
+
inputs=[
|
| 274 |
+
video_input,
|
| 275 |
+
categorical_input,
|
| 276 |
+
unary_input,
|
| 277 |
+
binary_input,
|
| 278 |
+
pairs_input,
|
| 279 |
+
fps_input,
|
| 280 |
+
box_threshold_input,
|
| 281 |
+
text_threshold_input,
|
| 282 |
+
],
|
| 283 |
+
outputs=[video_output, json_output],
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
print("Got to main")
|
| 288 |
+
demo.launch(share=True, debug=True)
|
groundingdino_swint_ogc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799
|
| 3 |
+
size 693997677
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
spaces>=0.24.0
|
| 3 |
+
|
| 4 |
+
transformers>=4.40.0
|
| 5 |
+
huggingface-hub>=0.23.0
|
| 6 |
+
safetensors>=0.4.2
|
| 7 |
+
accelerate>=0.30.0
|
| 8 |
+
|
| 9 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 10 |
+
torch==2.2.1
|
| 11 |
+
torchvision==0.17.1
|
| 12 |
+
|
| 13 |
+
numpy
|
| 14 |
+
opencv-python
|
| 15 |
+
pillow
|
| 16 |
+
matplotlib
|
| 17 |
+
seaborn
|
| 18 |
+
pandas
|
| 19 |
+
tqdm
|
| 20 |
+
scikit-learn
|
| 21 |
+
|
| 22 |
+
-e git+https://github.com/video-fm/video-sam2.git#egg=video_sam2
|
| 23 |
+
-e git+https://github.com/IDEA-Research/GroundingDINO.git#egg=GroundingDINO
|
| 24 |
+
-e git+https://github.com/kevinxuez/LASER.git#egg=laser
|
sam2_hiera_t.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 118 |
+
compile_image_encoder: False
|
sam2_hiera_tiny.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65b50056e05bcb13694174f51bb6da89c894b57b75ccdf0ba6352c597c5d1125
|
| 3 |
+
size 155906050
|
vine_hf/OVERVIEW.md
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VINE HuggingFace Interface - Complete Overview
|
| 2 |
+
|
| 3 |
+
This directory contains a complete HuggingFace-compatible interface for the VINE (Video Understanding with Natural Language) model. The interface allows you to easily use, share, and deploy your VINE model through the HuggingFace ecosystem.
|
| 4 |
+
|
| 5 |
+
## 📁 Directory Structure
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
vine_hf/
|
| 9 |
+
├── __init__.py # Package initialization and exports
|
| 10 |
+
├── vine_config.py # VineConfig class (PretrainedConfig)
|
| 11 |
+
├── vine_model.py # VineModel class (PreTrainedModel)
|
| 12 |
+
├── vine_pipeline.py # VinePipeline class (Pipeline)
|
| 13 |
+
├── example_usage.py # Comprehensive usage examples
|
| 14 |
+
├── convert_inference.py # Migration guide from inference.py
|
| 15 |
+
├── push_to_hub.py # Script to push model to HF Hub
|
| 16 |
+
├── setup.py # Package setup configuration
|
| 17 |
+
├── README.md # Detailed documentation
|
| 18 |
+
└── OVERVIEW.md # This file
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## 🏗️ Architecture Components
|
| 22 |
+
|
| 23 |
+
### 1. VineConfig (`vine_config.py`)
|
| 24 |
+
- Inherits from `PretrainedConfig`
|
| 25 |
+
- Configures model parameters, segmentation methods, and processing options
|
| 26 |
+
- Compatible with HuggingFace configuration system
|
| 27 |
+
|
| 28 |
+
### 2. VineModel (`vine_model.py`)
|
| 29 |
+
- Inherits from `PreTrainedModel`
|
| 30 |
+
- Implements the core VINE model with three CLIP backbones
|
| 31 |
+
- Supports categorical, unary, and binary predictions
|
| 32 |
+
- Provides both `forward()` and `predict()` methods
|
| 33 |
+
|
| 34 |
+
### 3. VinePipeline (`vine_pipeline.py`)
|
| 35 |
+
- Inherits from `Pipeline`
|
| 36 |
+
- Handles end-to-end video processing workflow
|
| 37 |
+
- Integrates segmentation (SAM2, Grounding DINO + SAM2)
|
| 38 |
+
- Provides user-friendly interface for video understanding
|
| 39 |
+
|
| 40 |
+
## 🚀 Key Features
|
| 41 |
+
|
| 42 |
+
✅ **Full HuggingFace Compatibility**
|
| 43 |
+
- Compatible with `transformers` library
|
| 44 |
+
- Supports `AutoModel` and `pipeline` interfaces
|
| 45 |
+
- Can be pushed to and loaded from HuggingFace Hub
|
| 46 |
+
|
| 47 |
+
✅ **Flexible Segmentation**
|
| 48 |
+
- Support for SAM2 automatic segmentation
|
| 49 |
+
- Support for Grounding DINO + SAM2 text-guided segmentation
|
| 50 |
+
- Configurable thresholds and parameters
|
| 51 |
+
|
| 52 |
+
✅ **Multi-Modal Understanding**
|
| 53 |
+
- Categorical classification (object types)
|
| 54 |
+
- Unary predicates (single object actions)
|
| 55 |
+
- Binary relations (object-object relationships)
|
| 56 |
+
|
| 57 |
+
✅ **Easy Integration**
|
| 58 |
+
- Simple pipeline interface for end users
|
| 59 |
+
- Direct model access for researchers
|
| 60 |
+
- Comprehensive configuration options
|
| 61 |
+
|
| 62 |
+
## 📖 Usage Examples
|
| 63 |
+
|
| 64 |
+
### Quick Start with Pipeline
|
| 65 |
+
```python
|
| 66 |
+
from transformers import pipeline
|
| 67 |
+
from vine_hf import VineModel, VinePipeline
|
| 68 |
+
|
| 69 |
+
# Create pipeline
|
| 70 |
+
vine_pipeline = pipeline(
|
| 71 |
+
"vine-video-understanding",
|
| 72 |
+
model="your-username/vine-model",
|
| 73 |
+
trust_remote_code=True
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Process video
|
| 77 |
+
results = vine_pipeline(
|
| 78 |
+
"video.mp4",
|
| 79 |
+
categorical_keywords=['human', 'dog', 'frisbee'],
|
| 80 |
+
unary_keywords=['running', 'jumping'],
|
| 81 |
+
binary_keywords=['chasing', 'behind']
|
| 82 |
+
)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Direct Model Usage
|
| 86 |
+
```python
|
| 87 |
+
from vine_hf import VineConfig, VineModel
|
| 88 |
+
|
| 89 |
+
config = VineConfig(segmentation_method="grounding_dino_sam2")
|
| 90 |
+
model = VineModel(config)
|
| 91 |
+
|
| 92 |
+
results = model.predict(
|
| 93 |
+
video_frames=video_tensor,
|
| 94 |
+
masks=masks_dict,
|
| 95 |
+
bboxes=bboxes_dict,
|
| 96 |
+
categorical_keywords=['human', 'dog'],
|
| 97 |
+
unary_keywords=['running', 'sitting'],
|
| 98 |
+
binary_keywords=['chasing', 'near']
|
| 99 |
+
)
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## 🔧 Migration from Original Code
|
| 103 |
+
|
| 104 |
+
The `convert_inference.py` script shows how to migrate from the original `inference.py` workflow:
|
| 105 |
+
|
| 106 |
+
**Original Approach:**
|
| 107 |
+
- Manual model loading and configuration
|
| 108 |
+
- Direct handling of segmentation pipeline
|
| 109 |
+
- Custom result processing
|
| 110 |
+
- Complex setup requirements
|
| 111 |
+
|
| 112 |
+
**New HuggingFace Interface:**
|
| 113 |
+
- Standardized model configuration
|
| 114 |
+
- Automatic preprocessing/postprocessing
|
| 115 |
+
- Simple pipeline interface
|
| 116 |
+
- Easy sharing via HuggingFace Hub
|
| 117 |
+
|
| 118 |
+
## 📤 Sharing Your Model
|
| 119 |
+
|
| 120 |
+
Use the `push_to_hub.py` script to share your trained model:
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
python vine_hf/push_to_hub.py \
|
| 124 |
+
--weights path/to/your/model.pth \
|
| 125 |
+
--repo your-username/vine-model \
|
| 126 |
+
--login
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
## 🛠️ Installation & Setup
|
| 130 |
+
|
| 131 |
+
1. **Install Dependencies:**
|
| 132 |
+
```bash
|
| 133 |
+
pip install transformers torch torchvision opencv-python pillow numpy
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
2. **Install Segmentation Models (Optional):**
|
| 137 |
+
- SAM2: https://github.com/facebookresearch/sam2
|
| 138 |
+
- Grounding DINO: https://github.com/IDEA-Research/GroundingDINO
|
| 139 |
+
|
| 140 |
+
3. **Install VINE HF Interface:**
|
| 141 |
+
```bash
|
| 142 |
+
cd vine_hf
|
| 143 |
+
pip install -e .
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## 🎯 Configuration Options
|
| 147 |
+
|
| 148 |
+
The `VineConfig` class supports extensive configuration:
|
| 149 |
+
|
| 150 |
+
- **Model Settings:** CLIP backbone, hidden dimensions
|
| 151 |
+
- **Segmentation:** Method, thresholds, target FPS
|
| 152 |
+
- **Processing:** Alpha values, top-k results, video length limits
|
| 153 |
+
- **Performance:** Multi-class mode, output format options
|
| 154 |
+
|
| 155 |
+
## 📊 Output Format
|
| 156 |
+
|
| 157 |
+
The interface returns structured predictions:
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
{
|
| 161 |
+
"categorical_predictions": {obj_id: [(prob, category), ...]},
|
| 162 |
+
"unary_predictions": {(frame, obj): [(prob, action), ...]},
|
| 163 |
+
"binary_predictions": {(frame, pair): [(prob, relation), ...]},
|
| 164 |
+
"confidence_scores": {"categorical": float, "unary": float, "binary": float},
|
| 165 |
+
"summary": {
|
| 166 |
+
"num_objects_detected": int,
|
| 167 |
+
"top_categories": [(category, prob), ...],
|
| 168 |
+
"top_actions": [(action, prob), ...],
|
| 169 |
+
"top_relations": [(relation, prob), ...]
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## 🔍 Testing & Validation
|
| 175 |
+
|
| 176 |
+
Run the example scripts to test your setup:
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
# Test basic functionality
|
| 180 |
+
python vine_hf/example_usage.py
|
| 181 |
+
|
| 182 |
+
# Test migration from original code
|
| 183 |
+
python vine_hf/convert_inference.py
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## 🤝 Contributing
|
| 187 |
+
|
| 188 |
+
To contribute or customize:
|
| 189 |
+
|
| 190 |
+
1. **Modify Configuration:** Edit `vine_config.py` for new parameters
|
| 191 |
+
2. **Extend Model:** Add functionality to `vine_model.py`
|
| 192 |
+
3. **Enhance Pipeline:** Improve preprocessing/postprocessing in `vine_pipeline.py`
|
| 193 |
+
4. **Add Features:** Create additional utility scripts
|
| 194 |
+
|
| 195 |
+
## 📝 Next Steps
|
| 196 |
+
|
| 197 |
+
1. **Load Your Weights:** Use your trained VINE model weights
|
| 198 |
+
2. **Test Segmentation:** Set up Grounding DINO and SAM2 models
|
| 199 |
+
3. **Validate Results:** Compare with original inference.py output
|
| 200 |
+
4. **Share Model:** Push to HuggingFace Hub for community use
|
| 201 |
+
5. **Deploy:** Use in applications, demos, or research projects
|
| 202 |
+
|
| 203 |
+
## 🐛 Troubleshooting
|
| 204 |
+
|
| 205 |
+
**Common Issues:**
|
| 206 |
+
- **Import Errors:** Check PYTHONPATH and package installation
|
| 207 |
+
- **Segmentation Failures:** Verify Grounding DINO/SAM2 setup
|
| 208 |
+
- **Weight Loading:** Adjust weight loading logic in `convert_inference.py`
|
| 209 |
+
- **CUDA Issues:** Check GPU availability and PyTorch installation
|
| 210 |
+
|
| 211 |
+
**Support:**
|
| 212 |
+
- Check the README.md for detailed documentation
|
| 213 |
+
- Review example_usage.py for working code examples
|
| 214 |
+
- Examine convert_inference.py for migration guidance
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
This HuggingFace interface makes VINE accessible to the broader ML community while maintaining all the powerful video understanding capabilities of the original model. The standardized interface enables easy sharing, deployment, and integration with existing HuggingFace workflows.
|
vine_hf/README.md
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VINE HuggingFace Interface
|
| 2 |
+
|
| 3 |
+
VINE (Video Understanding with Natural Language) is a model that processes videos along with categorical, unary, and binary keywords to return probability distributions over those keywords for detected objects and their relationships.
|
| 4 |
+
|
| 5 |
+
This package provides a HuggingFace-compatible interface for the VINE model, making it easy to use for video understanding tasks.
|
| 6 |
+
|
| 7 |
+
## Features
|
| 8 |
+
|
| 9 |
+
- **Categorical Classification**: Classify objects in videos (e.g., "human", "dog", "frisbee")
|
| 10 |
+
- **Unary Predicates**: Detect actions on single objects (e.g., "running", "jumping", "sitting")
|
| 11 |
+
- **Binary Relations**: Detect relationships between object pairs (e.g., "behind", "in front of", "chasing")
|
| 12 |
+
- **Multiple Segmentation Methods**: Support for SAM2 and Grounding DINO + SAM2
|
| 13 |
+
- **HuggingFace Integration**: Full compatibility with HuggingFace transformers and pipelines
|
| 14 |
+
- **Visualization Hooks**: Optional high-level visualizations plus lightweight debug mask dumps for quick sanity checks
|
| 15 |
+
|
| 16 |
+
## Installation
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
# Install the package (assuming it's in your Python path)
|
| 20 |
+
pip install transformers torch torchvision
|
| 21 |
+
pip install opencv-python pillow numpy
|
| 22 |
+
|
| 23 |
+
# For segmentation functionality, you'll also need:
|
| 24 |
+
# - SAM2: https://github.com/facebookresearch/sam2
|
| 25 |
+
# - Grounding DINO: https://github.com/IDEA-Research/GroundingDINO
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Segmentation Model Configuration
|
| 29 |
+
|
| 30 |
+
`VinePipeline` lazily brings up the segmentation stack the first time a call needs masks. Thresholds, FPS, visualization toggles, and device selection live in `VineConfig`; the pipeline constructor tells it where to fetch SAM2 / GroundingDINO weights or lets you inject already-instantiated modules.
|
| 31 |
+
|
| 32 |
+
### Provide file paths at construction (most common)
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 36 |
+
|
| 37 |
+
vine_config = VineConfig(
|
| 38 |
+
segmentation_method="grounding_dino_sam2", # or "sam2"
|
| 39 |
+
box_threshold=0.35,
|
| 40 |
+
text_threshold=0.25,
|
| 41 |
+
target_fps=5,
|
| 42 |
+
visualization_dir="output/visualizations", # where to write visualizations (and debug visualizations if enabled)
|
| 43 |
+
debug_visualizations=True, # Write videos of the groundingDINO/SAM2/Binary/Unary, etc... outputs
|
| 44 |
+
pretrained_vine_path="/abs/path/to/laser_model_v1.pkl",
|
| 45 |
+
device="cuda:0", # accepts int, str, or torch.device
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
vine_model = VineModel(vine_config)
|
| 49 |
+
|
| 50 |
+
vine_pipeline = VinePipeline(
|
| 51 |
+
model=vine_model,
|
| 52 |
+
tokenizer=None,
|
| 53 |
+
sam_config_path="/abs/path/to/sam2/sam2.1_hiera_t.yaml",
|
| 54 |
+
sam_checkpoint_path="/abs/path/to/sam2/sam2_hiera_tiny.pt",
|
| 55 |
+
gd_config_path="/abs/path/to/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
| 56 |
+
gd_checkpoint_path="/abs/path/to/groundingdino/weights/groundingdino_swint_ogc.pth",
|
| 57 |
+
device=vine_config._device,
|
| 58 |
+
)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
When `segmentation_method="grounding_dino_sam2"`, both SAM2 and GroundingDINO must be reachable. The pipeline validates the paths; missing files raise a `ValueError`. If you pick `"sam2"`, only the SAM2 config and checkpoint are required.
|
| 62 |
+
|
| 63 |
+
### Reuse pre-initialized segmentation modules
|
| 64 |
+
|
| 65 |
+
If you build the segmentation stack elsewhere, inject the components with `set_segmentation_models` before running the pipeline:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
| 69 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 70 |
+
from groundingdino.util.inference import Model as GroundingDINOModel
|
| 71 |
+
|
| 72 |
+
sam_predictor = build_sam2_video_predictor(..., device=vine_config._device)
|
| 73 |
+
mask_generator = SAM2AutomaticMaskGenerator(build_sam2(..., device=vine_config._device))
|
| 74 |
+
grounding_model = GroundingDINOModel(..., device=vine_config._device)
|
| 75 |
+
|
| 76 |
+
vine_pipeline.set_segmentation_models(
|
| 77 |
+
sam_predictor=sam_predictor,
|
| 78 |
+
mask_generator=mask_generator,
|
| 79 |
+
grounding_model=grounding_model,
|
| 80 |
+
)
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Any argument left as `None` is initialized lazily from the file paths when the pipeline first needs that backend.
|
| 84 |
+
|
| 85 |
+
## Quick Start
|
| 86 |
+
|
| 87 |
+
## Requirements
|
| 88 |
+
-torch
|
| 89 |
+
-torchvision
|
| 90 |
+
-transformers
|
| 91 |
+
-opencv-python
|
| 92 |
+
-matplotlib
|
| 93 |
+
-seaborn
|
| 94 |
+
-pandas
|
| 95 |
+
-numpy
|
| 96 |
+
-ipywidgets
|
| 97 |
+
-tqdm
|
| 98 |
+
-scikit-learn
|
| 99 |
+
-sam2 (from Facebook Research) "https://github.com/video-fm/video-sam2"
|
| 100 |
+
-sam2 weights (downloaded separately. EX: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)
|
| 101 |
+
-groundingdino (from IDEA Research)
|
| 102 |
+
-groundingdino weights (downloaded separately. EX:https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth)
|
| 103 |
+
-spacy-fastlang
|
| 104 |
+
-en-core-web-sm (for spacy-fastlang)
|
| 105 |
+
-ffmpeg (for video processing)
|
| 106 |
+
-(optional) laser weights/full model checkpoint (downloaded separately. EX: https://huggingface.co/video-fm/vine_v0)
|
| 107 |
+
|
| 108 |
+
Usually, by running the laser/environments/laser_env.yml from the LASER repo, most dependencies will be installed. You will need to manually install sam2 and groundingdino as per their instructions.
|
| 109 |
+
|
| 110 |
+
### Using the Pipeline (Recommended)
|
| 111 |
+
```python
|
| 112 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 113 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 114 |
+
|
| 115 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 116 |
+
"vine-video-understanding",
|
| 117 |
+
pipeline_class=VinePipeline,
|
| 118 |
+
pt_model=VineModel,
|
| 119 |
+
type="multimodal",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
config = VineConfig(
|
| 123 |
+
segmentation_method="grounding_dino_sam2",
|
| 124 |
+
pretrained_vine_path="/abs/path/to/laser_model_v1.pkl",
|
| 125 |
+
visualization_dir="output",
|
| 126 |
+
visualize=True,
|
| 127 |
+
device="cuda:0",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
model = VineModel(config)
|
| 131 |
+
|
| 132 |
+
vine_pipeline = VinePipeline(
|
| 133 |
+
model=model,
|
| 134 |
+
tokenizer=None,
|
| 135 |
+
sam_config_path="/abs/path/to/sam2/sam2.1_hiera_t.yaml",
|
| 136 |
+
sam_checkpoint_path="/abs/path/to/sam2/sam2_hiera_tiny.pt",
|
| 137 |
+
gd_config_path="/abs/path/to/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
| 138 |
+
gd_checkpoint_path="/abs/path/to/groundingdino/weights/groundingdino_swint_ogc.pth",
|
| 139 |
+
device=config._device,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
results = vine_pipeline(
|
| 143 |
+
"/path/to/video.mp4",
|
| 144 |
+
categorical_keywords=["dog", "human"],
|
| 145 |
+
unary_keywords=["running"],
|
| 146 |
+
binary_keywords=["chasing"],
|
| 147 |
+
object_pairs=[(0, 1)],
|
| 148 |
+
return_top_k=3,
|
| 149 |
+
include_visualizations=True,
|
| 150 |
+
)
|
| 151 |
+
print(results["summary"])
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Using the Model Directly (Advanced)
|
| 155 |
+
|
| 156 |
+
For advanced users who want to provide their own segmentation:
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
from vine_hf import VineConfig, VineModel
|
| 160 |
+
import torch
|
| 161 |
+
|
| 162 |
+
# Create configuration
|
| 163 |
+
config = VineConfig(
|
| 164 |
+
pretrained_vine_path="/path/to/your/vine/weights" # Optional: your fine-tuned weights
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Initialize model
|
| 168 |
+
model = VineModel(config)
|
| 169 |
+
|
| 170 |
+
# If you have your own video frames, masks, and bboxes from external segmentation
|
| 171 |
+
video_frames = torch.randn(3, 224, 224, 3) * 255 # Your video frames
|
| 172 |
+
masks = {0: {1: torch.ones(224, 224, 1)}} # Your segmentation masks
|
| 173 |
+
bboxes = {0: {1: [50, 50, 150, 150]}} # Your bounding boxes
|
| 174 |
+
|
| 175 |
+
# Run prediction
|
| 176 |
+
results = model.predict(
|
| 177 |
+
video_frames=video_frames,
|
| 178 |
+
masks=masks,
|
| 179 |
+
bboxes=bboxes,
|
| 180 |
+
categorical_keywords=['human', 'dog', 'frisbee'],
|
| 181 |
+
unary_keywords=['running', 'jumping'],
|
| 182 |
+
binary_keywords=['chasing', 'following'],
|
| 183 |
+
object_pairs=[(1, 2)],
|
| 184 |
+
return_top_k=3
|
| 185 |
+
)
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
**Note**: For most users, the pipeline approach above is recommended as it handles video loading and segmentation automatically.
|
| 189 |
+
|
| 190 |
+
## Configuration Options
|
| 191 |
+
|
| 192 |
+
The `VineConfig` class supports the following parameters (non-exhaustive):
|
| 193 |
+
|
| 194 |
+
- `model_name`: CLIP model backbone (default: `"openai/clip-vit-large-patch14-336"`)
|
| 195 |
+
- `pretrained_vine_path`: Optional path or Hugging Face repo with pretrained VINE weights
|
| 196 |
+
- `segmentation_method`: `"sam2"` or `"grounding_dino_sam2"` (default: `"grounding_dino_sam2"`)
|
| 197 |
+
- `box_threshold` / `text_threshold`: Grounding DINO thresholds
|
| 198 |
+
- `target_fps`: Target FPS for video processing (default: `1`)
|
| 199 |
+
- `alpha`, `white_alpha`: Rendering parameters used when extracting masked crops
|
| 200 |
+
- `topk_cate`: Top-k categories to return per object (default: `3`)
|
| 201 |
+
- `max_video_length`: Maximum frames to process (default: `100`)
|
| 202 |
+
- `visualize`: When `True`, pipeline post-processing attempts to create stitched visualizations
|
| 203 |
+
- `visualization_dir`: Optional base directory where visualization assets are written
|
| 204 |
+
- `debug_visualizations`: When `True`, the model saves a single first-frame mask composite for quick inspection
|
| 205 |
+
- `debug_visualization_path`: Target filepath for the debug mask composite (must point to a writable file)
|
| 206 |
+
- `return_flattened_segments`, `return_valid_pairs`, `interested_object_pairs`: Advanced geometry outputs for downstream consumers
|
| 207 |
+
|
| 208 |
+
## Output Format
|
| 209 |
+
|
| 210 |
+
The model returns a dictionary with the following structure:
|
| 211 |
+
|
| 212 |
+
```python
|
| 213 |
+
{
|
| 214 |
+
"masks" : {},
|
| 215 |
+
|
| 216 |
+
"boxes" : {},
|
| 217 |
+
|
| 218 |
+
"categorical_predictions": {
|
| 219 |
+
object_id: [(probability, category), ...]
|
| 220 |
+
},
|
| 221 |
+
"unary_predictions": {
|
| 222 |
+
(frame_id, object_id): [(probability, action), ...]
|
| 223 |
+
},
|
| 224 |
+
"binary_predictions": {
|
| 225 |
+
(frame_id, (obj1_id, obj2_id)): [(probability, relation), ...]
|
| 226 |
+
},
|
| 227 |
+
"confidence_scores": {
|
| 228 |
+
"categorical": max_categorical_confidence,
|
| 229 |
+
"unary": max_unary_confidence,
|
| 230 |
+
"binary": max_binary_confidence
|
| 231 |
+
},
|
| 232 |
+
"summary": {
|
| 233 |
+
"num_objects_detected": int,
|
| 234 |
+
"top_categories": [(category, probability), ...],
|
| 235 |
+
"top_actions": [(action, probability), ...],
|
| 236 |
+
"top_relations": [(relation, probability), ...]
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
## Visualization & Debugging
|
| 242 |
+
|
| 243 |
+
There are two complementary visualization layers:
|
| 244 |
+
|
| 245 |
+
- **Post-process visualizations** (`include_visualizations=True` in the pipeline call) produces a high-level stitched video summarizing detections, actions, and relations over time.
|
| 246 |
+
|
| 247 |
+
- **Debug visualizations** (`debug_visualizations=True` in `VineConfig`) dumps videos of intermediate segmentation masks and outputs from GroundingDINO, SAM2, Unary, Binary, etc. for quick sanity checks.
|
| 248 |
+
|
| 249 |
+
If you plan to enable either option, ensure the relevant output directories exist before running the pipeline.
|
| 250 |
+
|
| 251 |
+
## Segmentation Methods
|
| 252 |
+
|
| 253 |
+
### Grounding DINO + SAM2 (Recommended)
|
| 254 |
+
|
| 255 |
+
Uses Grounding DINO for object detection based on text prompts, then SAM2 for precise segmentation.
|
| 256 |
+
|
| 257 |
+
Requirements:
|
| 258 |
+
- Grounding DINO model and weights
|
| 259 |
+
- SAM2 model and weights
|
| 260 |
+
- Properly configured paths to model checkpoints
|
| 261 |
+
|
| 262 |
+
### SAM2 Only
|
| 263 |
+
|
| 264 |
+
Uses SAM2's automatic mask generation without text-based object detection.
|
| 265 |
+
|
| 266 |
+
Requirements:
|
| 267 |
+
- SAM2 model and weights
|
| 268 |
+
|
| 269 |
+
## Model Architecture
|
| 270 |
+
|
| 271 |
+
VINE is built on top of CLIP and uses three separate CLIP models for different tasks:
|
| 272 |
+
- **Categorical Model**: For object classification
|
| 273 |
+
- **Unary Model**: For single-object action recognition
|
| 274 |
+
- **Binary Model**: For relationship detection between object pairs
|
| 275 |
+
|
| 276 |
+
Each model processes both visual and textual features to compute similarity scores and probability distributions.
|
| 277 |
+
|
| 278 |
+
## Pushing to HuggingFace Hub
|
| 279 |
+
|
| 280 |
+
```python
|
| 281 |
+
from vine_hf import VineConfig, VineModel
|
| 282 |
+
|
| 283 |
+
# Create and configure your model
|
| 284 |
+
config = VineConfig()
|
| 285 |
+
model = VineModel(config)
|
| 286 |
+
|
| 287 |
+
# Load your pretrained weights
|
| 288 |
+
# model.load_state_dict(torch.load('path/to/your/weights.pth'))
|
| 289 |
+
|
| 290 |
+
# Register for auto classes
|
| 291 |
+
config.register_for_auto_class()
|
| 292 |
+
model.register_for_auto_class("AutoModel")
|
| 293 |
+
|
| 294 |
+
# Push to Hub
|
| 295 |
+
config.push_to_hub('your-username/vine-model')
|
| 296 |
+
model.push_to_hub('your-username/vine-model')
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
## Loading from HuggingFace Hub
|
| 300 |
+
|
| 301 |
+
```python
|
| 302 |
+
from transformers import AutoModel, pipeline
|
| 303 |
+
|
| 304 |
+
# Load model
|
| 305 |
+
model = AutoModel.from_pretrained('your-username/vine-model', trust_remote_code=True)
|
| 306 |
+
|
| 307 |
+
# Or use with pipeline
|
| 308 |
+
vine_pipeline = pipeline(
|
| 309 |
+
'vine-video-understanding',
|
| 310 |
+
model='your-username/vine-model',
|
| 311 |
+
trust_remote_code=True
|
| 312 |
+
)
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
## Examples
|
| 316 |
+
|
| 317 |
+
See `example_usage.py` for comprehensive examples including:
|
| 318 |
+
- Direct model usage
|
| 319 |
+
- Pipeline usage
|
| 320 |
+
- HuggingFace Hub integration
|
| 321 |
+
- Real video processing
|
| 322 |
+
|
| 323 |
+
## Requirements
|
| 324 |
+
|
| 325 |
+
- Python 3.7+
|
| 326 |
+
- PyTorch 1.9+
|
| 327 |
+
- transformers 4.20+
|
| 328 |
+
- OpenCV
|
| 329 |
+
- PIL/Pillow
|
| 330 |
+
- NumPy
|
| 331 |
+
|
| 332 |
+
For segmentation:
|
| 333 |
+
- SAM2 (Facebook Research)
|
| 334 |
+
- Grounding DINO (IDEA Research)
|
| 335 |
+
|
| 336 |
+
## Citation
|
| 337 |
+
|
| 338 |
+
If you use VINE in your research, please cite:
|
| 339 |
+
|
| 340 |
+
```bibtex
|
| 341 |
+
@article{vine2024,
|
| 342 |
+
title={VINE: Video Understanding with Natural Language},
|
| 343 |
+
author={Your Authors},
|
| 344 |
+
journal={Your Journal},
|
| 345 |
+
year={2024}
|
| 346 |
+
}
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
## License
|
| 350 |
+
|
| 351 |
+
[Your License Here]
|
| 352 |
+
|
| 353 |
+
## Contact
|
| 354 |
+
|
| 355 |
+
[Your Contact Information Here]
|
vine_hf/README_HF.md
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VINE: Video Understanding with Natural Language
|
| 2 |
+
|
| 3 |
+
[](https://huggingface.co/video-fm/vine)
|
| 4 |
+
[](https://github.com/kevinxuez/LASER)
|
| 5 |
+
|
| 6 |
+
VINE is a video understanding model that processes videos along with categorical, unary, and binary keywords to return probability distributions over those keywords for detected objects and their relationships.
|
| 7 |
+
|
| 8 |
+
## Quick Start
|
| 9 |
+
|
| 10 |
+
```python
|
| 11 |
+
from transformers import AutoModel
|
| 12 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 13 |
+
|
| 14 |
+
# Load VINE model from HuggingFace
|
| 15 |
+
model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
|
| 16 |
+
|
| 17 |
+
# Create pipeline with your checkpoint paths
|
| 18 |
+
vine_pipeline = VinePipeline(
|
| 19 |
+
model=model,
|
| 20 |
+
tokenizer=None,
|
| 21 |
+
sam_config_path="/path/to/sam2_config.yaml",
|
| 22 |
+
sam_checkpoint_path="/path/to/sam2_checkpoint.pt",
|
| 23 |
+
gd_config_path="/path/to/grounding_dino_config.py",
|
| 24 |
+
gd_checkpoint_path="/path/to/grounding_dino_checkpoint.pth",
|
| 25 |
+
device="cuda",
|
| 26 |
+
trust_remote_code=True
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Process a video
|
| 30 |
+
results = vine_pipeline(
|
| 31 |
+
'path/to/video.mp4',
|
| 32 |
+
categorical_keywords=['human', 'dog', 'frisbee'],
|
| 33 |
+
unary_keywords=['running', 'jumping'],
|
| 34 |
+
binary_keywords=['chasing', 'behind'],
|
| 35 |
+
return_top_k=3
|
| 36 |
+
)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Installation
|
| 40 |
+
|
| 41 |
+
### Option 1: Automated Setup (Recommended)
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# Download the setup script
|
| 45 |
+
wget https://raw.githubusercontent.com/kevinxuez/vine_hf/main/setup_vine_demo.sh
|
| 46 |
+
|
| 47 |
+
# Run the setup
|
| 48 |
+
bash setup_vine_demo.sh
|
| 49 |
+
|
| 50 |
+
# Activate environment
|
| 51 |
+
conda activate vine_demo
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Option 2: Manual Installation
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# 1. Create conda environment
|
| 58 |
+
conda create -n vine_demo python=3.10 -y
|
| 59 |
+
conda activate vine_demo
|
| 60 |
+
|
| 61 |
+
# 2. Install PyTorch with CUDA support
|
| 62 |
+
pip install torch==2.7.1 torchvision==0.22.1 --index-url https://download.pytorch.org/whl/cu126
|
| 63 |
+
|
| 64 |
+
# 3. Install core dependencies
|
| 65 |
+
pip install transformers huggingface-hub safetensors
|
| 66 |
+
|
| 67 |
+
# 4. Clone and install required repositories
|
| 68 |
+
git clone https://github.com/video-fm/video-sam2.git
|
| 69 |
+
git clone https://github.com/video-fm/GroundingDINO.git
|
| 70 |
+
git clone https://github.com/kevinxuez/LASER.git
|
| 71 |
+
git clone https://github.com/kevinxuez/vine_hf.git
|
| 72 |
+
|
| 73 |
+
# Install in editable mode
|
| 74 |
+
pip install -e ./video-sam2
|
| 75 |
+
pip install -e ./GroundingDINO
|
| 76 |
+
pip install -e ./LASER
|
| 77 |
+
pip install -e ./vine_hf
|
| 78 |
+
|
| 79 |
+
# Build GroundingDINO extensions
|
| 80 |
+
cd GroundingDINO && python setup.py build_ext --force --inplace && cd ..
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Required Checkpoints
|
| 84 |
+
|
| 85 |
+
VINE requires SAM2 and GroundingDINO checkpoints for segmentation. Download these separately:
|
| 86 |
+
|
| 87 |
+
### SAM2 Checkpoint
|
| 88 |
+
```bash
|
| 89 |
+
wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt
|
| 90 |
+
wget https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
### GroundingDINO Checkpoint
|
| 94 |
+
```bash
|
| 95 |
+
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
| 96 |
+
wget https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Architecture
|
| 100 |
+
|
| 101 |
+
```
|
| 102 |
+
video-fm/vine (HuggingFace Hub)
|
| 103 |
+
├── VINE Model Weights (~1.8GB)
|
| 104 |
+
│ ├── Categorical CLIP model (fine-tuned)
|
| 105 |
+
│ ├── Unary CLIP model (fine-tuned)
|
| 106 |
+
│ └── Binary CLIP model (fine-tuned)
|
| 107 |
+
└── Architecture Files
|
| 108 |
+
├── vine_config.py
|
| 109 |
+
├── vine_model.py
|
| 110 |
+
├── vine_pipeline.py
|
| 111 |
+
└── utilities
|
| 112 |
+
|
| 113 |
+
User Provides:
|
| 114 |
+
├── Dependencies (via pip/conda)
|
| 115 |
+
│ ├── laser (video processing utilities)
|
| 116 |
+
│ ├── sam2 (segmentation)
|
| 117 |
+
│ └── groundingdino (object detection)
|
| 118 |
+
└── Checkpoints (downloaded separately)
|
| 119 |
+
├── SAM2 model files
|
| 120 |
+
└── GroundingDINO model files
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Why This Architecture?
|
| 124 |
+
|
| 125 |
+
This separation of concerns provides several benefits:
|
| 126 |
+
|
| 127 |
+
1. **Lightweight Distribution**: Only VINE-specific weights (~1.8GB) are on HuggingFace
|
| 128 |
+
2. **Version Control**: Users can choose their preferred SAM2/GroundingDINO versions
|
| 129 |
+
3. **Licensing**: Keeps different model licenses separate
|
| 130 |
+
4. **Flexibility**: Easy to swap segmentation backends
|
| 131 |
+
5. **Standard Practice**: Similar to models like LLaVA, BLIP-2, etc.
|
| 132 |
+
|
| 133 |
+
## Full Usage Example
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
import os
|
| 137 |
+
from pathlib import Path
|
| 138 |
+
from transformers import AutoModel
|
| 139 |
+
from vine_hf import VinePipeline
|
| 140 |
+
|
| 141 |
+
# Set up paths
|
| 142 |
+
checkpoint_dir = Path("/path/to/checkpoints")
|
| 143 |
+
sam_config = checkpoint_dir / "sam2_hiera_t.yaml"
|
| 144 |
+
sam_checkpoint = checkpoint_dir / "sam2_hiera_tiny.pt"
|
| 145 |
+
gd_config = checkpoint_dir / "GroundingDINO_SwinT_OGC.py"
|
| 146 |
+
gd_checkpoint = checkpoint_dir / "groundingdino_swint_ogc.pth"
|
| 147 |
+
|
| 148 |
+
# Load VINE from HuggingFace
|
| 149 |
+
model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
|
| 150 |
+
|
| 151 |
+
# Create pipeline
|
| 152 |
+
vine_pipeline = VinePipeline(
|
| 153 |
+
model=model,
|
| 154 |
+
tokenizer=None,
|
| 155 |
+
sam_config_path=str(sam_config),
|
| 156 |
+
sam_checkpoint_path=str(sam_checkpoint),
|
| 157 |
+
gd_config_path=str(gd_config),
|
| 158 |
+
gd_checkpoint_path=str(gd_checkpoint),
|
| 159 |
+
device="cuda:0",
|
| 160 |
+
trust_remote_code=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Process video
|
| 164 |
+
results = vine_pipeline(
|
| 165 |
+
"path/to/video.mp4",
|
| 166 |
+
categorical_keywords=['person', 'dog', 'ball'],
|
| 167 |
+
unary_keywords=['running', 'jumping', 'sitting'],
|
| 168 |
+
binary_keywords=['chasing', 'next to', 'holding'],
|
| 169 |
+
object_pairs=[(0, 1), (0, 2)], # person-dog, person-ball
|
| 170 |
+
return_top_k=5,
|
| 171 |
+
include_visualizations=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Access results
|
| 175 |
+
print(f"Detected {results['summary']['num_objects_detected']} objects")
|
| 176 |
+
print(f"Top categories: {results['summary']['top_categories']}")
|
| 177 |
+
print(f"Top actions: {results['summary']['top_actions']}")
|
| 178 |
+
print(f"Top relations: {results['summary']['top_relations']}")
|
| 179 |
+
|
| 180 |
+
# Access detailed predictions
|
| 181 |
+
for obj_id, predictions in results['categorical_predictions'].items():
|
| 182 |
+
print(f"\nObject {obj_id}:")
|
| 183 |
+
for prob, category in predictions:
|
| 184 |
+
print(f" {category}: {prob:.3f}")
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
## Output Format
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
{
|
| 191 |
+
"categorical_predictions": {
|
| 192 |
+
object_id: [(probability, category), ...]
|
| 193 |
+
},
|
| 194 |
+
"unary_predictions": {
|
| 195 |
+
(frame_id, object_id): [(probability, action), ...]
|
| 196 |
+
},
|
| 197 |
+
"binary_predictions": {
|
| 198 |
+
(frame_id, (obj1_id, obj2_id)): [(probability, relation), ...]
|
| 199 |
+
},
|
| 200 |
+
"confidence_scores": {
|
| 201 |
+
"categorical": float,
|
| 202 |
+
"unary": float,
|
| 203 |
+
"binary": float
|
| 204 |
+
},
|
| 205 |
+
"summary": {
|
| 206 |
+
"num_objects_detected": int,
|
| 207 |
+
"top_categories": [(category, probability), ...],
|
| 208 |
+
"top_actions": [(action, probability), ...],
|
| 209 |
+
"top_relations": [(relation, probability), ...]
|
| 210 |
+
},
|
| 211 |
+
"visualizations": { # if include_visualizations=True
|
| 212 |
+
"vine": {
|
| 213 |
+
"all": {"frames": [...], "video_path": "..."},
|
| 214 |
+
...
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
## Configuration Options
|
| 221 |
+
|
| 222 |
+
```python
|
| 223 |
+
from vine_hf import VineConfig
|
| 224 |
+
|
| 225 |
+
config = VineConfig(
|
| 226 |
+
model_name="openai/clip-vit-base-patch32", # CLIP backbone
|
| 227 |
+
segmentation_method="grounding_dino_sam2", # or "sam2"
|
| 228 |
+
box_threshold=0.35, # GroundingDINO threshold
|
| 229 |
+
text_threshold=0.25, # GroundingDINO threshold
|
| 230 |
+
target_fps=5, # Video sampling rate
|
| 231 |
+
visualize=True, # Enable visualizations
|
| 232 |
+
visualization_dir="outputs/", # Output directory
|
| 233 |
+
debug_visualizations=False, # Debug mode
|
| 234 |
+
device="cuda:0" # Device
|
| 235 |
+
)
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Deployment Examples
|
| 239 |
+
|
| 240 |
+
### Local Script
|
| 241 |
+
```python
|
| 242 |
+
# test_vine.py
|
| 243 |
+
from transformers import AutoModel
|
| 244 |
+
from vine_hf import VinePipeline
|
| 245 |
+
|
| 246 |
+
model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
|
| 247 |
+
pipeline = VinePipeline(model=model, ...)
|
| 248 |
+
results = pipeline("video.mp4", ...)
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
### HuggingFace Spaces
|
| 252 |
+
```python
|
| 253 |
+
# app.py for Gradio Space
|
| 254 |
+
import gradio as gr
|
| 255 |
+
from transformers import AutoModel
|
| 256 |
+
from vine_hf import VinePipeline
|
| 257 |
+
|
| 258 |
+
model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
|
| 259 |
+
# ... set up pipeline and Gradio interface
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
### API Server
|
| 263 |
+
```python
|
| 264 |
+
# FastAPI server
|
| 265 |
+
from fastapi import FastAPI
|
| 266 |
+
from transformers import AutoModel
|
| 267 |
+
from vine_hf import VinePipeline
|
| 268 |
+
|
| 269 |
+
app = FastAPI()
|
| 270 |
+
model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
|
| 271 |
+
pipeline = VinePipeline(model=model, ...)
|
| 272 |
+
|
| 273 |
+
@app.post("/process")
|
| 274 |
+
async def process_video(video_path: str):
|
| 275 |
+
return pipeline(video_path, ...)
|
| 276 |
+
```
|
| 277 |
+
|
| 278 |
+
## Troubleshooting
|
| 279 |
+
|
| 280 |
+
### Import Errors
|
| 281 |
+
```bash
|
| 282 |
+
# Make sure all dependencies are installed
|
| 283 |
+
pip list | grep -E "laser|sam2|groundingdino"
|
| 284 |
+
|
| 285 |
+
# Reinstall if needed
|
| 286 |
+
pip install -e ./LASER
|
| 287 |
+
pip install -e ./video-sam2
|
| 288 |
+
pip install -e ./GroundingDINO
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
### CUDA Errors
|
| 292 |
+
```python
|
| 293 |
+
# Check CUDA availability
|
| 294 |
+
import torch
|
| 295 |
+
print(torch.cuda.is_available())
|
| 296 |
+
print(torch.version.cuda)
|
| 297 |
+
|
| 298 |
+
# Use CPU if needed
|
| 299 |
+
pipeline = VinePipeline(model=model, device="cpu", ...)
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
### Checkpoint Not Found
|
| 303 |
+
```bash
|
| 304 |
+
# Verify checkpoint paths
|
| 305 |
+
ls -lh /path/to/sam2_hiera_tiny.pt
|
| 306 |
+
ls -lh /path/to/groundingdino_swint_ogc.pth
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
## System Requirements
|
| 310 |
+
|
| 311 |
+
- **Python**: 3.10+
|
| 312 |
+
- **CUDA**: 11.8+ (for GPU)
|
| 313 |
+
- **GPU**: 8GB+ VRAM recommended (T4, V100, A100, etc.)
|
| 314 |
+
- **RAM**: 16GB+ recommended
|
| 315 |
+
- **Storage**: ~3GB for checkpoints
|
| 316 |
+
|
| 317 |
+
## Citation
|
| 318 |
+
|
| 319 |
+
```bibtex
|
| 320 |
+
@article{laser2024,
|
| 321 |
+
title={LASER: Language-guided Object Grounding and Relation Understanding in Videos},
|
| 322 |
+
author={Your Authors},
|
| 323 |
+
journal={Your Conference/Journal},
|
| 324 |
+
year={2024}
|
| 325 |
+
}
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
## License
|
| 329 |
+
|
| 330 |
+
This model and code are released under the MIT License. Note that SAM2 and GroundingDINO have their own respective licenses.
|
| 331 |
+
|
| 332 |
+
## Links
|
| 333 |
+
|
| 334 |
+
- **Model**: https://huggingface.co/video-fm/vine
|
| 335 |
+
- **Code**: https://github.com/kevinxuez/LASER
|
| 336 |
+
- **vine_hf Package**: https://github.com/kevinxuez/vine_hf
|
| 337 |
+
- **SAM2**: https://github.com/facebookresearch/sam2
|
| 338 |
+
- **GroundingDINO**: https://github.com/IDEA-Research/GroundingDINO
|
| 339 |
+
|
| 340 |
+
## Support
|
| 341 |
+
|
| 342 |
+
For issues or questions:
|
| 343 |
+
- **Model/Architecture**: [HuggingFace Discussions](https://huggingface.co/video-fm/vine/discussions)
|
| 344 |
+
- **LASER Framework**: [GitHub Issues](https://github.com/kevinxuez/LASER/issues)
|
| 345 |
+
- **vine_hf Package**: [GitHub Issues](https://github.com/kevinxuez/vine_hf/issues)
|
vine_hf/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VINE HuggingFace Interface
|
| 3 |
+
|
| 4 |
+
VINE (Video Understanding with Natural Language) is a model that processes videos
|
| 5 |
+
along with categorical, unary, and binary keywords to return probability
|
| 6 |
+
distributions over those keywords for detected objects and their relationships.
|
| 7 |
+
|
| 8 |
+
This package provides a HuggingFace-compatible interface for the VINE model,
|
| 9 |
+
including configuration, model, and pipeline classes.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .vine_config import VineConfig
|
| 13 |
+
from .vine_model import VineModel
|
| 14 |
+
from .vine_pipeline import VinePipeline
|
| 15 |
+
|
| 16 |
+
__version__ = "1.0.0"
|
| 17 |
+
__author__ = "LASER Team"
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"VineConfig",
|
| 21 |
+
"VineModel",
|
| 22 |
+
"VinePipeline"
|
| 23 |
+
]
|
vine_hf/convert_inference.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to convert existing inference.py workflow to use VINE HuggingFace interface
|
| 3 |
+
|
| 4 |
+
This script demonstrates how to migrate from the original inference.py approach
|
| 5 |
+
to the new HuggingFace-compatible interface.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Dict, List, Tuple, Any
|
| 13 |
+
|
| 14 |
+
# Add paths for imports
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 16 |
+
|
| 17 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 18 |
+
from laser.loading import load_video
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_pretrained_vine_model(model_dir: str, model_name: str, epoch: int = 0) -> VineModel:
|
| 22 |
+
"""
|
| 23 |
+
Load a pretrained VINE model from the original format into HuggingFace format.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_dir: Directory containing the model
|
| 27 |
+
model_name: Name of the model file (without .{epoch}.model extension)
|
| 28 |
+
epoch: Epoch number to load
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
VineModel instance with loaded weights
|
| 32 |
+
"""
|
| 33 |
+
print(f"Loading pretrained VINE model from {model_dir}")
|
| 34 |
+
|
| 35 |
+
# Create configuration (adjust parameters as needed)
|
| 36 |
+
# We expect local ensemble weights in `model_dir`, so configure
|
| 37 |
+
# VineConfig to load from local directory/filename.
|
| 38 |
+
model_file = f"{model_name}.{epoch}.model"
|
| 39 |
+
config = VineConfig(
|
| 40 |
+
model_name="openai/clip-vit-base-patch32",
|
| 41 |
+
segmentation_method="grounding_dino_sam2",
|
| 42 |
+
target_fps=1,
|
| 43 |
+
box_threshold=0.35,
|
| 44 |
+
text_threshold=0.25,
|
| 45 |
+
use_hf_repo=False,
|
| 46 |
+
local_dir=model_dir,
|
| 47 |
+
local_filename=model_file,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Initialize model (VineModel will consult the config when loading)
|
| 51 |
+
vine_model = VineModel(config)
|
| 52 |
+
|
| 53 |
+
# Load original weights
|
| 54 |
+
model_file = f"{model_name}.{epoch}.model"
|
| 55 |
+
model_path = os.path.join(model_dir, model_file)
|
| 56 |
+
|
| 57 |
+
if os.path.exists(model_path):
|
| 58 |
+
print(f"Loading weights from: {model_path}")
|
| 59 |
+
try:
|
| 60 |
+
# Add safe globals for PyTorch 2.6+
|
| 61 |
+
import torch.serialization
|
| 62 |
+
from laser.models.llava_clip_model_v3 import PredicateModel
|
| 63 |
+
torch.serialization.add_safe_globals([PredicateModel])
|
| 64 |
+
|
| 65 |
+
# Load the original model
|
| 66 |
+
original_model = torch.load(model_path, map_location='cpu', weights_only=False)
|
| 67 |
+
|
| 68 |
+
# Transfer weights to HuggingFace model
|
| 69 |
+
# This assumes the original model has the same structure
|
| 70 |
+
# You may need to adjust this based on your specific model structure
|
| 71 |
+
|
| 72 |
+
if hasattr(original_model, 'clip_cate_model'):
|
| 73 |
+
vine_model.clip_cate_model.load_state_dict(original_model.clip_cate_model.state_dict())
|
| 74 |
+
if hasattr(original_model, 'clip_unary_model'):
|
| 75 |
+
vine_model.clip_unary_model.load_state_dict(original_model.clip_unary_model.state_dict())
|
| 76 |
+
if hasattr(original_model, 'clip_binary_model'):
|
| 77 |
+
vine_model.clip_binary_model.load_state_dict(original_model.clip_binary_model.state_dict())
|
| 78 |
+
if hasattr(original_model, 'clip_tokenizer'):
|
| 79 |
+
vine_model.clip_tokenizer = original_model.clip_tokenizer
|
| 80 |
+
if hasattr(original_model, 'clip_processor'):
|
| 81 |
+
vine_model.clip_processor = original_model.clip_processor
|
| 82 |
+
|
| 83 |
+
print("✓ Weights transferred successfully")
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"✗ Error loading weights: {e}")
|
| 87 |
+
print("You may need to adjust the weight loading logic for your specific model")
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
print(f"✗ Model file not found: {model_path}")
|
| 91 |
+
|
| 92 |
+
return vine_model
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def convert_inference_workflow():
|
| 96 |
+
"""
|
| 97 |
+
Convert the original inference.py workflow to use HuggingFace interface.
|
| 98 |
+
|
| 99 |
+
This function demonstrates how to replicate the original inference workflow
|
| 100 |
+
using the new HuggingFace-compatible components.
|
| 101 |
+
"""
|
| 102 |
+
print("=== Converting Inference Workflow ===")
|
| 103 |
+
|
| 104 |
+
# Original parameters from inference.py
|
| 105 |
+
video_id = 'v1'
|
| 106 |
+
target_fps = 1
|
| 107 |
+
classes = ['human', 'dog', 'frisbee']
|
| 108 |
+
unary_keywords = ['running', 'jumping', 'sitting', 'standing']
|
| 109 |
+
binary_keywords = ['behind', 'bite', 'front', 'jump over', 'right', 'left']
|
| 110 |
+
|
| 111 |
+
# Paths (adjust these to match your setup)
|
| 112 |
+
demo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../demo"))
|
| 113 |
+
video_dir = os.path.join(demo_dir, "videos")
|
| 114 |
+
video_path = os.path.join(video_dir, f"{video_id}.mp4")
|
| 115 |
+
|
| 116 |
+
# Model paths (adjust these to match your setup)
|
| 117 |
+
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
|
| 118 |
+
model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
|
| 119 |
+
model_name = "ensemble-2025-02-10-14-57-22"
|
| 120 |
+
|
| 121 |
+
# Segmentation model paths (adjust these to your actual paths)
|
| 122 |
+
sam_config_path = "/path/to/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 123 |
+
sam_checkpoint_path = "/path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
|
| 124 |
+
gd_config_path = "/path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py"
|
| 125 |
+
gd_checkpoint_path = "/path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth"
|
| 126 |
+
|
| 127 |
+
print(f"Video path: {video_path}")
|
| 128 |
+
print(f"Model dir: {model_dir}")
|
| 129 |
+
print(f"SAM2 config: {sam_config_path}")
|
| 130 |
+
print(f"GroundingDINO config: {gd_config_path}")
|
| 131 |
+
|
| 132 |
+
# Check if video exists
|
| 133 |
+
if not os.path.exists(video_path):
|
| 134 |
+
print(f"✗ Video not found: {video_path}")
|
| 135 |
+
print("Please adjust the video path or use your own video file")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
# 1. Load video (same as original)
|
| 139 |
+
print(f"Loading video: {video_id}")
|
| 140 |
+
video_tensor = load_video(video_path, target_fps=target_fps)
|
| 141 |
+
print(f"Video shape: {video_tensor.shape}")
|
| 142 |
+
|
| 143 |
+
# 2. Load VINE model with HuggingFace interface
|
| 144 |
+
print("Loading VINE model...")
|
| 145 |
+
if os.path.exists(model_dir):
|
| 146 |
+
vine_model = load_pretrained_vine_model(model_dir, model_name, epoch=0)
|
| 147 |
+
else:
|
| 148 |
+
print(f"Model directory not found: {model_dir}")
|
| 149 |
+
print("Creating new model with random weights for demonstration")
|
| 150 |
+
config = VineConfig()
|
| 151 |
+
vine_model = VineModel(config)
|
| 152 |
+
|
| 153 |
+
# 3. Create pipeline for easier use
|
| 154 |
+
print("Creating VINE pipeline...")
|
| 155 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 156 |
+
|
| 157 |
+
# Register pipeline if not already registered
|
| 158 |
+
try:
|
| 159 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 160 |
+
"vine-video-understanding",
|
| 161 |
+
pipeline_class=VinePipeline,
|
| 162 |
+
pt_model=VineModel,
|
| 163 |
+
type="multimodal",
|
| 164 |
+
)
|
| 165 |
+
except Exception:
|
| 166 |
+
pass # Already registered
|
| 167 |
+
|
| 168 |
+
# Create pipeline instance with segmentation model paths
|
| 169 |
+
vine_pipeline = VinePipeline(
|
| 170 |
+
model=vine_model,
|
| 171 |
+
tokenizer=None,
|
| 172 |
+
# SAM2 configuration
|
| 173 |
+
sam_config_path=sam_config_path,
|
| 174 |
+
sam_checkpoint_path=sam_checkpoint_path,
|
| 175 |
+
# GroundingDINO configuration
|
| 176 |
+
gd_config_path=gd_config_path,
|
| 177 |
+
gd_checkpoint_path=gd_checkpoint_path
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# 4. Process video with new interface
|
| 181 |
+
print("Processing video with VINE HuggingFace interface...")
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
# Use the pipeline to process the video
|
| 185 |
+
results = vine_pipeline(
|
| 186 |
+
video_path,
|
| 187 |
+
categorical_keywords=classes,
|
| 188 |
+
unary_keywords=unary_keywords,
|
| 189 |
+
binary_keywords=binary_keywords,
|
| 190 |
+
object_pairs=[(1, 2), (2, 3)], # Example object pairs
|
| 191 |
+
segmentation_method='grounding_dino_sam2',
|
| 192 |
+
target_fps=target_fps,
|
| 193 |
+
return_top_k=3,
|
| 194 |
+
include_visualizations=False
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# 5. Display results (similar to original format)
|
| 198 |
+
print("\n=== VINE Results (HuggingFace Interface) ===")
|
| 199 |
+
|
| 200 |
+
# Categorical predictions
|
| 201 |
+
print("\nCategorical Predictions:")
|
| 202 |
+
for obj_id, predictions in results['categorical_predictions'].items():
|
| 203 |
+
print(f" Object {obj_id}:")
|
| 204 |
+
for prob, category in predictions:
|
| 205 |
+
print(f" {prob:.3f}: {category}")
|
| 206 |
+
|
| 207 |
+
# Unary predictions
|
| 208 |
+
print("\nUnary Predictions:")
|
| 209 |
+
for (frame_id, obj_id), predictions in results['unary_predictions'].items():
|
| 210 |
+
print(f" Frame {frame_id}, Object {obj_id}:")
|
| 211 |
+
for prob, action in predictions:
|
| 212 |
+
print(f" {prob:.3f}: {action}")
|
| 213 |
+
|
| 214 |
+
# Binary predictions
|
| 215 |
+
print("\nBinary Predictions:")
|
| 216 |
+
for (frame_id, obj_pair), predictions in results['binary_predictions'].items():
|
| 217 |
+
print(f" Frame {frame_id}, Objects {obj_pair}:")
|
| 218 |
+
for prob, relation in predictions:
|
| 219 |
+
print(f" {prob:.3f}: {relation}")
|
| 220 |
+
|
| 221 |
+
# Summary
|
| 222 |
+
print(f"\nSummary:")
|
| 223 |
+
print(f" Objects detected: {results['summary']['num_objects_detected']}")
|
| 224 |
+
print(f" Top categories: {results['summary']['top_categories']}")
|
| 225 |
+
print(f" Top actions: {results['summary']['top_actions']}")
|
| 226 |
+
print(f" Top relations: {results['summary']['top_relations']}")
|
| 227 |
+
|
| 228 |
+
print("\n✓ Successfully processed video with VINE HuggingFace interface!")
|
| 229 |
+
|
| 230 |
+
except Exception as e:
|
| 231 |
+
print(f"✗ Error processing video: {e}")
|
| 232 |
+
print("This may be due to missing segmentation models or other dependencies")
|
| 233 |
+
print("The interface is set up correctly, but full functionality requires:")
|
| 234 |
+
print(" 1. Properly installed Grounding DINO and SAM2")
|
| 235 |
+
print(" 2. Correct model weights")
|
| 236 |
+
print(" 3. Proper configuration paths")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def compare_interfaces():
|
| 240 |
+
"""
|
| 241 |
+
Compare the original inference.py approach with the new HuggingFace interface.
|
| 242 |
+
"""
|
| 243 |
+
print("\n=== Interface Comparison ===")
|
| 244 |
+
|
| 245 |
+
print("\nOriginal inference.py approach:")
|
| 246 |
+
print("✓ Direct access to model internals")
|
| 247 |
+
print("✓ Full control over segmentation pipeline")
|
| 248 |
+
print("✗ Complex setup and configuration")
|
| 249 |
+
print("✗ Not compatible with HuggingFace ecosystem")
|
| 250 |
+
print("✗ Requires manual handling of all components")
|
| 251 |
+
|
| 252 |
+
print("\nNew HuggingFace interface:")
|
| 253 |
+
print("✓ Easy to use pipeline interface")
|
| 254 |
+
print("✓ Compatible with HuggingFace Hub")
|
| 255 |
+
print("✓ Standardized configuration")
|
| 256 |
+
print("✓ Automatic handling of preprocessing/postprocessing")
|
| 257 |
+
print("✓ Easy sharing and distribution")
|
| 258 |
+
print("✓ Configurable segmentation model paths")
|
| 259 |
+
print("✗ Slightly less direct control (can still access model directly)")
|
| 260 |
+
|
| 261 |
+
print("\nMigration benefits:")
|
| 262 |
+
print("• Share your model easily on HuggingFace Hub")
|
| 263 |
+
print("• Users can load your model with a single line")
|
| 264 |
+
print("• Standardized interface for video understanding")
|
| 265 |
+
print("• Better integration with other HuggingFace tools")
|
| 266 |
+
print("• Simplified deployment and inference")
|
| 267 |
+
print("• Flexible segmentation model configuration")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
print("VINE HuggingFace Interface Conversion")
|
| 272 |
+
print("=" * 50)
|
| 273 |
+
|
| 274 |
+
# Run conversion demonstration
|
| 275 |
+
convert_inference_workflow()
|
| 276 |
+
|
| 277 |
+
# Show comparison
|
| 278 |
+
compare_interfaces()
|
| 279 |
+
|
| 280 |
+
print("\n" + "=" * 50)
|
| 281 |
+
print("Next steps:")
|
| 282 |
+
print("1. Install SAM2 and GroundingDINO dependencies")
|
| 283 |
+
print("2. Download the required model checkpoints")
|
| 284 |
+
print("3. Update the paths in this script to point to your models")
|
| 285 |
+
print("4. Test the interface with your specific model weights")
|
| 286 |
+
print("5. Adjust configuration parameters as needed")
|
| 287 |
+
print("6. Push your model to HuggingFace Hub using push_to_hub.py")
|
| 288 |
+
print("7. Share with the community!")
|
vine_hf/example_ensemble_weights.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example demonstrating how to load and use VINE ensemble weights
|
| 3 |
+
|
| 4 |
+
This script shows the correct way to load your pretrained VINE ensemble weights
|
| 5 |
+
and use them with the HuggingFace interface, based on the actual inference.py workflow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 13 |
+
|
| 14 |
+
#os.environ["OPENAI_API_KEY"]="dummy-key" # Set your OpenAI API key here or via environment variable
|
| 15 |
+
|
| 16 |
+
# Add the parent directory to the path to import vine_hf
|
| 17 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 18 |
+
|
| 19 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 20 |
+
from laser.loading import load_video
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def example_load_ensemble_weights():
|
| 24 |
+
"""Example of loading ensemble weights correctly."""
|
| 25 |
+
print("=== Loading Ensemble VINE Weights ===")
|
| 26 |
+
|
| 27 |
+
# Path to your ensemble model (adjust this to your actual path)
|
| 28 |
+
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
|
| 29 |
+
model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
|
| 30 |
+
|
| 31 |
+
print(f"Looking for ensemble weights in: {model_dir}")
|
| 32 |
+
|
| 33 |
+
if os.path.exists(model_dir):
|
| 34 |
+
print("✓ Model directory found")
|
| 35 |
+
|
| 36 |
+
# List available model files
|
| 37 |
+
model_files = [f for f in os.listdir(model_dir) if f.endswith('.model')]
|
| 38 |
+
print(f"Available model files: {model_files}")
|
| 39 |
+
|
| 40 |
+
if model_files:
|
| 41 |
+
# Create configuration with ensemble path (local directory with .model files)
|
| 42 |
+
config = VineConfig(
|
| 43 |
+
segmentation_method="grounding_dino_sam2",
|
| 44 |
+
use_hf_repo=False,
|
| 45 |
+
local_dir=model_dir,
|
| 46 |
+
local_filename=None,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
print("Creating VINE model with ensemble weights...")
|
| 50 |
+
vine_model = VineModel(config)
|
| 51 |
+
|
| 52 |
+
print("✓ VINE model created with ensemble weights!")
|
| 53 |
+
return vine_model
|
| 54 |
+
else:
|
| 55 |
+
print("✗ No .model files found in directory")
|
| 56 |
+
return None
|
| 57 |
+
else:
|
| 58 |
+
print(f"✗ Model directory not found: {model_dir}")
|
| 59 |
+
print("Please adjust the path to point to your ensemble weights")
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def example_direct_ensemble_loading():
|
| 64 |
+
"""Example of loading ensemble weights using from_pretrained_vine."""
|
| 65 |
+
print("\n=== Direct Ensemble Loading ===")
|
| 66 |
+
|
| 67 |
+
# Path to specific ensemble file
|
| 68 |
+
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
|
| 69 |
+
model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
|
| 70 |
+
|
| 71 |
+
if os.path.exists(model_dir):
|
| 72 |
+
try:
|
| 73 |
+
# Use the class method for direct loading
|
| 74 |
+
vine_model = VineModel.from_pretrained_vine(
|
| 75 |
+
model_path=model_dir,
|
| 76 |
+
epoch=0 # Load epoch 0
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
print("✓ Model loaded using from_pretrained_vine!")
|
| 80 |
+
return vine_model
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"✗ Error loading with from_pretrained_vine: {e}")
|
| 84 |
+
return None
|
| 85 |
+
else:
|
| 86 |
+
print(f"✗ Model directory not found: {model_dir}")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def example_compare_original_vs_hf():
|
| 91 |
+
"""Compare the original inference.py approach with HuggingFace interface."""
|
| 92 |
+
print("\n=== Comparing Original vs HuggingFace Interface ===")
|
| 93 |
+
|
| 94 |
+
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
|
| 95 |
+
model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
|
| 96 |
+
model_name = "ensemble-2025-02-10-14-57-22"
|
| 97 |
+
epoch = 0
|
| 98 |
+
|
| 99 |
+
if not os.path.exists(model_dir):
|
| 100 |
+
print(f"Model directory not found: {model_dir}")
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
print("Original approach (from inference.py):")
|
| 104 |
+
print("```python")
|
| 105 |
+
print("def load_model(model_dir, model_name, epoch, device):")
|
| 106 |
+
print(" model_name = model_name + f'.{epoch}.model'")
|
| 107 |
+
print(" predicate_model = torch.load(os.path.join(model_dir, model_name), map_location=device, weights_only=False)")
|
| 108 |
+
print(" return predicate_model")
|
| 109 |
+
print("")
|
| 110 |
+
print("predicate_model = load_model(model_dir, model_name, epoch, device)")
|
| 111 |
+
print("```")
|
| 112 |
+
|
| 113 |
+
print("\nNew HuggingFace approach:")
|
| 114 |
+
print("```python")
|
| 115 |
+
print("config = VineConfig(pretrained_vine_path=model_dir)")
|
| 116 |
+
print("vine_model = VineModel(config)")
|
| 117 |
+
print("# or")
|
| 118 |
+
print("vine_model = VineModel.from_pretrained_vine(model_dir, epoch=0)")
|
| 119 |
+
print("```")
|
| 120 |
+
|
| 121 |
+
# Try to load with both approaches if possible
|
| 122 |
+
try:
|
| 123 |
+
# Original approach
|
| 124 |
+
def load_model(model_dir, model_name, epoch, device):
|
| 125 |
+
model_name = model_name + f'.{epoch}.model'
|
| 126 |
+
model_path = os.path.join(model_dir, model_name)
|
| 127 |
+
if os.path.exists(model_path):
|
| 128 |
+
return torch.load(model_path, map_location=device, weights_only=False)
|
| 129 |
+
else:
|
| 130 |
+
print(f"Model file not found: {model_path}")
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 134 |
+
original_model = load_model(model_dir, model_name, epoch, device)
|
| 135 |
+
|
| 136 |
+
if original_model:
|
| 137 |
+
print(f"✓ Original model loaded: {type(original_model)}")
|
| 138 |
+
print(f" Has clip_cate_model: {hasattr(original_model, 'clip_cate_model')}")
|
| 139 |
+
print(f" Has clip_unary_model: {hasattr(original_model, 'clip_unary_model')}")
|
| 140 |
+
print(f" Has clip_binary_model: {hasattr(original_model, 'clip_binary_model')}")
|
| 141 |
+
|
| 142 |
+
# HuggingFace approach
|
| 143 |
+
vine_model = VineModel.from_pretrained_vine(model_dir, epoch=epoch)
|
| 144 |
+
|
| 145 |
+
if vine_model:
|
| 146 |
+
print(f"✓ HuggingFace model loaded: {type(vine_model)}")
|
| 147 |
+
print(f" Has clip_cate_model: {hasattr(vine_model, 'clip_cate_model')}")
|
| 148 |
+
print(f" Has clip_unary_model: {hasattr(vine_model, 'clip_unary_model')}")
|
| 149 |
+
print(f" Has clip_binary_model: {hasattr(vine_model, 'clip_binary_model')}")
|
| 150 |
+
|
| 151 |
+
print("\n✓ Both approaches work! HuggingFace interface successfully loads ensemble weights.")
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Error in comparison: {e}")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def example_ensemble_with_pipeline():
|
| 158 |
+
"""Example using ensemble weights with the pipeline."""
|
| 159 |
+
print("\n=== Using Ensemble Weights with Pipeline ===")
|
| 160 |
+
|
| 161 |
+
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
|
| 162 |
+
model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
|
| 163 |
+
|
| 164 |
+
if not os.path.exists(model_dir):
|
| 165 |
+
print(f"Model directory not found: {model_dir}")
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
# Register pipeline
|
| 169 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 170 |
+
"vine-video-understanding",
|
| 171 |
+
pipeline_class=VinePipeline,
|
| 172 |
+
pt_model=VineModel,
|
| 173 |
+
type="multimodal",
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Create model with ensemble weights (local directory)
|
| 177 |
+
config = VineConfig(
|
| 178 |
+
segmentation_method="grounding_dino_sam2",
|
| 179 |
+
use_hf_repo=False,
|
| 180 |
+
local_dir=model_dir,
|
| 181 |
+
local_filename=None,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
vine_model = VineModel(config)
|
| 185 |
+
# Create pipeline with segmentation model paths
|
| 186 |
+
vine_pipeline = VinePipeline(
|
| 187 |
+
model=vine_model,
|
| 188 |
+
tokenizer=None,
|
| 189 |
+
# SAM2 configuration
|
| 190 |
+
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
|
| 191 |
+
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
|
| 192 |
+
# GroundingDINO configuration
|
| 193 |
+
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
|
| 194 |
+
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
| 195 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
print("✓ Pipeline created with ensemble VINE weights")
|
| 199 |
+
|
| 200 |
+
# Check for demo video
|
| 201 |
+
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/v1.mp4")
|
| 202 |
+
|
| 203 |
+
if os.path.exists(demo_video):
|
| 204 |
+
print(f"Found demo video: {demo_video}")
|
| 205 |
+
|
| 206 |
+
# Use the same keywords as in the original inference.py
|
| 207 |
+
categorical_keywords = ['human', 'dog', 'frisbee']
|
| 208 |
+
unary_keywords = ['running', 'jumping', 'catching', 'throwing']
|
| 209 |
+
binary_keywords = ['behind', 'bite', 'front', 'jump over', 'right', 'left']
|
| 210 |
+
|
| 211 |
+
print("Example pipeline usage:")
|
| 212 |
+
print("```python")
|
| 213 |
+
print("results = vine_pipeline(")
|
| 214 |
+
print(f" '{demo_video}',")
|
| 215 |
+
print(f" categorical_keywords={categorical_keywords},")
|
| 216 |
+
print(f" unary_keywords={unary_keywords},")
|
| 217 |
+
print(f" binary_keywords={binary_keywords},")
|
| 218 |
+
print(" segmentation_method='grounding_dino_sam2'")
|
| 219 |
+
print(")")
|
| 220 |
+
print("```")
|
| 221 |
+
|
| 222 |
+
# Uncomment to actually run (requires segmentation models)
|
| 223 |
+
# try:
|
| 224 |
+
# results = vine_pipeline(
|
| 225 |
+
# demo_video,
|
| 226 |
+
# categorical_keywords=categorical_keywords,
|
| 227 |
+
# unary_keywords=unary_keywords,
|
| 228 |
+
# binary_keywords=binary_keywords,
|
| 229 |
+
# segmentation_method='grounding_dino_sam2'
|
| 230 |
+
# )
|
| 231 |
+
# print("Results:", results['summary'])
|
| 232 |
+
# except Exception as e:
|
| 233 |
+
# print(f"Pipeline execution failed: {e}")
|
| 234 |
+
# print("This is expected if segmentation models are not set up")
|
| 235 |
+
|
| 236 |
+
return vine_pipeline
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def demonstrate_weight_transfer():
|
| 241 |
+
"""Demonstrate how weights are transferred from ensemble to HuggingFace format."""
|
| 242 |
+
print("\n=== Weight Transfer Demonstration ===")
|
| 243 |
+
|
| 244 |
+
print("The ensemble model structure (PredicateModel):")
|
| 245 |
+
print("- clip_cate_model: CLIP model for categorical classification")
|
| 246 |
+
print("- clip_unary_model: CLIP model for unary predicates")
|
| 247 |
+
print("- clip_binary_model: CLIP model for binary relations")
|
| 248 |
+
print("- clip_tokenizer: Tokenizer for text processing")
|
| 249 |
+
print("- clip_processor: Processor for image processing")
|
| 250 |
+
|
| 251 |
+
print("\nWeight transfer process:")
|
| 252 |
+
print("1. Load ensemble model with torch.load()")
|
| 253 |
+
print("2. Initialize base CLIP models in HuggingFace format")
|
| 254 |
+
print("3. Transfer state_dict from ensemble to HuggingFace models:")
|
| 255 |
+
print(" - ensemble.clip_cate_model → hf.clip_cate_model")
|
| 256 |
+
print(" - ensemble.clip_unary_model → hf.clip_unary_model")
|
| 257 |
+
print(" - ensemble.clip_binary_model → hf.clip_binary_model")
|
| 258 |
+
print("4. Transfer tokenizer and processor")
|
| 259 |
+
|
| 260 |
+
print("\nThis preserves all your fine-tuned weights while making them HuggingFace compatible!")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def troubleshooting_guide():
|
| 264 |
+
"""Provide troubleshooting guide for common issues."""
|
| 265 |
+
print("\n=== Troubleshooting Guide ===")
|
| 266 |
+
|
| 267 |
+
print("Common Issues:")
|
| 268 |
+
print("1. 'No model file found for epoch X'")
|
| 269 |
+
print(" → Check that .model files exist in the directory")
|
| 270 |
+
print(" → Verify the epoch number is correct")
|
| 271 |
+
print(" → List files: ls /path/to/model/dir/*.model")
|
| 272 |
+
|
| 273 |
+
print("\n2. 'Error loading VINE weights'")
|
| 274 |
+
print(" → Check file permissions")
|
| 275 |
+
print(" → Verify the model file is not corrupted")
|
| 276 |
+
print(" → Try loading with torch.load() directly first")
|
| 277 |
+
|
| 278 |
+
print("\n3. 'CLIP model mismatch'")
|
| 279 |
+
print(" → Ensure config.model_name matches the base model used in training")
|
| 280 |
+
|
| 281 |
+
print("\n4. 'Device mismatch errors'")
|
| 282 |
+
print(" → Models are loaded to CPU first, then moved to device")
|
| 283 |
+
print(" → Check CUDA availability with torch.cuda.is_available()")
|
| 284 |
+
|
| 285 |
+
print("\nDebugging steps:")
|
| 286 |
+
print("1. Test loading ensemble model directly:")
|
| 287 |
+
print(" model = torch.load('path/to/model.0.model', map_location='cpu')")
|
| 288 |
+
print("2. Check model attributes:")
|
| 289 |
+
print(" print(dir(model))")
|
| 290 |
+
print("3. Verify state_dict keys:")
|
| 291 |
+
print(" print(model.clip_cate_model.state_dict().keys())")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
print("VINE Ensemble Weights Loading Examples")
|
| 296 |
+
print("=" * 50)
|
| 297 |
+
|
| 298 |
+
# Test ensemble weight loading
|
| 299 |
+
try:
|
| 300 |
+
model1 = example_load_ensemble_weights()
|
| 301 |
+
except Exception as e:
|
| 302 |
+
print(f"Ensemble loading example failed: {e}")
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
model2 = example_direct_ensemble_loading()
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"Direct loading example failed: {e}")
|
| 308 |
+
|
| 309 |
+
# Compare approaches
|
| 310 |
+
try:
|
| 311 |
+
example_compare_original_vs_hf()
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f"Comparison example failed: {e}")
|
| 314 |
+
|
| 315 |
+
# Test pipeline with ensemble weights
|
| 316 |
+
try:
|
| 317 |
+
pipeline = example_ensemble_with_pipeline()
|
| 318 |
+
except Exception as e:
|
| 319 |
+
print(f"Pipeline example failed: {e}")
|
| 320 |
+
|
| 321 |
+
# Educational content
|
| 322 |
+
demonstrate_weight_transfer()
|
| 323 |
+
troubleshooting_guide()
|
| 324 |
+
|
| 325 |
+
print("\n" + "=" * 50)
|
| 326 |
+
print("Key Points:")
|
| 327 |
+
print("1. AutoModel.from_pretrained() won't work with .pt ensemble weights")
|
| 328 |
+
print("2. Use torch.load() to load the ensemble, then transfer weights")
|
| 329 |
+
print("3. The HuggingFace interface preserves your fine-tuned weights")
|
| 330 |
+
print("4. Specify pretrained_vine_path in VineConfig to auto-load weights")
|
| 331 |
+
print("5. Use VineModel.from_pretrained_vine() for direct loading")
|
| 332 |
+
|
| 333 |
+
|
vine_hf/example_sam2_masks.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example demonstrating SAM2 mask generation in VINE HuggingFace interface
|
| 3 |
+
|
| 4 |
+
This script shows how to use both SAM2-only and Grounding DINO + SAM2
|
| 5 |
+
segmentation methods with the VINE model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 13 |
+
|
| 14 |
+
# Add the parent directory to the path to import vine_hf
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 16 |
+
# Add the parent directory to the path to import vine_hf
|
| 17 |
+
|
| 18 |
+
#Either uncomment the below or set a environemental key, though it isn't needed to run.
|
| 19 |
+
#os.environ['OPENAI_API_KEY'] = 'dummy-key'
|
| 20 |
+
|
| 21 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 22 |
+
from laser.loading import load_video
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def example_sam2_only_segmentation():
|
| 26 |
+
"""Example using SAM2 automatic mask generation only."""
|
| 27 |
+
print("=== SAM2-Only Segmentation Example ===")
|
| 28 |
+
|
| 29 |
+
# Create configuration for SAM2-only
|
| 30 |
+
config = VineConfig(
|
| 31 |
+
use_hf_repo=True,
|
| 32 |
+
model_repo="video-fm/vine_v0",
|
| 33 |
+
segmentation_method="sam2", # Use SAM2 only
|
| 34 |
+
target_fps=1,
|
| 35 |
+
debug_visualizations=True,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Register pipeline
|
| 39 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 40 |
+
"vine-video-understanding",
|
| 41 |
+
pipeline_class=VinePipeline,
|
| 42 |
+
pt_model=VineModel,
|
| 43 |
+
type="multimodal",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Create model and pipeline with SAM2 paths
|
| 47 |
+
vine_model = VineModel(config)
|
| 48 |
+
vine_pipeline = VinePipeline(
|
| 49 |
+
model=vine_model,
|
| 50 |
+
tokenizer=None,
|
| 51 |
+
sam_config_path="path/to/your/sam2/sam_config.yaml",
|
| 52 |
+
sam_checkpoint_path="path/to/your/sam2/sam_checkpoint.pth",
|
| 53 |
+
gd_config_path="path/to/your/groundingdino/config.py",
|
| 54 |
+
gd_checkpoint_path="path/to/your/groundingdino/checkpoint.pth",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Check for demo video
|
| 58 |
+
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
|
| 59 |
+
|
| 60 |
+
if os.path.exists(demo_video):
|
| 61 |
+
print(f"Processing video: {demo_video}")
|
| 62 |
+
|
| 63 |
+
# Define keywords (SAM2 will find all objects, then classify them)
|
| 64 |
+
categorical_keywords = ['human', 'dog', 'frisbee', 'object', 'person', 'animal']
|
| 65 |
+
unary_keywords = ['running', 'jumping', 'sitting', 'standing', 'moving', 'static']
|
| 66 |
+
binary_keywords = ['behind', 'in front of', 'next to', 'chasing', 'following']
|
| 67 |
+
object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3), (0,4)]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
print("Using SAM2 automatic mask generation...")
|
| 71 |
+
print("This will find all objects in the video automatically")
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# Process with SAM2 only
|
| 75 |
+
results = vine_pipeline(
|
| 76 |
+
demo_video,
|
| 77 |
+
categorical_keywords=categorical_keywords,
|
| 78 |
+
unary_keywords=unary_keywords,
|
| 79 |
+
binary_keywords=binary_keywords,
|
| 80 |
+
object_pairs=object_pairs,
|
| 81 |
+
segmentation_method="sam2",
|
| 82 |
+
return_top_k=3,
|
| 83 |
+
debug_visualizations=True,
|
| 84 |
+
debug_visualization_path=os.path.join(os.getcwd(), "sam2_debug_masks.png"),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
print("\n✓ SAM2 segmentation completed!")
|
| 88 |
+
print("Results summary:")
|
| 89 |
+
print(f" Objects detected: {results['summary']['num_objects_detected']}")
|
| 90 |
+
print(f" Top categories: {results['summary']['top_categories']}")
|
| 91 |
+
print(f" Top actions: {results['summary']['top_actions']}")
|
| 92 |
+
|
| 93 |
+
return results
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"SAM2 segmentation failed: {e}")
|
| 97 |
+
print("Make sure SAM2 models are properly installed")
|
| 98 |
+
return None
|
| 99 |
+
else:
|
| 100 |
+
print(f"Demo video not found: {demo_video}")
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
def example_grounding_dino_sam2_segmentation():
|
| 104 |
+
"""Example using Grounding DINO + SAM2 text-guided segmentation."""
|
| 105 |
+
print("\n=== Grounding DINO + SAM2 Segmentation Example ===")
|
| 106 |
+
|
| 107 |
+
# Create configuration for Grounding DINO + SAM2
|
| 108 |
+
config = VineConfig(
|
| 109 |
+
use_hf_repo=True,
|
| 110 |
+
model_repo="video-fm/vine_v0",
|
| 111 |
+
segmentation_method="grounding_dino_sam2", # Use text-guided segmentation
|
| 112 |
+
box_threshold=0.35,
|
| 113 |
+
text_threshold=0.25,
|
| 114 |
+
target_fps=1,
|
| 115 |
+
debug_visualizations=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Create model and pipeline with both SAM2 and GroundingDINO paths
|
| 119 |
+
vine_model = VineModel(config)
|
| 120 |
+
vine_pipeline = VinePipeline(
|
| 121 |
+
model=vine_model,
|
| 122 |
+
tokenizer=None,
|
| 123 |
+
# SAM2 configuration
|
| 124 |
+
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
|
| 125 |
+
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
|
| 126 |
+
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
|
| 127 |
+
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
| 128 |
+
device=0,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Check for demo video
|
| 132 |
+
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
|
| 133 |
+
|
| 134 |
+
if os.path.exists(demo_video):
|
| 135 |
+
print(f"Processing video: {demo_video}")
|
| 136 |
+
|
| 137 |
+
# Define keywords (Grounding DINO will look specifically for these)
|
| 138 |
+
categorical_keywords = ['human', 'dog', 'frisbee'] # Specific objects to find
|
| 139 |
+
unary_keywords = ['running', 'jumping', 'catching', 'throwing']
|
| 140 |
+
binary_keywords = ['behind', 'chasing', 'next to', 'throwing to']
|
| 141 |
+
object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3), (0,4)]
|
| 142 |
+
print("Using Grounding DINO + SAM2 text-guided segmentation...")
|
| 143 |
+
print(f"Looking specifically for: {categorical_keywords}")
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
# Process with Grounding DINO + SAM2
|
| 147 |
+
results = vine_pipeline(
|
| 148 |
+
demo_video,
|
| 149 |
+
categorical_keywords=categorical_keywords,
|
| 150 |
+
unary_keywords=unary_keywords,
|
| 151 |
+
binary_keywords=binary_keywords,
|
| 152 |
+
object_pairs=object_pairs,
|
| 153 |
+
segmentation_method="grounding_dino_sam2",
|
| 154 |
+
box_threshold=0.35,
|
| 155 |
+
text_threshold=0.25,
|
| 156 |
+
return_top_k=3,
|
| 157 |
+
debug_visualizations=True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
print("\n✓ Grounding DINO + SAM2 segmentation completed!")
|
| 161 |
+
print("Results summary:")
|
| 162 |
+
print(f" Objects detected: {results['summary']['num_objects_detected']}")
|
| 163 |
+
print(f" Top categories: {results['summary']['top_categories']}")
|
| 164 |
+
print(f" Top actions: {results['summary']['top_actions']}")
|
| 165 |
+
print(f" Top relations: {results['summary']['top_relations']}")
|
| 166 |
+
|
| 167 |
+
return results
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Grounding DINO + SAM2 segmentation failed: {e}")
|
| 171 |
+
print("Make sure both Grounding DINO and SAM2 models are properly installed")
|
| 172 |
+
return None
|
| 173 |
+
else:
|
| 174 |
+
print(f"Demo video not found: {demo_video}")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def compare_segmentation_methods():
|
| 179 |
+
"""Compare SAM2-only vs Grounding DINO + SAM2 approaches."""
|
| 180 |
+
print("\n=== Comparing Segmentation Methods ===")
|
| 181 |
+
|
| 182 |
+
print("\nSAM2-Only Approach:")
|
| 183 |
+
print("✓ Finds all objects automatically")
|
| 184 |
+
print("✓ No need to specify what to look for")
|
| 185 |
+
print("✓ Good for exploratory analysis")
|
| 186 |
+
print("✗ May find too many irrelevant objects")
|
| 187 |
+
print("✗ Less precise for specific object types")
|
| 188 |
+
|
| 189 |
+
print("\nGrounding DINO + SAM2 Approach:")
|
| 190 |
+
print("✓ Finds specific objects based on text prompts")
|
| 191 |
+
print("✓ More precise and targeted")
|
| 192 |
+
print("✓ Better for known object categories")
|
| 193 |
+
print("✓ Integrates object detection with segmentation")
|
| 194 |
+
print("✗ Limited to specified categories")
|
| 195 |
+
print("✗ Requires knowing what objects to look for")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def demonstrate_mask_processing():
|
| 199 |
+
"""Demonstrate how masks are processed internally."""
|
| 200 |
+
print("\n=== Mask Processing Demonstration ===")
|
| 201 |
+
|
| 202 |
+
# Load a video to show the processing pipeline
|
| 203 |
+
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
|
| 204 |
+
|
| 205 |
+
if os.path.exists(demo_video):
|
| 206 |
+
print("Loading video for mask processing demo...")
|
| 207 |
+
|
| 208 |
+
# Load video tensor
|
| 209 |
+
video_tensor = np.asarray(load_video(demo_video, target_fps=1))
|
| 210 |
+
print(f"Video shape: {video_tensor.shape}")
|
| 211 |
+
|
| 212 |
+
# Create pipeline with segmentation model paths
|
| 213 |
+
config = VineConfig(segmentation_method="sam2")
|
| 214 |
+
vine_model = VineModel(config)
|
| 215 |
+
vine_pipeline = VinePipeline(
|
| 216 |
+
model=vine_model,
|
| 217 |
+
tokenizer=None,
|
| 218 |
+
# SAM2 configuration
|
| 219 |
+
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
|
| 220 |
+
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
|
| 221 |
+
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
|
| 222 |
+
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
# Process just the first few frames to show the pipeline
|
| 227 |
+
print("\nProcessing first 2 frames with SAM2...")
|
| 228 |
+
|
| 229 |
+
# Manually call the preprocessing to show the steps
|
| 230 |
+
processed_data = vine_pipeline.preprocess(
|
| 231 |
+
video_tensor[:2], # Just first 2 frames
|
| 232 |
+
segmentation_method="sam2",
|
| 233 |
+
categorical_keywords=['object']
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
print("Mask processing results:")
|
| 237 |
+
print(f" Number of frames processed: {processed_data['num_frames']}")
|
| 238 |
+
print(f" Frames with masks: {list(processed_data['masks'].keys())}")
|
| 239 |
+
|
| 240 |
+
# Show mask details
|
| 241 |
+
for frame_id, frame_masks in processed_data['masks'].items():
|
| 242 |
+
print(f" Frame {frame_id}: {len(frame_masks)} objects detected")
|
| 243 |
+
for obj_id, mask in frame_masks.items():
|
| 244 |
+
print(f" Object {obj_id}: mask shape {mask.shape}")
|
| 245 |
+
|
| 246 |
+
print("\nBounding box extraction:")
|
| 247 |
+
for frame_id, frame_bboxes in processed_data['bboxes'].items():
|
| 248 |
+
print(f" Frame {frame_id}: {len(frame_bboxes)} bounding boxes")
|
| 249 |
+
for obj_id, bbox in frame_bboxes.items():
|
| 250 |
+
print(f" Object {obj_id}: bbox {bbox}")
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
print(f"Mask processing failed: {e}")
|
| 254 |
+
print("This is expected if SAM2 models are not properly set up")
|
| 255 |
+
else:
|
| 256 |
+
print(f"Demo video not found: {demo_video}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def test_mask_formats():
|
| 260 |
+
"""Test different mask input formats."""
|
| 261 |
+
print("\n=== Testing Mask Formats ===")
|
| 262 |
+
|
| 263 |
+
# Create dummy data to test mask processing
|
| 264 |
+
height, width = 224, 224
|
| 265 |
+
|
| 266 |
+
# Test different mask formats
|
| 267 |
+
print("Testing mask format conversions...")
|
| 268 |
+
|
| 269 |
+
# Format 1: NumPy boolean array
|
| 270 |
+
mask_np = np.random.rand(height, width) > 0.5
|
| 271 |
+
print(f"NumPy mask: {mask_np.shape}, dtype: {mask_np.dtype}")
|
| 272 |
+
|
| 273 |
+
# Format 2: PyTorch tensor
|
| 274 |
+
mask_torch = torch.from_numpy(mask_np)
|
| 275 |
+
print(f"PyTorch mask: {mask_torch.shape}, dtype: {mask_torch.dtype}")
|
| 276 |
+
|
| 277 |
+
# Format 3: 3D mask with singleton dimension
|
| 278 |
+
mask_3d = mask_torch.unsqueeze(-1)
|
| 279 |
+
print(f"3D mask: {mask_3d.shape}")
|
| 280 |
+
|
| 281 |
+
# Test bounding box extraction
|
| 282 |
+
from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
|
| 283 |
+
|
| 284 |
+
try:
|
| 285 |
+
bbox = mask_to_bbox(mask_torch)
|
| 286 |
+
print(f"Extracted bbox: {bbox}")
|
| 287 |
+
print("✓ Mask format testing successful")
|
| 288 |
+
except Exception as e:
|
| 289 |
+
print(f"Mask format testing failed: {e}")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
print("VINE SAM2 Mask Generation Examples")
|
| 294 |
+
print("=" * 50)
|
| 295 |
+
|
| 296 |
+
# Test SAM2-only approach
|
| 297 |
+
try:
|
| 298 |
+
sam2_results = example_sam2_only_segmentation()
|
| 299 |
+
except Exception as e:
|
| 300 |
+
print(f"SAM2-only example failed: {e}")
|
| 301 |
+
|
| 302 |
+
# Test Grounding DINO + SAM2 approach
|
| 303 |
+
try:
|
| 304 |
+
gd_sam2_results = example_grounding_dino_sam2_segmentation()
|
| 305 |
+
except Exception as e:
|
| 306 |
+
print(f"Grounding DINO + SAM2 example failed: {e}")
|
| 307 |
+
|
| 308 |
+
# Compare approaches
|
| 309 |
+
compare_segmentation_methods()
|
| 310 |
+
|
| 311 |
+
# Demonstrate mask processing
|
| 312 |
+
try:
|
| 313 |
+
demonstrate_mask_processing()
|
| 314 |
+
except Exception as e:
|
| 315 |
+
print(f"Mask processing demo failed: {e}")
|
| 316 |
+
|
| 317 |
+
# Test mask formats
|
| 318 |
+
try:
|
| 319 |
+
test_mask_formats()
|
| 320 |
+
except Exception as e:
|
| 321 |
+
print(f"Mask format testing failed: {e}")
|
| 322 |
+
|
| 323 |
+
print("\n" + "=" * 50)
|
| 324 |
+
print("Examples completed!")
|
| 325 |
+
print("\nKey takeaways:")
|
| 326 |
+
print("1. SAM2-only: Automatic object detection and segmentation")
|
| 327 |
+
print("2. Grounding DINO + SAM2: Text-guided object detection and segmentation")
|
| 328 |
+
print("3. Both methods provide masks and bounding boxes for VINE model")
|
| 329 |
+
print("4. Choose method based on whether you know what objects to look for")
|
| 330 |
+
|
| 331 |
+
|
vine_hf/example_usage.ipynb
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "44d53281",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/home/kevinx/miniconda3/envs/laser_env/lib/python3.10/site-packages/pydantic/_internal/_config.py:383: UserWarning: Valid config keys have changed in V2:\n",
|
| 14 |
+
"* 'schema_extra' has been renamed to 'json_schema_extra'\n",
|
| 15 |
+
" warnings.warn(message, UserWarning)\n",
|
| 16 |
+
"Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n",
|
| 17 |
+
"Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
|
| 18 |
+
]
|
| 19 |
+
}
|
| 20 |
+
],
|
| 21 |
+
"source": [
|
| 22 |
+
"import os\n",
|
| 23 |
+
"import sys\n",
|
| 24 |
+
"import torch\n",
|
| 25 |
+
"from transformers import pipeline, AutoModel\n",
|
| 26 |
+
"from transformers.pipelines import PIPELINE_REGISTRY\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"# Uncomment or set your own\n",
|
| 29 |
+
"#os.environ['OPENAI_API_KEY'] = 'dummy-key'\n",
|
| 30 |
+
"from vine_hf import VineConfig, VineModel, VinePipeline"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": 2,
|
| 36 |
+
"id": "174e479f",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"PIPELINE_REGISTRY.register_pipeline(\n",
|
| 41 |
+
" \"vine-video-understanding\",\n",
|
| 42 |
+
" pipeline_class=VinePipeline,\n",
|
| 43 |
+
" pt_model=VineModel,\n",
|
| 44 |
+
" type=\"multimodal\",\n",
|
| 45 |
+
")"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"id": "a9af2770",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"vine_config = VineConfig(\n",
|
| 56 |
+
" model_name=\"openai/clip-vit-base-patch32\",\n",
|
| 57 |
+
" # Local file example: set use_hf_repo=False and provide local_dir/local_filename\n",
|
| 58 |
+
" use_hf_repo=False,\n",
|
| 59 |
+
" local_dir=os.path.dirname('/path/to/your/pretrained/model.pt'),\n",
|
| 60 |
+
" local_filename=os.path.basename('/path/to/your/pretrained/model.pt'), # Local file path\n",
|
| 61 |
+
" segmentation_method=\"grounding_dino_sam2\",\n",
|
| 62 |
+
" visualize=True,\n",
|
| 63 |
+
" visualization_dir=\"path/to/visualization/dir\",\n",
|
| 64 |
+
" debug_visualizations=True,\n",
|
| 65 |
+
" device=0, # Change to your desired device\n",
|
| 66 |
+
")"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": null,
|
| 72 |
+
"id": "274e6515",
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [
|
| 75 |
+
{
|
| 76 |
+
"name": "stdout",
|
| 77 |
+
"output_type": "stream",
|
| 78 |
+
"text": [
|
| 79 |
+
"Loaded state type: <class 'collections.OrderedDict'>\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"vine_pipeline = VinePipeline(\n",
|
| 85 |
+
" model=VineModel(vine_config), \n",
|
| 86 |
+
" tokenizer=None,\n",
|
| 87 |
+
" sam_config_path=\"path/to/sam2/configs/sam2_hiera_base_plus.yaml\",\n",
|
| 88 |
+
" sam_checkpoint_path=\"path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt\",\n",
|
| 89 |
+
" gd_config_path=\"path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py\",\n",
|
| 90 |
+
" gd_checkpoint_path=\"path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth\",\n",
|
| 91 |
+
")"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"cell_type": "code",
|
| 96 |
+
"execution_count": 6,
|
| 97 |
+
"id": "123a090d",
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"outputs": [],
|
| 100 |
+
"source": [
|
| 101 |
+
"categorical_keywords = ['human', 'dog', 'frisbee']\n",
|
| 102 |
+
"unary_keywords = ['running', 'jumping', 'catching', 'throwing']\n",
|
| 103 |
+
"binary_keywords = ['behind', 'in front of', 'next to', 'chasing']\n",
|
| 104 |
+
"object_pairs = [(0, 1), (0, 2), (1, 2)] # human-dog, dog-frisbee relationships "
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": 7,
|
| 110 |
+
"id": "0b42f032",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"demo_video_path = \"/home/kevinx/LASER/LASER/demo/videos/v1.mp4\" # Replace with your video file path"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": 8,
|
| 120 |
+
"id": "8202c654",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"outputs": [
|
| 123 |
+
{
|
| 124 |
+
"name": "stdout",
|
| 125 |
+
"output_type": "stream",
|
| 126 |
+
"text": [
|
| 127 |
+
"Segmentation method: grounding_dino_sam2\n",
|
| 128 |
+
"Generating Grounding DINO + SAM2 masks...\n",
|
| 129 |
+
"<class 'int'>\n",
|
| 130 |
+
"✓ SAM2 models initialized successfully\n",
|
| 131 |
+
"<class 'int'>\n"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"name": "stderr",
|
| 136 |
+
"output_type": "stream",
|
| 137 |
+
"text": [
|
| 138 |
+
"UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4314.)\n"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"name": "stdout",
|
| 143 |
+
"output_type": "stream",
|
| 144 |
+
"text": [
|
| 145 |
+
"final text_encoder_type: bert-base-uncased\n",
|
| 146 |
+
"✓ GroundingDINO model initialized successfully\n",
|
| 147 |
+
"Start detecting objects at time 05:08:58.178592\n"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"name": "stderr",
|
| 152 |
+
"output_type": "stream",
|
| 153 |
+
"text": [
|
| 154 |
+
"Detecting objects: 0%| | 0/3 [00:00<?, ?it/s]FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
|
| 155 |
+
"UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
|
| 156 |
+
"UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
|
| 157 |
+
"FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 158 |
+
"Detecting objects: 100%|██████████| 3/3 [00:01<00:00, 2.82it/s]\n"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"name": "stdout",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"Finished detecting objects at time 05:08:59.250419\n",
|
| 166 |
+
"Loading inference state at time 05:08:59.544425\n",
|
| 167 |
+
"Number of frames: 3\n",
|
| 168 |
+
"None\n"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"name": "stderr",
|
| 173 |
+
"output_type": "stream",
|
| 174 |
+
"text": [
|
| 175 |
+
"Processing frames: 100%|██████████| 3/3 [00:00<00:00, 11.77it/s]\n"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"name": "stdout",
|
| 180 |
+
"output_type": "stream",
|
| 181 |
+
"text": [
|
| 182 |
+
"Annotated frames: []\n",
|
| 183 |
+
"Find the most dense prompt at time 05:09:01.413703\n",
|
| 184 |
+
"Most dense frame: 0\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"Start propagating objects at time 05:09:01.416367\n",
|
| 188 |
+
"Pass count: 0\n"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"name": "stderr",
|
| 193 |
+
"output_type": "stream",
|
| 194 |
+
"text": [
|
| 195 |
+
"propagate in video: 100%|██████████| 3/3 [00:00<00:00, 20.20it/s]\n",
|
| 196 |
+
"propagate in video: 0it [00:00, ?it/s]\n"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"name": "stdout",
|
| 201 |
+
"output_type": "stream",
|
| 202 |
+
"text": [
|
| 203 |
+
"Most dense frame: 1\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"Pass count: 1\n"
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"name": "stderr",
|
| 211 |
+
"output_type": "stream",
|
| 212 |
+
"text": [
|
| 213 |
+
"propagate in video: 100%|██████████| 3/3 [00:00<00:00, 19.25it/s]\n",
|
| 214 |
+
"propagate in video: 0it [00:00, ?it/s]\n"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"name": "stdout",
|
| 219 |
+
"output_type": "stream",
|
| 220 |
+
"text": [
|
| 221 |
+
"Most dense frame: 2\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"Pass count: 2\n"
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"name": "stderr",
|
| 229 |
+
"output_type": "stream",
|
| 230 |
+
"text": [
|
| 231 |
+
"propagate in video: 100%|██████████| 3/3 [00:00<00:00, 25.92it/s]\n",
|
| 232 |
+
"propagate in video: 0it [00:00, ?it/s]\n"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"name": "stdout",
|
| 237 |
+
"output_type": "stream",
|
| 238 |
+
"text": [
|
| 239 |
+
"Most dense frame: -1\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"Results:\n",
|
| 244 |
+
"Summary: {'num_objects_detected': 4, 'num_unary_predictions': 10, 'num_binary_predictions': 3, 'top_categories': [('frisbee', 0.9989640712738037), ('dog', 0.957672655582428), ('dog', 0.957672655582428)], 'top_actions': [('running', 0.8483631610870361), ('running', 0.832377016544342), ('running', 0.8178836107254028)], 'top_relations': [('chasing', 0.9616015553474426), ('chasing', 0.9478002786636353), ('chasing', 0.6380977630615234)]}\n"
|
| 245 |
+
]
|
| 246 |
+
}
|
| 247 |
+
],
|
| 248 |
+
"source": [
|
| 249 |
+
"try:\n",
|
| 250 |
+
" results = vine_pipeline(\n",
|
| 251 |
+
" demo_video_path,\n",
|
| 252 |
+
" categorical_keywords=categorical_keywords,\n",
|
| 253 |
+
" unary_keywords=unary_keywords,\n",
|
| 254 |
+
" binary_keywords=binary_keywords,\n",
|
| 255 |
+
" object_pairs=object_pairs,\n",
|
| 256 |
+
" segmentation_method='grounding_dino_sam2',\n",
|
| 257 |
+
" return_top_k=3,\n",
|
| 258 |
+
" include_visualizations=False,\n",
|
| 259 |
+
" debug_visualizations=False,\n",
|
| 260 |
+
" )\n",
|
| 261 |
+
" \n",
|
| 262 |
+
" print(\"\\nResults:\")\n",
|
| 263 |
+
" print(f\"Summary: {results['summary']}\")\n",
|
| 264 |
+
" \n",
|
| 265 |
+
"except Exception as e:\n",
|
| 266 |
+
" print(f\"Note: Full execution requires segmentation models to be properly set up.\")\n",
|
| 267 |
+
" print(f\"Error: {e}\")"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"cell_type": "code",
|
| 272 |
+
"execution_count": 9,
|
| 273 |
+
"id": "414ede9b",
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"outputs": [
|
| 276 |
+
{
|
| 277 |
+
"name": "stdout",
|
| 278 |
+
"output_type": "stream",
|
| 279 |
+
"text": [
|
| 280 |
+
"Summary: {'num_objects_detected': 4, 'num_unary_predictions': 10, 'num_binary_predictions': 3, 'top_categories': [('frisbee', 0.9989640712738037), ('dog', 0.957672655582428), ('dog', 0.957672655582428)], 'top_actions': [('running', 0.8483631610870361), ('running', 0.832377016544342), ('running', 0.8178836107254028)], 'top_relations': [('chasing', 0.9616015553474426), ('chasing', 0.9478002786636353), ('chasing', 0.6380977630615234)]}\n"
|
| 281 |
+
]
|
| 282 |
+
}
|
| 283 |
+
],
|
| 284 |
+
"source": [
|
| 285 |
+
"print(f\"Summary: {results['summary']}\")"
|
| 286 |
+
]
|
| 287 |
+
}
|
| 288 |
+
],
|
| 289 |
+
"metadata": {
|
| 290 |
+
"kernelspec": {
|
| 291 |
+
"display_name": "laser_env",
|
| 292 |
+
"language": "python",
|
| 293 |
+
"name": "python3"
|
| 294 |
+
},
|
| 295 |
+
"language_info": {
|
| 296 |
+
"codemirror_mode": {
|
| 297 |
+
"name": "ipython",
|
| 298 |
+
"version": 3
|
| 299 |
+
},
|
| 300 |
+
"file_extension": ".py",
|
| 301 |
+
"mimetype": "text/x-python",
|
| 302 |
+
"name": "python",
|
| 303 |
+
"nbconvert_exporter": "python",
|
| 304 |
+
"pygments_lexer": "ipython3",
|
| 305 |
+
"version": "3.10.0"
|
| 306 |
+
}
|
| 307 |
+
},
|
| 308 |
+
"nbformat": 4,
|
| 309 |
+
"nbformat_minor": 5
|
| 310 |
+
}
|
vine_hf/example_usage.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example usage of VINE HuggingFace interface
|
| 3 |
+
|
| 4 |
+
This script demonstrates how to use the VINE model through the HuggingFace interface
|
| 5 |
+
for video understanding with categorical, unary, and binary keyword predictions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import pipeline, AutoModel
|
| 12 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 13 |
+
|
| 14 |
+
# Add the parent directory to the path to import vine_hf
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 16 |
+
|
| 17 |
+
# Uncomment or set your own
|
| 18 |
+
#os.environ['OPENAI_API_KEY'] = 'dummy-key'
|
| 19 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 20 |
+
|
| 21 |
+
def example_direct_model_usage():
|
| 22 |
+
"""Example of using the VINE model directly."""
|
| 23 |
+
print("=== Direct Model Usage ===")
|
| 24 |
+
|
| 25 |
+
# Create configuration
|
| 26 |
+
config = VineConfig(
|
| 27 |
+
model_name="openai/clip-vit-base-patch32",
|
| 28 |
+
segmentation_method="grounding_dino_sam2",
|
| 29 |
+
use_hf_repo=True,
|
| 30 |
+
model_repo="video-fm/vine_v0", # Your HF Hub model
|
| 31 |
+
debug_visualizations=True,
|
| 32 |
+
debug_visualization_path=os.path.join(os.getcwd(), "debug_masks.png"),
|
| 33 |
+
target_fps=30,
|
| 34 |
+
box_threshold=0.35,
|
| 35 |
+
text_threshold=0.25
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Initialize model
|
| 39 |
+
model = VineModel(config)
|
| 40 |
+
|
| 41 |
+
print(f"Model initialized with CLIP backbone: {config.model_name}")
|
| 42 |
+
print(f"Segmentation method: {config.segmentation_method}")
|
| 43 |
+
print(f"Device: {model.device}")
|
| 44 |
+
|
| 45 |
+
# Example video data (placeholder - in real usage, load from video file)
|
| 46 |
+
num_frames, height, width = 3, 224, 224
|
| 47 |
+
video_frames = torch.randn(num_frames, height, width, 3) * 255
|
| 48 |
+
video_frames = video_frames.clamp(0, 255).byte()
|
| 49 |
+
|
| 50 |
+
# Example masks and bboxes (placeholder - in real usage, generated by segmentation)
|
| 51 |
+
masks = {
|
| 52 |
+
0: {1: torch.ones(height, width, 1), 2: torch.ones(height, width, 1)},
|
| 53 |
+
1: {1: torch.ones(height, width, 1), 2: torch.ones(height, width, 1)},
|
| 54 |
+
2: {1: torch.ones(height, width, 1), 2: torch.ones(height, width, 1)}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
bboxes = {
|
| 58 |
+
0: {1: [50, 50, 150, 150], 2: [100, 100, 200, 200]},
|
| 59 |
+
1: {1: [52, 52, 152, 152], 2: [102, 102, 202, 202]},
|
| 60 |
+
2: {1: [54, 54, 154, 154], 2: [104, 104, 204, 204]}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# Define keywords
|
| 64 |
+
categorical_keywords = ["human", "dog", "frisbee"]
|
| 65 |
+
unary_keywords = ["running", "jumping", "sitting", "standing"]
|
| 66 |
+
binary_keywords = ["behind", "in front of", "next to", "throwing to", "catching from"]
|
| 67 |
+
object_pairs = [(1, 2)] # Object 1 relates to Object 2
|
| 68 |
+
|
| 69 |
+
# Run prediction
|
| 70 |
+
print("\nRunning prediction...")
|
| 71 |
+
results = model.predict(
|
| 72 |
+
video_frames=video_frames,
|
| 73 |
+
masks=masks,
|
| 74 |
+
bboxes=bboxes,
|
| 75 |
+
categorical_keywords=categorical_keywords,
|
| 76 |
+
unary_keywords=unary_keywords,
|
| 77 |
+
binary_keywords=binary_keywords,
|
| 78 |
+
object_pairs=object_pairs,
|
| 79 |
+
return_top_k=3
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
print("\nResults:")
|
| 83 |
+
print(f"Categorical predictions: {len(results['categorical_predictions'])} objects")
|
| 84 |
+
print(f"Unary predictions: {len(results['unary_predictions'])} actions")
|
| 85 |
+
print(f"Binary predictions: {len(results['binary_predictions'])} relations")
|
| 86 |
+
print(f"Confidence scores: {results['confidence_scores']}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def example_pipeline_usage():
|
| 90 |
+
"""Example of using the VINE pipeline."""
|
| 91 |
+
print("\n=== Pipeline Usage ===")
|
| 92 |
+
|
| 93 |
+
# Register the pipeline
|
| 94 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 95 |
+
"vine-video-understanding",
|
| 96 |
+
pipeline_class=VinePipeline,
|
| 97 |
+
pt_model=VineModel,
|
| 98 |
+
type="multimodal",
|
| 99 |
+
)
|
| 100 |
+
vine_config = VineConfig(
|
| 101 |
+
model_name="openai/clip-vit-base-patch32",
|
| 102 |
+
use_hf_repo=True,
|
| 103 |
+
model_repo="video-fm/vine_v0", # Your HF Hub model
|
| 104 |
+
segmentation_method="grounding_dino_sam2",
|
| 105 |
+
debug_visualizations=True,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
vine_pipe = VinePipeline(
|
| 109 |
+
model=VineModel(vine_config),
|
| 110 |
+
tokenizer=None,
|
| 111 |
+
trust_remote_code=True,
|
| 112 |
+
# SAM2 configuration
|
| 113 |
+
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
|
| 114 |
+
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
|
| 115 |
+
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
|
| 116 |
+
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
| 117 |
+
device=0,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
print("Pipeline created successfully!")
|
| 122 |
+
|
| 123 |
+
# Example usage with video path
|
| 124 |
+
video_path = "path/to/your/video.mp4" # Replace with actual video path
|
| 125 |
+
|
| 126 |
+
# For demonstration, we'll show the expected usage format
|
| 127 |
+
print(f"\nExample pipeline call (replace with actual video path):")
|
| 128 |
+
print(f"results = vine_pipeline(")
|
| 129 |
+
print(f" '{video_path}',")
|
| 130 |
+
print(f" categorical_keywords=['human', 'dog', 'frisbee'],")
|
| 131 |
+
print(f" unary_keywords=['running', 'jumping', 'sitting'],")
|
| 132 |
+
print(f" binary_keywords=['behind', 'in front of', 'next to'],")
|
| 133 |
+
print(f" object_pairs=[(1, 2)],")
|
| 134 |
+
print(f" segmentation_method='grounding_dino_sam2',")
|
| 135 |
+
print(f" return_top_k=3,")
|
| 136 |
+
print(f" return_flattened_segments=True,")
|
| 137 |
+
print(f" return_valid_pairs=True,")
|
| 138 |
+
print(f" include_visualizations=True,")
|
| 139 |
+
print(f" debug_visualizations=True")
|
| 140 |
+
print(f")")
|
| 141 |
+
|
| 142 |
+
# Note: Actual execution would require proper video file and segmentation models
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def example_huggingface_hub_usage():
|
| 146 |
+
"""Example of how to push and load from HuggingFace Hub."""
|
| 147 |
+
print("\n=== HuggingFace Hub Usage ===")
|
| 148 |
+
|
| 149 |
+
# Example of preparing model for Hub
|
| 150 |
+
config = VineConfig()
|
| 151 |
+
model = VineModel(config)
|
| 152 |
+
|
| 153 |
+
# Register for auto classes
|
| 154 |
+
config.register_for_auto_class()
|
| 155 |
+
model.register_for_auto_class("AutoModel")
|
| 156 |
+
|
| 157 |
+
print("Model registered for auto classes")
|
| 158 |
+
|
| 159 |
+
# Example push to hub (commented out - requires actual model weights and credentials)
|
| 160 |
+
# config.push_to_hub('your-username/vine-model')
|
| 161 |
+
# model.push_to_hub('your-username/vine-model')
|
| 162 |
+
|
| 163 |
+
# Example load from hub (commented out - requires actual model on hub)
|
| 164 |
+
# model = AutoModel.from_pretrained('your-username/vine-model', trust_remote_code=True)
|
| 165 |
+
# pipeline = pipeline('vine-video-understanding', model='your-username/vine-model', trust_remote_code=True)
|
| 166 |
+
|
| 167 |
+
print("To push to Hub:")
|
| 168 |
+
print("1. config.push_to_hub('your-username/vine-model')")
|
| 169 |
+
print("2. model.push_to_hub('your-username/vine-model')")
|
| 170 |
+
print("\nTo load from Hub:")
|
| 171 |
+
print("model = AutoModel.from_pretrained('your-username/vine-model', trust_remote_code=True)")
|
| 172 |
+
print("pipe = pipeline('vine-video-understanding', model='your-username/vine-model', trust_remote_code=True)")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def example_with_real_video():
|
| 176 |
+
"""Example showing how to use with a real video file."""
|
| 177 |
+
print("\n=== Real Video Usage Example ===")
|
| 178 |
+
|
| 179 |
+
# Check if demo video exists
|
| 180 |
+
demo_video_path = os.path.join(os.path.dirname(__file__), "../demo/videos/v1.mp4")
|
| 181 |
+
|
| 182 |
+
if os.path.exists(demo_video_path):
|
| 183 |
+
print(f"Found demo video: {demo_video_path}")
|
| 184 |
+
|
| 185 |
+
# Create pipeline with segmentation model paths
|
| 186 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 187 |
+
"vine-video-understanding",
|
| 188 |
+
pipeline_class=VinePipeline,
|
| 189 |
+
pt_model=VineModel,
|
| 190 |
+
type="multimodal",
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
vine_config = VineConfig(
|
| 194 |
+
model_name="openai/clip-vit-base-patch32",
|
| 195 |
+
use_hf_repo=True,
|
| 196 |
+
model_repo="video-fm/vine_v0", # Your HF Hub model
|
| 197 |
+
segmentation_method="grounding_dino_sam2",
|
| 198 |
+
debug_visualizations=True,
|
| 199 |
+
debug_visualization_path=os.path.join(os.getcwd(), "real_video_debug_masks.png"),
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
vine_pipeline = VinePipeline(
|
| 203 |
+
model=VineModel(vine_config),
|
| 204 |
+
tokenizer=None,
|
| 205 |
+
trust_remote_code=True,
|
| 206 |
+
# SAM2 configuration
|
| 207 |
+
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
|
| 208 |
+
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
|
| 209 |
+
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
|
| 210 |
+
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Define keywords based on the demo
|
| 214 |
+
categorical_keywords = ['human', 'dog', 'frisbee']
|
| 215 |
+
unary_keywords = ['running', 'jumping', 'catching', 'throwing']
|
| 216 |
+
binary_keywords = ['behind', 'in front of', 'next to', 'chasing']
|
| 217 |
+
object_pairs = [(0, 1), (0, 2), (1, 2)] # human-dog, dog-frisbee relationships
|
| 218 |
+
|
| 219 |
+
print("\nProcessing video with VINE...")
|
| 220 |
+
print("Keywords:")
|
| 221 |
+
print(f" Categorical: {categorical_keywords}")
|
| 222 |
+
print(f" Unary: {unary_keywords}")
|
| 223 |
+
print(f" Binary: {binary_keywords}")
|
| 224 |
+
print(f" Object pairs: {object_pairs}")
|
| 225 |
+
|
| 226 |
+
# Note: This would require proper segmentation models to be set up
|
| 227 |
+
try:
|
| 228 |
+
results = vine_pipeline(
|
| 229 |
+
demo_video_path,
|
| 230 |
+
categorical_keywords=categorical_keywords,
|
| 231 |
+
unary_keywords=unary_keywords,
|
| 232 |
+
binary_keywords=binary_keywords,
|
| 233 |
+
object_pairs=object_pairs,
|
| 234 |
+
segmentation_method='grounding_dino_sam2',
|
| 235 |
+
return_top_k=3,
|
| 236 |
+
include_visualizations=False,
|
| 237 |
+
debug_visualizations=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
print("\nResults:")
|
| 241 |
+
print(f"Summary: {results['summary']}")
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"Note: Full execution requires segmentation models to be properly set up.")
|
| 245 |
+
print(f"Error: {e}")
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
print(f"Demo video not found at: {demo_video_path}")
|
| 249 |
+
print("To use with a real video, provide the path to your video file.")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
print("VINE HuggingFace Interface Examples")
|
| 254 |
+
print("=" * 50)
|
| 255 |
+
|
| 256 |
+
# Run examples
|
| 257 |
+
try:
|
| 258 |
+
example_direct_model_usage()
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Direct model usage failed: {e}")
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
example_pipeline_usage()
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f"Pipeline usage failed: {e}")
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
example_huggingface_hub_usage()
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"Hub usage example failed: {e}")
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
example_with_real_video()
|
| 274 |
+
except Exception as e:
|
| 275 |
+
print(f"Real video example failed: {e}")
|
| 276 |
+
|
| 277 |
+
print("\n" + "=" * 50)
|
| 278 |
+
print("Examples completed!")
|
| 279 |
+
print("\nNext steps:")
|
| 280 |
+
print("1. Set up Grounding DINO and SAM2 models for segmentation")
|
| 281 |
+
print("2. Load your pretrained VINE model weights")
|
| 282 |
+
print("3. Test with your own videos")
|
| 283 |
+
print("4. Push to HuggingFace Hub for sharing")
|
vine_hf/example_visualization.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example visualization runner for VINE
|
| 2 |
+
# - Loads a video (path, demo, or random)
|
| 3 |
+
# - Runs the VINE pipeline
|
| 4 |
+
# - Saves annotated frames and an MP4 if available
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import argparse
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from collections.abc import Mapping, Sequence
|
| 12 |
+
|
| 13 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 14 |
+
from transformers import pipeline
|
| 15 |
+
|
| 16 |
+
# Set your OpenAI API key here or via environment variable
|
| 17 |
+
os.environ['OPENAI_API_KEY'] = "dummy-key"
|
| 18 |
+
|
| 19 |
+
# Local imports (workspace)
|
| 20 |
+
sys.path.append(os.path.dirname(__file__))
|
| 21 |
+
|
| 22 |
+
from vine_hf.vine_pipeline import VinePipeline # https://github.com link not needed; local path used
|
| 23 |
+
from vine_hf.vine_model import VineModel
|
| 24 |
+
from vine_hf.vine_config import VineConfig
|
| 25 |
+
from laser.loading import load_video
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build_pipeline(args) -> VinePipeline:
|
| 29 |
+
# Register pipeline type
|
| 30 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 31 |
+
"vine-video-understanding",
|
| 32 |
+
pipeline_class=VinePipeline,
|
| 33 |
+
pt_model=VineModel,
|
| 34 |
+
type="multimodal",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
config = VineConfig(
|
| 38 |
+
segmentation_method="grounding_dino_sam2",
|
| 39 |
+
model_name="openai/clip-vit-base-patch32",
|
| 40 |
+
# Example: load from HF repo
|
| 41 |
+
use_hf_repo=True,
|
| 42 |
+
model_repo="video-fm/vine_v0",
|
| 43 |
+
# Alternatively use a local path by setting use_hf_repo=False and local_dir/local_filename
|
| 44 |
+
box_threshold=args.box_threshold,
|
| 45 |
+
text_threshold=args.text_threshold,
|
| 46 |
+
target_fps=args.fps,
|
| 47 |
+
topk_cate=args.topk_cate,
|
| 48 |
+
visualization_dir=args.out_dir,
|
| 49 |
+
visualize=True,
|
| 50 |
+
debug_visualizations=True,
|
| 51 |
+
device=args.device,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
model = VineModel(config)
|
| 55 |
+
|
| 56 |
+
# Create pipeline instance with segmentation model paths (if provided)
|
| 57 |
+
vine_pipe = VinePipeline(
|
| 58 |
+
model=model,
|
| 59 |
+
tokenizer=None,
|
| 60 |
+
sam_config_path="//home/kevinx/LASER/video-sam2/sam2/sam2_hiera_t.yaml",
|
| 61 |
+
sam_checkpoint_path="//home/kevinx/LASER/video-sam2/sam2_hiera_tiny.pt",
|
| 62 |
+
gd_config_path="//home/kevinx/LASER/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
| 63 |
+
gd_checkpoint_path="//home/kevinx/LASER/GroundingDINO/weights/groundingdino_swint_ogc.pth",
|
| 64 |
+
device=args.device,
|
| 65 |
+
trust_remote_code=True,
|
| 66 |
+
)
|
| 67 |
+
return vine_pipe
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def resolve_video(args) -> np.ndarray | str:
|
| 71 |
+
# Priority: user --video -> demo video -> random frames
|
| 72 |
+
if args.video and os.path.exists(args.video):
|
| 73 |
+
return args.video
|
| 74 |
+
|
| 75 |
+
demo_video = "//home/kevinx/LASER/LASER/demo/videos/v1.mp4"
|
| 76 |
+
demo_alt = "//home/kevinx/LASER/LASER/demo/videos/v2.mp4"
|
| 77 |
+
if os.path.exists(demo_video):
|
| 78 |
+
return demo_video
|
| 79 |
+
if os.path.exists(demo_alt):
|
| 80 |
+
return demo_alt
|
| 81 |
+
|
| 82 |
+
# Fallback to random frames (uint8 HxWx3) shaped as T x H x W x 3
|
| 83 |
+
print("No video found; using random frames.")
|
| 84 |
+
rng = np.random.default_rng(0)
|
| 85 |
+
frames = rng.integers(0, 255, size=(args.rand_frames, args.height, args.width, 3), dtype=np.uint8)
|
| 86 |
+
return frames
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main():
|
| 91 |
+
parser = argparse.ArgumentParser(description="VINE visualization example")
|
| 92 |
+
parser.add_argument("--video", type=str, default=None, help="Path to a video file")
|
| 93 |
+
parser.add_argument("--out_dir", type=str, default="output", help="Output directory")
|
| 94 |
+
parser.add_argument("--method", type=str, default="grounding_dino_sam2", choices=["sam2", "grounding_dino_sam2"], help="Segmentation method")
|
| 95 |
+
parser.add_argument("--fps", type=int, default=5, help="Target FPS for processing")
|
| 96 |
+
parser.add_argument("--box_threshold", type=float, default=0.3, help="GroundingDINO box threshold")
|
| 97 |
+
parser.add_argument("--text_threshold", type=float, default=0.3, help="GroundingDINO text threshold")
|
| 98 |
+
parser.add_argument("--topk_cate", type=int, default=5, help="Top-K categories to display")
|
| 99 |
+
parser.add_argument("--device", type=int, default=0, help="CUDA device index or -1 for CPU")
|
| 100 |
+
parser.add_argument("--debug_visualizations", action="store_true", help="Enable debug visualizations")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
|
| 105 |
+
vine_pipe = build_pipeline(args)
|
| 106 |
+
video = resolve_video(args)
|
| 107 |
+
|
| 108 |
+
# Keywords similar to examples/tests
|
| 109 |
+
categorical_keywords = ["dog", "frisbee", "cat"]
|
| 110 |
+
unary_keywords = ["running", "jumping", "sitting", "flying"]
|
| 111 |
+
binary_keywords = ["behind", "next to", "chasing","biting"]
|
| 112 |
+
object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3)]
|
| 113 |
+
|
| 114 |
+
print("Running VINE pipeline...")
|
| 115 |
+
call_kwargs = dict(
|
| 116 |
+
categorical_keywords=categorical_keywords,
|
| 117 |
+
unary_keywords=unary_keywords,
|
| 118 |
+
binary_keywords=binary_keywords,
|
| 119 |
+
object_pairs=object_pairs,
|
| 120 |
+
segmentation_method=args.method,
|
| 121 |
+
return_top_k=args.topk_cate,
|
| 122 |
+
include_visualizations=True,
|
| 123 |
+
debug_visualizations=args.debug_visualizations,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
results = vine_pipe(
|
| 128 |
+
video,
|
| 129 |
+
**call_kwargs,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Normalize pipeline output to a dict (can be dict or list[dict])
|
| 133 |
+
if isinstance(results, Mapping):
|
| 134 |
+
result = results
|
| 135 |
+
elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping):
|
| 136 |
+
result = results[0]
|
| 137 |
+
else:
|
| 138 |
+
result = {}
|
| 139 |
+
|
| 140 |
+
# Print brief summary
|
| 141 |
+
summary = result.get("summary", {}) if isinstance(result, dict) else {}
|
| 142 |
+
print("Summary:", summary)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
vine_hf/example_with_pretrained_vine.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example usage of VINE HuggingFace interface with pretrained VINE weights
|
| 3 |
+
|
| 4 |
+
This script demonstrates how to use the VINE model with your pretrained weights
|
| 5 |
+
from the ensemble format or from video-fm/vine_v0.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import pipeline
|
| 12 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 13 |
+
|
| 14 |
+
# Set your OpenAI API key here or via environment variable
|
| 15 |
+
#os.environ['OPENAI_API_KEY'] = "dummy-key"
|
| 16 |
+
|
| 17 |
+
# Add the parent directory to the path to import vine_hf
|
| 18 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
|
| 20 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def example_with_local_pretrained_weights():
|
| 24 |
+
print("=== Using Local Pretrained VINE Weights ===")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Download https://huggingface.co/video-fm/vine_v0/tree/main/laser_model_v1.pt
|
| 28 |
+
pretrained_vine_file = "/path/to/your/local/laser_model_v1.pt" # Replace with your local path
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Create configuration with your pretrained path (local file)
|
| 32 |
+
config = VineConfig(
|
| 33 |
+
model_name="openai/clip-vit-base-patch32",
|
| 34 |
+
segmentation_method="grounding_dino_sam2",
|
| 35 |
+
target_fps=1,
|
| 36 |
+
visualize=True,
|
| 37 |
+
visualization_dir="path/to/visualization/dir",
|
| 38 |
+
debug_visualizations=True,
|
| 39 |
+
use_hf_repo=False,
|
| 40 |
+
local_dir=os.path.dirname(pretrained_vine_file),
|
| 41 |
+
local_filename=os.path.basename(pretrained_vine_file),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Method 1: Initialize model directly
|
| 45 |
+
print("Method 1: Direct model initialization")
|
| 46 |
+
vine_model = VineModel(config)
|
| 47 |
+
print(f"✓ Model initialized with pretrained weights from: {pretrained_vine_file}")
|
| 48 |
+
|
| 49 |
+
# Method 2: Use the from_pretrained_vine class method
|
| 50 |
+
print("\nMethod 2: Using from_pretrained_vine class method")
|
| 51 |
+
vine_model_2 = VineModel.from_pretrained_vine(
|
| 52 |
+
model_path=pretrained_vine_file,
|
| 53 |
+
config=config,
|
| 54 |
+
epoch=0 # Specify epoch number
|
| 55 |
+
)
|
| 56 |
+
print("✓ Model loaded using from_pretrained_vine method")
|
| 57 |
+
|
| 58 |
+
return vine_model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def example_with_huggingface_hub():
|
| 62 |
+
"""Example using VINE weights from HuggingFace Hub."""
|
| 63 |
+
print("\n=== Using HuggingFace Hub Weights ===")
|
| 64 |
+
|
| 65 |
+
# Create configuration to use HuggingFace Hub weights
|
| 66 |
+
config = VineConfig(
|
| 67 |
+
model_name="openai/clip-vit-base-patch32",
|
| 68 |
+
use_hf_repo=True,
|
| 69 |
+
model_repo="video-fm/vine_v0", # Your HF Hub model
|
| 70 |
+
segmentation_method="grounding_dino_sam2",
|
| 71 |
+
visualize=True,
|
| 72 |
+
visualization_dir="path/to/visualization/dir",
|
| 73 |
+
debug_visualizations=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
# Initialize model (will try to load from HF Hub)
|
| 78 |
+
vine_model = VineModel(config)
|
| 79 |
+
print("✓ Model loaded from HuggingFace Hub: video-fm/vine_v0")
|
| 80 |
+
return vine_model
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"✗ Could not load from HuggingFace Hub: {e}")
|
| 83 |
+
print("Make sure your model is pushed to video-fm/vine_v0")
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def example_pipeline_with_pretrained():
|
| 88 |
+
"""Example using pipeline with pretrained VINE weights."""
|
| 89 |
+
print("\n=== Pipeline with Pretrained VINE ===")
|
| 90 |
+
|
| 91 |
+
# Register the pipeline
|
| 92 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 93 |
+
"vine-video-understanding",
|
| 94 |
+
pipeline_class=VinePipeline,
|
| 95 |
+
pt_model=VineModel,
|
| 96 |
+
type="multimodal",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Create configuration with your weights
|
| 100 |
+
pretrained_vine_file = "/path/to/your/local/laser_model_v1.pt" # Replace with your local path
|
| 101 |
+
config = VineConfig(
|
| 102 |
+
model_name="openai/clip-vit-base-patch32",
|
| 103 |
+
segmentation_method="grounding_dino_sam2",
|
| 104 |
+
visualize=True,
|
| 105 |
+
visualization_dir="path/to/visualization/dir",
|
| 106 |
+
debug_visualizations=True,
|
| 107 |
+
use_hf_repo=False,
|
| 108 |
+
local_dir=os.path.dirname(pretrained_vine_file),
|
| 109 |
+
local_filename=os.path.basename(pretrained_vine_file),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Create model with pretrained weights
|
| 113 |
+
vine_model = VineModel(config)
|
| 114 |
+
|
| 115 |
+
# Create pipeline with segmentation model paths
|
| 116 |
+
vine_pipeline = VinePipeline(
|
| 117 |
+
model=vine_model,
|
| 118 |
+
tokenizer=None,
|
| 119 |
+
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
|
| 120 |
+
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
|
| 121 |
+
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
|
| 122 |
+
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
|
| 123 |
+
device=0
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
print("✓ Pipeline created with pretrained VINE weights")
|
| 127 |
+
|
| 128 |
+
# Example usage (would require actual video file)
|
| 129 |
+
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/v1.mp4")
|
| 130 |
+
|
| 131 |
+
if os.path.exists(demo_video):
|
| 132 |
+
print(f"Found demo video: {demo_video}")
|
| 133 |
+
print("Example pipeline call:")
|
| 134 |
+
print(f"results = vine_pipeline(")
|
| 135 |
+
print(f" '{demo_video}',")
|
| 136 |
+
print(f" categorical_keywords=['human', 'dog', 'frisbee'],")
|
| 137 |
+
print(f" unary_keywords=['running', 'jumping', 'sitting'],")
|
| 138 |
+
print(f" binary_keywords=['behind', 'chasing', 'next to']")
|
| 139 |
+
print(f" debug_visualizations=True")
|
| 140 |
+
print(f")")
|
| 141 |
+
|
| 142 |
+
# Uncomment to actually run (requires segmentation models)
|
| 143 |
+
# results = vine_pipeline(
|
| 144 |
+
# demo_video,
|
| 145 |
+
# categorical_keywords=['human', 'dog', 'frisbee'],
|
| 146 |
+
# unary_keywords=['running', 'jumping', 'sitting'],
|
| 147 |
+
# binary_keywords=['behind', 'chasing', 'next to'],
|
| 148 |
+
# debug_visualizations=True,
|
| 149 |
+
# )
|
| 150 |
+
# print("Results:", results['summary'])
|
| 151 |
+
|
| 152 |
+
return vine_pipeline
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def example_manual_weight_loading():
|
| 157 |
+
"""Example of manually loading weights after model creation."""
|
| 158 |
+
print("\n=== Manual Weight Loading ===")
|
| 159 |
+
|
| 160 |
+
# Create model with base CLIP weights
|
| 161 |
+
# No pretrained path: create base config (no HF repo or local file configured)
|
| 162 |
+
config = VineConfig()
|
| 163 |
+
vine_model = VineModel(config)
|
| 164 |
+
print("✓ Model created with base CLIP weights")
|
| 165 |
+
model_dir = "/path/to/your/local/ensemble/model_dir.pt" # Replace with your model directory
|
| 166 |
+
|
| 167 |
+
if os.path.exists(model_dir):
|
| 168 |
+
success = vine_model.load_pretrained_vine_weights(model_dir, epoch=0)
|
| 169 |
+
if success:
|
| 170 |
+
print("✓ Successfully loaded pretrained VINE weights manually")
|
| 171 |
+
else:
|
| 172 |
+
print("✗ Failed to load pretrained weights")
|
| 173 |
+
else:
|
| 174 |
+
print(f"✗ Model directory not found: {model_dir}")
|
| 175 |
+
|
| 176 |
+
return vine_model
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def compare_model_outputs():
|
| 180 |
+
"""Compare outputs between base CLIP and pretrained VINE."""
|
| 181 |
+
print("\n=== Comparing Model Outputs ===")
|
| 182 |
+
|
| 183 |
+
# Create dummy data for testing
|
| 184 |
+
video_frames = torch.randn(3, 224, 224, 3) * 255 # 3 frames
|
| 185 |
+
video_frames = video_frames.clamp(0, 255).byte()
|
| 186 |
+
|
| 187 |
+
masks = {
|
| 188 |
+
0: {1: torch.ones(224, 224, 1)},
|
| 189 |
+
1: {1: torch.ones(224, 224, 1)},
|
| 190 |
+
2: {1: torch.ones(224, 224, 1)}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
bboxes = {
|
| 194 |
+
0: {1: [50, 50, 150, 150]},
|
| 195 |
+
1: {1: [52, 52, 152, 152]},
|
| 196 |
+
2: {1: [54, 54, 154, 154]}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
keywords = ['human', 'dog', 'frisbee']
|
| 200 |
+
|
| 201 |
+
# Model 1: Base CLIP
|
| 202 |
+
print("Creating model with base CLIP weights...")
|
| 203 |
+
config_base = VineConfig()
|
| 204 |
+
model_base = VineModel(config_base)
|
| 205 |
+
|
| 206 |
+
# Model 2: Pretrained VINE (if available)
|
| 207 |
+
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
|
| 208 |
+
model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
|
| 209 |
+
|
| 210 |
+
if os.path.exists(model_dir):
|
| 211 |
+
print("Creating model with pretrained VINE weights...")
|
| 212 |
+
config_vine = VineConfig(
|
| 213 |
+
use_hf_repo=False,
|
| 214 |
+
local_dir=model_dir,
|
| 215 |
+
local_filename=None,
|
| 216 |
+
)
|
| 217 |
+
model_vine = VineModel(config_vine)
|
| 218 |
+
|
| 219 |
+
print("\nComparing predictions...")
|
| 220 |
+
|
| 221 |
+
# Get predictions from both models
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
results_base = model_base.predict(
|
| 224 |
+
video_frames=video_frames,
|
| 225 |
+
masks=masks,
|
| 226 |
+
bboxes=bboxes,
|
| 227 |
+
categorical_keywords=keywords,
|
| 228 |
+
return_top_k=3
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
results_vine = model_vine.predict(
|
| 232 |
+
video_frames=video_frames,
|
| 233 |
+
masks=masks,
|
| 234 |
+
bboxes=bboxes,
|
| 235 |
+
categorical_keywords=keywords,
|
| 236 |
+
return_top_k=3
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
print("Base CLIP confidence scores:", results_base['confidence_scores'])
|
| 240 |
+
print("Pretrained VINE confidence scores:", results_vine['confidence_scores'])
|
| 241 |
+
|
| 242 |
+
print("✓ Successfully compared both models")
|
| 243 |
+
else:
|
| 244 |
+
print(f"Pretrained model not found at: {model_dir}")
|
| 245 |
+
print("Skipping comparison")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
print("VINE HuggingFace Interface - Pretrained Weights Examples")
|
| 250 |
+
print("=" * 60)
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
# Test local pretrained weights
|
| 254 |
+
model1 = example_with_local_pretrained_weights()
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Local weights example failed: {e}")
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
# Test HuggingFace Hub weights
|
| 260 |
+
model2 = example_with_huggingface_hub()
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"HuggingFace Hub example failed: {e}")
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
# Test pipeline with pretrained weights
|
| 266 |
+
pipeline = example_pipeline_with_pretrained()
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Pipeline example failed: {e}")
|
| 269 |
+
|
| 270 |
+
# try:
|
| 271 |
+
# # Test manual weight loading
|
| 272 |
+
# #model3 = example_manual_weight_loading()
|
| 273 |
+
# except Exception as e:
|
| 274 |
+
# print(f"Manual loading example failed: {e}")
|
| 275 |
+
|
| 276 |
+
# try:
|
| 277 |
+
# # Compare model outputs
|
| 278 |
+
# #compare_model_outputs()
|
| 279 |
+
# except Exception as e:
|
| 280 |
+
# print(f"Comparison example failed: {e}")
|
| 281 |
+
|
| 282 |
+
print("\n" + "=" * 60)
|
| 283 |
+
print("Examples completed!")
|
| 284 |
+
print("\nUsage Summary:")
|
| 285 |
+
print("1. Configure VineConfig with `use_hf_repo` + `model_repo` for Hub models, or `use_hf_repo=False` + `local_dir`/`local_filename` for local weights")
|
| 286 |
+
print("2. Use VineModel.from_pretrained_vine() for direct loading")
|
| 287 |
+
|
vine_hf/flattening.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MaskType = Union[np.ndarray, torch.Tensor]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _to_numpy_mask(mask: MaskType) -> np.ndarray:
|
| 14 |
+
"""
|
| 15 |
+
Convert assorted mask formats to a 2D numpy boolean array.
|
| 16 |
+
"""
|
| 17 |
+
if isinstance(mask, torch.Tensor):
|
| 18 |
+
mask_np = mask.detach().cpu().numpy()
|
| 19 |
+
else:
|
| 20 |
+
mask_np = np.asarray(mask)
|
| 21 |
+
|
| 22 |
+
# Remove singleton dimensions at the front/back
|
| 23 |
+
while mask_np.ndim > 2 and mask_np.shape[0] == 1:
|
| 24 |
+
mask_np = np.squeeze(mask_np, axis=0)
|
| 25 |
+
if mask_np.ndim > 2 and mask_np.shape[-1] == 1:
|
| 26 |
+
mask_np = np.squeeze(mask_np, axis=-1)
|
| 27 |
+
|
| 28 |
+
if mask_np.ndim != 2:
|
| 29 |
+
raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}")
|
| 30 |
+
|
| 31 |
+
return mask_np.astype(bool)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
| 35 |
+
"""
|
| 36 |
+
Compute a bounding box for a 2D boolean mask.
|
| 37 |
+
"""
|
| 38 |
+
if not mask.any():
|
| 39 |
+
return None
|
| 40 |
+
rows, cols = np.nonzero(mask)
|
| 41 |
+
y_min, y_max = rows.min(), rows.max()
|
| 42 |
+
x_min, x_max = cols.min(), cols.max()
|
| 43 |
+
return x_min, y_min, x_max, y_max
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def flatten_segments_for_batch(
|
| 47 |
+
video_id: int,
|
| 48 |
+
segments: Dict[int, Dict[int, MaskType]],
|
| 49 |
+
bbox_min_dim: int = 5,
|
| 50 |
+
) -> Dict[str, List]:
|
| 51 |
+
"""
|
| 52 |
+
Flatten nested segmentation data into batched lists suitable for predicate
|
| 53 |
+
models or downstream visualizations. Mirrors the notebook helper but is
|
| 54 |
+
robust to differing mask dtypes/shapes.
|
| 55 |
+
"""
|
| 56 |
+
batched_object_ids: List[Tuple[int, int, int]] = []
|
| 57 |
+
batched_masks: List[np.ndarray] = []
|
| 58 |
+
batched_bboxes: List[Tuple[int, int, int, int]] = []
|
| 59 |
+
frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 60 |
+
|
| 61 |
+
for frame_id, frame_objects in segments.items():
|
| 62 |
+
valid_objects: List[int] = []
|
| 63 |
+
for object_id, raw_mask in frame_objects.items():
|
| 64 |
+
mask = _to_numpy_mask(raw_mask)
|
| 65 |
+
bbox = _mask_to_bbox(mask)
|
| 66 |
+
if bbox is None:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
x_min, y_min, x_max, y_max = bbox
|
| 70 |
+
if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
valid_objects.append(object_id)
|
| 74 |
+
batched_object_ids.append((video_id, frame_id, object_id))
|
| 75 |
+
batched_masks.append(mask)
|
| 76 |
+
batched_bboxes.append(bbox)
|
| 77 |
+
|
| 78 |
+
for i in valid_objects:
|
| 79 |
+
for j in valid_objects:
|
| 80 |
+
if i == j:
|
| 81 |
+
continue
|
| 82 |
+
frame_pairs.append((video_id, frame_id, (i, j)))
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"object_ids": batched_object_ids,
|
| 86 |
+
"masks": batched_masks,
|
| 87 |
+
"bboxes": batched_bboxes,
|
| 88 |
+
"pairs": frame_pairs,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def extract_valid_object_pairs(
|
| 93 |
+
batched_object_ids: Sequence[Tuple[int, int, int]],
|
| 94 |
+
interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None,
|
| 95 |
+
) -> List[Tuple[int, int, Tuple[int, int]]]:
|
| 96 |
+
"""
|
| 97 |
+
Filter object pairs per frame. If `interested_object_pairs` is provided, only
|
| 98 |
+
emit those combinations when both objects are present; otherwise emit all
|
| 99 |
+
permutations (i, j) with i != j for each frame.
|
| 100 |
+
"""
|
| 101 |
+
frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set)
|
| 102 |
+
for vid, fid, oid in batched_object_ids:
|
| 103 |
+
frame_to_objects[(vid, fid)].add(oid)
|
| 104 |
+
|
| 105 |
+
interested = (
|
| 106 |
+
list(interested_object_pairs)
|
| 107 |
+
if interested_object_pairs is not None
|
| 108 |
+
else None
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 112 |
+
for (vid, fid), object_ids in frame_to_objects.items():
|
| 113 |
+
if interested:
|
| 114 |
+
for src, dst in interested:
|
| 115 |
+
if src in object_ids and dst in object_ids:
|
| 116 |
+
valid_pairs.append((vid, fid, (src, dst)))
|
| 117 |
+
else:
|
| 118 |
+
for src in object_ids:
|
| 119 |
+
for dst in object_ids:
|
| 120 |
+
if src == dst:
|
| 121 |
+
continue
|
| 122 |
+
valid_pairs.append((vid, fid, (src, dst)))
|
| 123 |
+
|
| 124 |
+
return valid_pairs
|
vine_hf/push_to_hub.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to push VINE model to HuggingFace Hub
|
| 3 |
+
|
| 4 |
+
This script helps you push your trained VINE model to the HuggingFace Hub
|
| 5 |
+
for easy sharing and distribution.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import argparse
|
| 12 |
+
from huggingface_hub import notebook_login
|
| 13 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 14 |
+
|
| 15 |
+
# Add the parent directory to the path to import vine_hf
|
| 16 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
os.environ['OPENAI_API_KEY'] = "dummy-key"
|
| 19 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def push_vine_to_hub(
|
| 23 |
+
model_weights_path: str,
|
| 24 |
+
repo_name: str,
|
| 25 |
+
model_name: str = "openai/clip-vit-base-patch32",
|
| 26 |
+
segmentation_method: str = "grounding_dino_sam2",
|
| 27 |
+
commit_message: str = "Upload VINE model",
|
| 28 |
+
private: bool = False
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Push VINE model to HuggingFace Hub.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_weights_path: Path to the trained model weights (.pth file)
|
| 35 |
+
repo_name: Name for the repository (e.g., "username/vine-model")
|
| 36 |
+
model_name: CLIP model backbone name
|
| 37 |
+
segmentation_method: Segmentation method used
|
| 38 |
+
commit_message: Commit message for the push
|
| 39 |
+
private: Whether to create a private repository
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
print("=== Pushing VINE Model to HuggingFace Hub ===")
|
| 43 |
+
|
| 44 |
+
# 1. Create configuration
|
| 45 |
+
print(f"Creating configuration with backbone: {model_name}")
|
| 46 |
+
config = VineConfig(
|
| 47 |
+
model_name=model_name,
|
| 48 |
+
segmentation_method=segmentation_method
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# 2. Initialize model
|
| 52 |
+
print("Initializing model...")
|
| 53 |
+
model = VineModel(config)
|
| 54 |
+
|
| 55 |
+
# 3. Load trained weights
|
| 56 |
+
if os.path.exists(model_weights_path):
|
| 57 |
+
print(f"Loading weights from: {model_weights_path}")
|
| 58 |
+
try:
|
| 59 |
+
# Try loading with weights_only=False for compatibility
|
| 60 |
+
weights = torch.load(model_weights_path, map_location='cpu', weights_only=False)
|
| 61 |
+
|
| 62 |
+
# Handle different weight formats
|
| 63 |
+
if isinstance(weights, dict):
|
| 64 |
+
if 'state_dict' in weights:
|
| 65 |
+
model.load_state_dict(weights['state_dict'])
|
| 66 |
+
elif 'model' in weights:
|
| 67 |
+
model.load_state_dict(weights['model'])
|
| 68 |
+
else:
|
| 69 |
+
model.load_state_dict(weights)
|
| 70 |
+
else:
|
| 71 |
+
# Assume it's the model directly
|
| 72 |
+
model = weights
|
| 73 |
+
|
| 74 |
+
print("✓ Weights loaded successfully")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"✗ Error loading weights: {e}")
|
| 77 |
+
print("Please check your weights file format")
|
| 78 |
+
return False
|
| 79 |
+
else:
|
| 80 |
+
print(f"✗ Weights file not found: {model_weights_path}")
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
# 4. Register for auto classes
|
| 84 |
+
print("Registering for auto classes...")
|
| 85 |
+
config.register_for_auto_class()
|
| 86 |
+
model.register_for_auto_class("AutoModel")
|
| 87 |
+
|
| 88 |
+
# 5. Register pipeline
|
| 89 |
+
print("Registering pipeline...")
|
| 90 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 91 |
+
"vine-video-understanding",
|
| 92 |
+
pipeline_class=VinePipeline,
|
| 93 |
+
pt_model=VineModel,
|
| 94 |
+
type="multimodal",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# 6. Create pipeline instance
|
| 98 |
+
print("Creating pipeline...")
|
| 99 |
+
vine_pipeline = VinePipeline(model=model, tokenizer=None)
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
# 7. Push configuration to hub
|
| 103 |
+
print(f"Pushing configuration to {repo_name}...")
|
| 104 |
+
config.push_to_hub(
|
| 105 |
+
repo_name,
|
| 106 |
+
commit_message=f"{commit_message} - config",
|
| 107 |
+
private=private
|
| 108 |
+
)
|
| 109 |
+
print("✓ Configuration pushed successfully")
|
| 110 |
+
|
| 111 |
+
# 8. Push model to hub
|
| 112 |
+
print(f"Pushing model to {repo_name}...")
|
| 113 |
+
model.push_to_hub(
|
| 114 |
+
repo_name,
|
| 115 |
+
commit_message=f"{commit_message} - model",
|
| 116 |
+
private=private
|
| 117 |
+
)
|
| 118 |
+
print("✓ Model pushed successfully")
|
| 119 |
+
|
| 120 |
+
# 9. Push pipeline to hub
|
| 121 |
+
print(f"Pushing pipeline to {repo_name}...")
|
| 122 |
+
vine_pipeline.push_to_hub(
|
| 123 |
+
repo_name,
|
| 124 |
+
commit_message=f"{commit_message} - pipeline",
|
| 125 |
+
private=private
|
| 126 |
+
)
|
| 127 |
+
print("✓ Pipeline pushed successfully")
|
| 128 |
+
|
| 129 |
+
print(f"\n🎉 Successfully pushed VINE model to: https://huggingface.co/{repo_name}")
|
| 130 |
+
print(f"\nTo use your model:")
|
| 131 |
+
print(f"```python")
|
| 132 |
+
print(f"from transformers import pipeline")
|
| 133 |
+
print(f"")
|
| 134 |
+
print(f"vine_pipeline = pipeline(")
|
| 135 |
+
print(f" 'vine-video-understanding',")
|
| 136 |
+
print(f" model='{repo_name}',")
|
| 137 |
+
print(f" trust_remote_code=True")
|
| 138 |
+
print(f")")
|
| 139 |
+
print(f"")
|
| 140 |
+
print(f"results = vine_pipeline(")
|
| 141 |
+
print(f" 'path/to/video.mp4',")
|
| 142 |
+
print(f" categorical_keywords=['human', 'dog', 'frisbee'],")
|
| 143 |
+
print(f" unary_keywords=['running', 'jumping'],")
|
| 144 |
+
print(f" binary_keywords=['chasing', 'behind']")
|
| 145 |
+
print(f")")
|
| 146 |
+
print(f"```")
|
| 147 |
+
|
| 148 |
+
return True
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"✗ Error pushing to hub: {e}")
|
| 152 |
+
print("Please check your HuggingFace credentials and repository permissions")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main():
|
| 157 |
+
parser = argparse.ArgumentParser(description="Push VINE model to HuggingFace Hub")
|
| 158 |
+
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--weights",
|
| 161 |
+
type=str,
|
| 162 |
+
required=True,
|
| 163 |
+
help="Path to the trained model weights (.pth file)"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--repo",
|
| 168 |
+
type=str,
|
| 169 |
+
required=True,
|
| 170 |
+
help="Repository name (e.g., 'username/vine-model')"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--model-name",
|
| 175 |
+
type=str,
|
| 176 |
+
default="openai/clip-vit-base-patch32",
|
| 177 |
+
help="CLIP model backbone name"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--segmentation",
|
| 182 |
+
type=str,
|
| 183 |
+
default="grounding_dino_sam2",
|
| 184 |
+
choices=["sam2", "grounding_dino_sam2"],
|
| 185 |
+
help="Segmentation method"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--message",
|
| 190 |
+
type=str,
|
| 191 |
+
default="Upload VINE model",
|
| 192 |
+
help="Commit message"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--private",
|
| 197 |
+
action="store_true",
|
| 198 |
+
help="Create private repository"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--login",
|
| 203 |
+
action="store_true",
|
| 204 |
+
help="Login to HuggingFace Hub first"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
args = parser.parse_args()
|
| 208 |
+
|
| 209 |
+
# Login if requested
|
| 210 |
+
if args.login:
|
| 211 |
+
print("Logging in to HuggingFace Hub...")
|
| 212 |
+
notebook_login()
|
| 213 |
+
|
| 214 |
+
# Push model
|
| 215 |
+
success = push_vine_to_hub(
|
| 216 |
+
model_weights_path=args.weights,
|
| 217 |
+
repo_name=args.repo,
|
| 218 |
+
model_name=args.model_name,
|
| 219 |
+
segmentation_method=args.segmentation,
|
| 220 |
+
commit_message=args.message,
|
| 221 |
+
private=args.private
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if success:
|
| 225 |
+
print("\n✅ Model successfully pushed to HuggingFace Hub!")
|
| 226 |
+
else:
|
| 227 |
+
print("\n❌ Failed to push model to HuggingFace Hub")
|
| 228 |
+
sys.exit(1)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
main()
|
vine_hf/push_to_video_fm.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to push VINE model to video-fm organization on HuggingFace Hub
|
| 3 |
+
|
| 4 |
+
This script pushes the VINE architecture (config, model, pipeline) and model weights
|
| 5 |
+
to the video-fm organization for easy sharing and distribution.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import argparse
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from huggingface_hub import HfApi, login
|
| 14 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 15 |
+
from transformers import AutoModel
|
| 16 |
+
from safetensors.torch import save_file
|
| 17 |
+
|
| 18 |
+
# Add the parent directory to path to enable vine_hf imports
|
| 19 |
+
current_dir = Path(__file__).parent
|
| 20 |
+
parent_dir = current_dir.parent
|
| 21 |
+
sys.path.insert(0, str(parent_dir))
|
| 22 |
+
|
| 23 |
+
os.environ['OPENAI_API_KEY'] = "dummy-key"
|
| 24 |
+
|
| 25 |
+
# Import from vine_hf package
|
| 26 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def push_vine_to_video_fm(
|
| 30 |
+
source_repo_or_path: str = "KevinX-Penn28/testing",
|
| 31 |
+
target_repo: str = "video-fm/vine",
|
| 32 |
+
model_name: str = "openai/clip-vit-base-patch32",
|
| 33 |
+
commit_message: str = "Upload VINE model architecture and weights",
|
| 34 |
+
private: bool = False,
|
| 35 |
+
use_local_weights: bool = False,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Push VINE model to video-fm organization on HuggingFace Hub.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
source_repo_or_path: Source HF repo or local path with model weights
|
| 42 |
+
target_repo: Target repository (e.g., "video-fm/vine")
|
| 43 |
+
model_name: CLIP model backbone name
|
| 44 |
+
commit_message: Commit message for the push
|
| 45 |
+
private: Whether to create a private repository
|
| 46 |
+
use_local_weights: If True, source_repo_or_path is a local file path
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
print("=" * 70)
|
| 50 |
+
print("🚀 Pushing VINE Model to HuggingFace Hub - video-fm Organization")
|
| 51 |
+
print("=" * 70)
|
| 52 |
+
|
| 53 |
+
# 1. Create configuration
|
| 54 |
+
print(f"\n📝 Creating configuration with backbone: {model_name}")
|
| 55 |
+
config = VineConfig(
|
| 56 |
+
model_name=model_name,
|
| 57 |
+
segmentation_method="grounding_dino_sam2",
|
| 58 |
+
use_hf_repo=not use_local_weights,
|
| 59 |
+
model_repo=source_repo_or_path if not use_local_weights else None,
|
| 60 |
+
local_dir=str(Path(source_repo_or_path).parent) if use_local_weights else None,
|
| 61 |
+
local_filename=Path(source_repo_or_path).name if use_local_weights else None,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# 2. Initialize model (will automatically load weights from source)
|
| 65 |
+
print(f"\n🔧 Initializing model and loading weights from: {source_repo_or_path}")
|
| 66 |
+
model = VineModel(config)
|
| 67 |
+
print("✓ Model initialized with weights loaded")
|
| 68 |
+
|
| 69 |
+
# 3. Register for auto classes
|
| 70 |
+
print("\n📋 Registering for auto classes...")
|
| 71 |
+
config.register_for_auto_class()
|
| 72 |
+
model.register_for_auto_class("AutoModel")
|
| 73 |
+
print("✓ Registered for AutoModel and AutoConfig")
|
| 74 |
+
|
| 75 |
+
# 4. Register pipeline
|
| 76 |
+
print("\n🔌 Registering custom pipeline...")
|
| 77 |
+
try:
|
| 78 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 79 |
+
"vine-video-understanding",
|
| 80 |
+
pipeline_class=VinePipeline,
|
| 81 |
+
pt_model=VineModel,
|
| 82 |
+
type="multimodal",
|
| 83 |
+
)
|
| 84 |
+
print("✓ Pipeline registered")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"⚠ Pipeline registration: {e} (may already be registered)")
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# 5. Push configuration to hub
|
| 90 |
+
print(f"\n⬆️ Pushing configuration to {target_repo}...")
|
| 91 |
+
config.push_to_hub(
|
| 92 |
+
target_repo,
|
| 93 |
+
commit_message=f"{commit_message} - config",
|
| 94 |
+
private=private
|
| 95 |
+
)
|
| 96 |
+
print("✓ Configuration pushed successfully")
|
| 97 |
+
|
| 98 |
+
# 6. Push model to hub
|
| 99 |
+
print(f"\n⬆️ Pushing model to {target_repo}...")
|
| 100 |
+
model.push_to_hub(
|
| 101 |
+
target_repo,
|
| 102 |
+
commit_message=f"{commit_message} - model and weights",
|
| 103 |
+
private=private
|
| 104 |
+
)
|
| 105 |
+
print("✓ Model and weights pushed successfully")
|
| 106 |
+
|
| 107 |
+
# 7. Copy additional necessary files to the repo
|
| 108 |
+
print(f"\n📦 Uploading additional architecture files...")
|
| 109 |
+
api = HfApi()
|
| 110 |
+
|
| 111 |
+
# Upload flattening.py and vis_utils.py as they're imported by the model
|
| 112 |
+
current_dir = Path(__file__).parent
|
| 113 |
+
additional_files = [
|
| 114 |
+
"flattening.py",
|
| 115 |
+
"vis_utils.py",
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
for filename in additional_files:
|
| 119 |
+
file_path = current_dir / filename
|
| 120 |
+
if file_path.exists():
|
| 121 |
+
api.upload_file(
|
| 122 |
+
path_or_fileobj=str(file_path),
|
| 123 |
+
path_in_repo=filename,
|
| 124 |
+
repo_id=target_repo,
|
| 125 |
+
commit_message=f"Add {filename}",
|
| 126 |
+
)
|
| 127 |
+
print(f"✓ Uploaded {filename}")
|
| 128 |
+
else:
|
| 129 |
+
print(f"⚠ Warning: {filename} not found at {file_path}")
|
| 130 |
+
|
| 131 |
+
# 8. Upload README if it exists
|
| 132 |
+
readme_path = current_dir / "README.md"
|
| 133 |
+
if readme_path.exists():
|
| 134 |
+
api.upload_file(
|
| 135 |
+
path_or_fileobj=str(readme_path),
|
| 136 |
+
path_in_repo="README.md",
|
| 137 |
+
repo_id=target_repo,
|
| 138 |
+
commit_message="Add README documentation",
|
| 139 |
+
)
|
| 140 |
+
print("✓ Uploaded README.md")
|
| 141 |
+
|
| 142 |
+
print("\n" + "=" * 70)
|
| 143 |
+
print("🎉 Successfully pushed VINE model to HuggingFace Hub!")
|
| 144 |
+
print("=" * 70)
|
| 145 |
+
print(f"\n📍 Model URL: https://huggingface.co/{target_repo}")
|
| 146 |
+
print(f"\n📚 To use your model:")
|
| 147 |
+
print(f"""
|
| 148 |
+
```python
|
| 149 |
+
from transformers import AutoModel, AutoConfig
|
| 150 |
+
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 151 |
+
|
| 152 |
+
# Option 1: Load with AutoModel
|
| 153 |
+
model = AutoModel.from_pretrained('{target_repo}', trust_remote_code=True)
|
| 154 |
+
|
| 155 |
+
# Option 2: Load with VineModel directly
|
| 156 |
+
config = VineConfig.from_pretrained('{target_repo}')
|
| 157 |
+
model = VineModel.from_pretrained('{target_repo}')
|
| 158 |
+
|
| 159 |
+
# Option 3: Use with pipeline
|
| 160 |
+
from transformers import pipeline
|
| 161 |
+
|
| 162 |
+
vine_pipeline = pipeline(
|
| 163 |
+
'vine-video-understanding',
|
| 164 |
+
model='{target_repo}',
|
| 165 |
+
trust_remote_code=True
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
results = vine_pipeline(
|
| 169 |
+
'path/to/video.mp4',
|
| 170 |
+
categorical_keywords=['human', 'dog', 'frisbee'],
|
| 171 |
+
unary_keywords=['running', 'jumping'],
|
| 172 |
+
binary_keywords=['chasing', 'behind']
|
| 173 |
+
)
|
| 174 |
+
```
|
| 175 |
+
""")
|
| 176 |
+
|
| 177 |
+
return True
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"\n❌ Error pushing to hub: {e}")
|
| 181 |
+
import traceback
|
| 182 |
+
traceback.print_exc()
|
| 183 |
+
print("\nPlease check:")
|
| 184 |
+
print(" - HuggingFace credentials (run: huggingface-cli login)")
|
| 185 |
+
print(" - Repository permissions for video-fm organization")
|
| 186 |
+
print(" - Network connectivity")
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def main():
|
| 191 |
+
parser = argparse.ArgumentParser(
|
| 192 |
+
description="Push VINE model to video-fm organization on HuggingFace Hub"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--source",
|
| 197 |
+
type=str,
|
| 198 |
+
default="KevinX-Penn28/testing",
|
| 199 |
+
help="Source HF repo or local path with model weights (default: KevinX-Penn28/testing)"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--target",
|
| 204 |
+
type=str,
|
| 205 |
+
default="video-fm/vine",
|
| 206 |
+
help="Target repository in video-fm org (default: video-fm/vine)"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--model-name",
|
| 211 |
+
type=str,
|
| 212 |
+
default="openai/clip-vit-base-patch32",
|
| 213 |
+
help="CLIP model backbone name"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--message",
|
| 218 |
+
type=str,
|
| 219 |
+
default="Upload VINE model architecture and weights",
|
| 220 |
+
help="Commit message"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--private",
|
| 225 |
+
action="store_true",
|
| 226 |
+
help="Create private repository"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--local-weights",
|
| 231 |
+
action="store_true",
|
| 232 |
+
help="Use local weights file instead of HF repo"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
args = parser.parse_args()
|
| 236 |
+
|
| 237 |
+
# Check login status
|
| 238 |
+
try:
|
| 239 |
+
api = HfApi()
|
| 240 |
+
user_info = api.whoami()
|
| 241 |
+
print(f"✓ Logged in as: {user_info['name']}")
|
| 242 |
+
|
| 243 |
+
# Check if user has access to video-fm org
|
| 244 |
+
orgs = [org['name'] for org in user_info.get('orgs', [])]
|
| 245 |
+
if 'video-fm' in orgs:
|
| 246 |
+
print(f"✓ Confirmed access to video-fm organization")
|
| 247 |
+
else:
|
| 248 |
+
print(f"⚠ Warning: You may not have access to video-fm organization")
|
| 249 |
+
print(f" Your organizations: {orgs}")
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"❌ Not logged in to HuggingFace. Please run: huggingface-cli login")
|
| 252 |
+
print(f" Or use: python -c 'from huggingface_hub import login; login()'")
|
| 253 |
+
sys.exit(1)
|
| 254 |
+
|
| 255 |
+
# Push model
|
| 256 |
+
success = push_vine_to_video_fm(
|
| 257 |
+
source_repo_or_path=args.source,
|
| 258 |
+
target_repo=args.target,
|
| 259 |
+
model_name=args.model_name,
|
| 260 |
+
commit_message=args.message,
|
| 261 |
+
private=args.private,
|
| 262 |
+
use_local_weights=args.local_weights,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if success:
|
| 266 |
+
print("\n✅ Successfully completed!")
|
| 267 |
+
sys.exit(0)
|
| 268 |
+
else:
|
| 269 |
+
print("\n❌ Push failed!")
|
| 270 |
+
sys.exit(1)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
main()
|
vine_hf/setup.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Setup script for VINE HuggingFace Interface
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from setuptools import setup
|
| 6 |
+
|
| 7 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
| 8 |
+
long_description = fh.read()
|
| 9 |
+
|
| 10 |
+
setup(
|
| 11 |
+
name="vine-hf",
|
| 12 |
+
version="1.0.0",
|
| 13 |
+
author="LASER Team",
|
| 14 |
+
author_email="your-email@example.com",
|
| 15 |
+
description="HuggingFace interface for VINE (Video Understanding with Natural Language)",
|
| 16 |
+
long_description=long_description,
|
| 17 |
+
long_description_content_type="text/markdown",
|
| 18 |
+
url="https://github.com/your-username/vine-hf",
|
| 19 |
+
# Since all modules are in the root directory, we use py_modules instead of packages
|
| 20 |
+
py_modules=[
|
| 21 |
+
"vine_config",
|
| 22 |
+
"vine_model",
|
| 23 |
+
"vine_pipeline",
|
| 24 |
+
"vis_utils",
|
| 25 |
+
"flattening",
|
| 26 |
+
"convert_inference",
|
| 27 |
+
],
|
| 28 |
+
# Also include __init__.py to make it a package
|
| 29 |
+
packages=["vine_hf"],
|
| 30 |
+
package_dir={"vine_hf": "."},
|
| 31 |
+
classifiers=[
|
| 32 |
+
"Development Status :: 4 - Beta",
|
| 33 |
+
"Intended Audience :: Developers",
|
| 34 |
+
"Intended Audience :: Science/Research",
|
| 35 |
+
"License :: OSI Approved :: MIT License",
|
| 36 |
+
"Operating System :: OS Independent",
|
| 37 |
+
"Programming Language :: Python :: 3",
|
| 38 |
+
"Programming Language :: Python :: 3.7",
|
| 39 |
+
"Programming Language :: Python :: 3.8",
|
| 40 |
+
"Programming Language :: Python :: 3.9",
|
| 41 |
+
"Programming Language :: Python :: 3.10",
|
| 42 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 43 |
+
"Topic :: Multimedia :: Video",
|
| 44 |
+
],
|
| 45 |
+
python_requires=">=3.7",
|
| 46 |
+
install_requires=[
|
| 47 |
+
"torch>=1.9.0",
|
| 48 |
+
"torchvision>=0.10.0",
|
| 49 |
+
"transformers>=4.20.0",
|
| 50 |
+
"opencv-python>=4.5.0",
|
| 51 |
+
"pillow>=8.0.0",
|
| 52 |
+
"numpy>=1.20.0",
|
| 53 |
+
"huggingface-hub>=0.10.0",
|
| 54 |
+
"tqdm>=4.60.0",
|
| 55 |
+
],
|
| 56 |
+
extras_require={
|
| 57 |
+
"dev": [
|
| 58 |
+
"pytest>=6.0",
|
| 59 |
+
"black>=22.0",
|
| 60 |
+
"flake8>=4.0",
|
| 61 |
+
"isort>=5.0",
|
| 62 |
+
],
|
| 63 |
+
"segmentation": [
|
| 64 |
+
# Note: SAM2 and Grounding DINO need to be installed separately
|
| 65 |
+
# as they're not available on PyPI
|
| 66 |
+
],
|
| 67 |
+
},
|
| 68 |
+
entry_points={
|
| 69 |
+
"console_scripts": [
|
| 70 |
+
"vine-push-to-hub=vine_hf.push_to_hub:main",
|
| 71 |
+
],
|
| 72 |
+
},
|
| 73 |
+
)
|
vine_hf/vine_config.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
from typing import List, Optional, Dict, Any, Tuple, Union
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VineConfig(PretrainedConfig):
|
| 8 |
+
"""
|
| 9 |
+
Configuration class for VINE (Video Understanding with Natural Language) model.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
model_type = "vine"
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model_name: str = "openai/clip-vit-base-patch32",
|
| 17 |
+
hidden_dim: int = 768,
|
| 18 |
+
use_hf_repo: bool = True,
|
| 19 |
+
model_repo: Optional[str] = "KevinX-Penn28/testing",
|
| 20 |
+
model_file: Optional[str] = None,
|
| 21 |
+
local_dir: Optional[str] = str(Path(__file__).resolve().parent),
|
| 22 |
+
local_filename: Optional[str] = "laser_model_v1.pkl",
|
| 23 |
+
num_top_pairs: int = 18,
|
| 24 |
+
segmentation_method: str = "grounding_dino_sam2",
|
| 25 |
+
box_threshold: float = 0.35,
|
| 26 |
+
text_threshold: float = 0.25,
|
| 27 |
+
target_fps: int = 1,
|
| 28 |
+
alpha: float = 0.5,
|
| 29 |
+
white_alpha: float = 0.8,
|
| 30 |
+
topk_cate: int = 3,
|
| 31 |
+
multi_class: bool = False,
|
| 32 |
+
output_logit: bool = False,
|
| 33 |
+
use_pretrained_cate_weights: bool = False,
|
| 34 |
+
categorical_pool: str = "mean", # "mean" or "max"
|
| 35 |
+
max_video_length: int = 100,
|
| 36 |
+
bbox_min_dim: int = 1,
|
| 37 |
+
visualize: bool = False,
|
| 38 |
+
visualization_dir: Optional[str] = None,
|
| 39 |
+
return_flattened_segments: bool = False,
|
| 40 |
+
return_valid_pairs: bool = False,
|
| 41 |
+
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 42 |
+
debug_visualizations: bool = False,
|
| 43 |
+
device: Optional[Union[str, int]] = None,
|
| 44 |
+
**kwargs: Any,
|
| 45 |
+
):
|
| 46 |
+
self.model_name = model_name
|
| 47 |
+
self.use_hf_repo = use_hf_repo
|
| 48 |
+
if use_hf_repo:
|
| 49 |
+
self.model_repo = model_repo
|
| 50 |
+
self.model_file = model_file
|
| 51 |
+
self.local_dir = None
|
| 52 |
+
self.local_filename = None
|
| 53 |
+
else:
|
| 54 |
+
self.model_repo = None
|
| 55 |
+
self.model_file = None
|
| 56 |
+
self.local_dir = local_dir
|
| 57 |
+
self.local_filename = local_filename
|
| 58 |
+
|
| 59 |
+
self.hidden_dim = hidden_dim
|
| 60 |
+
self.num_top_pairs = num_top_pairs
|
| 61 |
+
self.segmentation_method = segmentation_method
|
| 62 |
+
self.box_threshold = box_threshold
|
| 63 |
+
self.text_threshold = text_threshold
|
| 64 |
+
self.target_fps = target_fps
|
| 65 |
+
self.alpha = alpha
|
| 66 |
+
self.white_alpha = white_alpha
|
| 67 |
+
self.topk_cate = topk_cate
|
| 68 |
+
self.multi_class = multi_class
|
| 69 |
+
self.output_logit = output_logit
|
| 70 |
+
self.use_pretrained_cate_weights = use_pretrained_cate_weights
|
| 71 |
+
self.categorical_pool = categorical_pool
|
| 72 |
+
self.max_video_length = max_video_length
|
| 73 |
+
self.bbox_min_dim = bbox_min_dim
|
| 74 |
+
self.visualize = visualize
|
| 75 |
+
self.visualization_dir = visualization_dir
|
| 76 |
+
self.return_flattened_segments = return_flattened_segments
|
| 77 |
+
self.return_valid_pairs = return_valid_pairs
|
| 78 |
+
self.interested_object_pairs = interested_object_pairs or []
|
| 79 |
+
self.debug_visualizations = debug_visualizations
|
| 80 |
+
|
| 81 |
+
if isinstance(device, int):
|
| 82 |
+
self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
|
| 83 |
+
else:
|
| 84 |
+
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
+
|
| 86 |
+
super().__init__(**kwargs)
|
vine_hf/vine_model.py
ADDED
|
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Dict, List, Tuple, Optional, Any, Union
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.utils.checkpoint as cp
|
| 12 |
+
from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
|
| 15 |
+
from .vine_config import VineConfig
|
| 16 |
+
from laser.models import llava_clip_model_v3
|
| 17 |
+
sys.modules["llava_clip_model_v3"] = llava_clip_model_v3
|
| 18 |
+
from laser.models.model_utils import (
|
| 19 |
+
extract_single_object,
|
| 20 |
+
extract_object_subject,
|
| 21 |
+
crop_image_contain_bboxes,
|
| 22 |
+
segment_list,
|
| 23 |
+
)
|
| 24 |
+
from .flattening import (
|
| 25 |
+
extract_valid_object_pairs,
|
| 26 |
+
flatten_segments_for_batch,
|
| 27 |
+
)
|
| 28 |
+
from .vis_utils import save_mask_one_image
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class VineModel(PreTrainedModel):
|
| 32 |
+
"""
|
| 33 |
+
VINE (Video Understanding with Natural Language) Model.
|
| 34 |
+
|
| 35 |
+
Internally, the core CLIP/text/image/pair logic mirrors
|
| 36 |
+
llava_clip_model_v3.PredicateModel as closely as possible for a single video,
|
| 37 |
+
with a small extension to re-normalize categorical probs after pooling.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
config_class = VineConfig
|
| 41 |
+
|
| 42 |
+
def __init__(self, config: VineConfig):
|
| 43 |
+
super().__init__(config)
|
| 44 |
+
self.config = config
|
| 45 |
+
self.visualize = getattr(config, "visualize", False)
|
| 46 |
+
self.visualization_dir = getattr(config, "visualization_dir", None)
|
| 47 |
+
self.debug_visualizations = getattr(config, "debug_visualizations", False)
|
| 48 |
+
self._device = getattr(config, "_device")
|
| 49 |
+
|
| 50 |
+
# CLIP components
|
| 51 |
+
self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 52 |
+
if self.clip_tokenizer.pad_token is None:
|
| 53 |
+
self.clip_tokenizer.pad_token = (
|
| 54 |
+
self.clip_tokenizer.unk_token
|
| 55 |
+
if self.clip_tokenizer.unk_token
|
| 56 |
+
else self.clip_tokenizer.eos_token
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.clip_processor = AutoProcessor.from_pretrained(config.model_name)
|
| 60 |
+
self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
|
| 61 |
+
self.clip_unary_model = AutoModel.from_pretrained(config.model_name)
|
| 62 |
+
self.clip_binary_model = AutoModel.from_pretrained(config.model_name)
|
| 63 |
+
|
| 64 |
+
# Load fine-tuned weights if available
|
| 65 |
+
if config.use_hf_repo:
|
| 66 |
+
self._load_huggingface_vine_weights(config.model_repo, config.model_file)
|
| 67 |
+
else:
|
| 68 |
+
self._load_local_pretrained_vine_weights(
|
| 69 |
+
config.local_dir, config.local_filename
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Optionally reset categorical model to base CLIP (ignore fine-tune)
|
| 73 |
+
if not getattr(config, "use_pretrained_cate_weights", True):
|
| 74 |
+
self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
|
| 75 |
+
self.clip_cate_model.to(self._device)
|
| 76 |
+
|
| 77 |
+
self.to(self._device)
|
| 78 |
+
|
| 79 |
+
# ------------------------------------------------------------------ #
|
| 80 |
+
# Weight loading
|
| 81 |
+
# ------------------------------------------------------------------ #
|
| 82 |
+
def _load_huggingface_vine_weights(
|
| 83 |
+
self, model_repo: str, model_file: Optional[str] = None
|
| 84 |
+
):
|
| 85 |
+
try:
|
| 86 |
+
print(f"Loading VINE weights from HuggingFace repo: {model_repo}")
|
| 87 |
+
repo_path = snapshot_download(model_repo, revision=model_file or "main")
|
| 88 |
+
weights = load_file(os.path.join(repo_path, "model.safetensors"))
|
| 89 |
+
self.load_state_dict(weights, strict=False)
|
| 90 |
+
print("✓ Successfully loaded VINE weights from HuggingFace Hub")
|
| 91 |
+
return True
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}")
|
| 94 |
+
print("Using base CLIP models instead")
|
| 95 |
+
return False
|
| 96 |
+
|
| 97 |
+
def _load_local_pretrained_vine_weights(
|
| 98 |
+
self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0
|
| 99 |
+
):
|
| 100 |
+
if local_dir is None and local_filename is None:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
full_path = (
|
| 104 |
+
os.path.join(local_dir, local_filename) if local_filename else local_dir
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# .pkl – usually pickled PredicateModel
|
| 108 |
+
if isinstance(full_path, str) and full_path.endswith(".pkl"):
|
| 109 |
+
print(f"Loading VINE weights from: {full_path}")
|
| 110 |
+
loaded_vine_model = torch.load(
|
| 111 |
+
full_path, map_location=self._device, weights_only=False
|
| 112 |
+
)
|
| 113 |
+
print(f"Loaded state type: {type(loaded_vine_model)}")
|
| 114 |
+
|
| 115 |
+
if not isinstance(loaded_vine_model, dict):
|
| 116 |
+
if hasattr(loaded_vine_model, "clip_tokenizer"):
|
| 117 |
+
self.clip_tokenizer = loaded_vine_model.clip_tokenizer
|
| 118 |
+
if hasattr(loaded_vine_model, "clip_processor"):
|
| 119 |
+
self.clip_processor = loaded_vine_model.clip_processor
|
| 120 |
+
|
| 121 |
+
if hasattr(loaded_vine_model, "clip_cate_model"):
|
| 122 |
+
self.clip_cate_model.load_state_dict(
|
| 123 |
+
loaded_vine_model.clip_cate_model.state_dict()
|
| 124 |
+
)
|
| 125 |
+
if hasattr(loaded_vine_model, "clip_unary_model"):
|
| 126 |
+
self.clip_unary_model.load_state_dict(
|
| 127 |
+
loaded_vine_model.clip_unary_model.state_dict()
|
| 128 |
+
)
|
| 129 |
+
if hasattr(loaded_vine_model, "clip_binary_model"):
|
| 130 |
+
self.clip_binary_model.load_state_dict(
|
| 131 |
+
loaded_vine_model.clip_binary_model.state_dict()
|
| 132 |
+
)
|
| 133 |
+
print("✓ Loaded VINE weights from .pkl PredicateModel checkpoint")
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
# .pt / .pth – plain state_dict
|
| 137 |
+
elif isinstance(full_path, str) and (
|
| 138 |
+
full_path.endswith(".pt") or full_path.endswith(".pth")
|
| 139 |
+
):
|
| 140 |
+
print(f"Loading VINE weights from: {full_path}")
|
| 141 |
+
state = torch.load(full_path, map_location=self._device, weights_only=True)
|
| 142 |
+
print(f"Loaded state type: {type(state)}")
|
| 143 |
+
self.load_state_dict(state, strict=False)
|
| 144 |
+
print("✓ Loaded VINE weights from state_dict")
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
# .model – full PredicateModel object
|
| 148 |
+
elif isinstance(full_path, str) and full_path.endswith(".model"):
|
| 149 |
+
print(f"Loading VINE weights from: {full_path}")
|
| 150 |
+
pretrained_model = torch.load(
|
| 151 |
+
full_path, map_location="cpu", weights_only=False
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if hasattr(pretrained_model, "clip_tokenizer"):
|
| 155 |
+
self.clip_tokenizer = pretrained_model.clip_tokenizer
|
| 156 |
+
if hasattr(pretrained_model, "clip_processor"):
|
| 157 |
+
self.clip_processor = pretrained_model.clip_processor
|
| 158 |
+
|
| 159 |
+
if hasattr(pretrained_model, "clip_cate_model"):
|
| 160 |
+
self.clip_cate_model.load_state_dict(
|
| 161 |
+
pretrained_model.clip_cate_model.state_dict()
|
| 162 |
+
)
|
| 163 |
+
if hasattr(pretrained_model, "clip_unary_model"):
|
| 164 |
+
self.clip_unary_model.load_state_dict(
|
| 165 |
+
pretrained_model.clip_unary_model.state_dict()
|
| 166 |
+
)
|
| 167 |
+
if hasattr(pretrained_model, "clip_binary_model"):
|
| 168 |
+
self.clip_binary_model.load_state_dict(
|
| 169 |
+
pretrained_model.clip_binary_model.state_dict()
|
| 170 |
+
)
|
| 171 |
+
print("✓ Loaded all sub-model weights from .model file")
|
| 172 |
+
return True
|
| 173 |
+
|
| 174 |
+
# directory of .model files
|
| 175 |
+
if isinstance(full_path, str) and os.path.isdir(full_path):
|
| 176 |
+
model_files = [
|
| 177 |
+
f for f in os.listdir(full_path) if f.endswith(f".{epoch}.model")
|
| 178 |
+
]
|
| 179 |
+
if model_files:
|
| 180 |
+
model_file = os.path.join(full_path, model_files[0])
|
| 181 |
+
print(f"Loading VINE weights from: {model_file}")
|
| 182 |
+
pretrained_model = torch.load(model_file, map_location="cpu")
|
| 183 |
+
|
| 184 |
+
if hasattr(pretrained_model, "clip_tokenizer"):
|
| 185 |
+
self.clip_tokenizer = pretrained_model.clip_tokenizer
|
| 186 |
+
if hasattr(pretrained_model, "clip_processor"):
|
| 187 |
+
self.clip_processor = pretrained_model.clip_processor
|
| 188 |
+
|
| 189 |
+
if hasattr(pretrained_model, "clip_cate_model"):
|
| 190 |
+
self.clip_cate_model.load_state_dict(
|
| 191 |
+
pretrained_model.clip_cate_model.state_dict()
|
| 192 |
+
)
|
| 193 |
+
if hasattr(pretrained_model, "clip_unary_model"):
|
| 194 |
+
self.clip_unary_model.load_state_dict(
|
| 195 |
+
pretrained_model.clip_unary_model.state_dict()
|
| 196 |
+
)
|
| 197 |
+
if hasattr(pretrained_model, "clip_binary_model"):
|
| 198 |
+
self.clip_binary_model.load_state_dict(
|
| 199 |
+
pretrained_model.clip_binary_model.state_dict()
|
| 200 |
+
)
|
| 201 |
+
print("✓ Loaded all sub-model weights from ensemble format")
|
| 202 |
+
return True
|
| 203 |
+
else:
|
| 204 |
+
print(f"No model file found for epoch {epoch} in {full_path}")
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
print("Unsupported format for pretrained VINE path:", full_path)
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
@classmethod
|
| 211 |
+
def from_pretrained_vine(
|
| 212 |
+
cls,
|
| 213 |
+
model_path: str,
|
| 214 |
+
config: Optional[VineConfig] = None,
|
| 215 |
+
epoch: int = 0,
|
| 216 |
+
**kwargs: Any,
|
| 217 |
+
):
|
| 218 |
+
if config is None:
|
| 219 |
+
if model_path and ("/" in model_path and not os.path.exists(model_path)):
|
| 220 |
+
config = VineConfig(use_hf_repo=True, model_repo=model_path)
|
| 221 |
+
else:
|
| 222 |
+
if os.path.isdir(model_path):
|
| 223 |
+
config = VineConfig(use_hf_repo=False, local_dir=model_path)
|
| 224 |
+
else:
|
| 225 |
+
config = VineConfig(
|
| 226 |
+
use_hf_repo=False,
|
| 227 |
+
local_dir=os.path.dirname(model_path) or None,
|
| 228 |
+
local_filename=os.path.basename(model_path) or None,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
if model_path and ("/" in model_path and not os.path.exists(model_path)):
|
| 232 |
+
config.use_hf_repo = True
|
| 233 |
+
config.model_repo = model_path
|
| 234 |
+
config.model_file = None
|
| 235 |
+
config.local_dir = None
|
| 236 |
+
config.local_filename = None
|
| 237 |
+
else:
|
| 238 |
+
config.use_hf_repo = False
|
| 239 |
+
if os.path.isdir(model_path):
|
| 240 |
+
config.local_dir = model_path
|
| 241 |
+
config.local_filename = None
|
| 242 |
+
else:
|
| 243 |
+
config.local_dir = os.path.dirname(model_path) or None
|
| 244 |
+
config.local_filename = os.path.basename(model_path) or None
|
| 245 |
+
|
| 246 |
+
model = cls(config, **kwargs)
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
# ------------------------------------------------------------------ #
|
| 250 |
+
# Gradient checkpoint helpers
|
| 251 |
+
# ------------------------------------------------------------------ #
|
| 252 |
+
def _text_features_checkpoint(self, model, token_dict):
|
| 253 |
+
input_ids = token_dict["input_ids"]
|
| 254 |
+
attention_mask = token_dict["attention_mask"]
|
| 255 |
+
token_type_ids = token_dict.get("token_type_ids", None)
|
| 256 |
+
|
| 257 |
+
if token_type_ids is not None:
|
| 258 |
+
|
| 259 |
+
def forward_pass(input_ids, attention_mask, token_type_ids):
|
| 260 |
+
return model.get_text_features(
|
| 261 |
+
input_ids=input_ids,
|
| 262 |
+
attention_mask=attention_mask,
|
| 263 |
+
token_type_ids=token_type_ids,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return cp.checkpoint(
|
| 267 |
+
forward_pass,
|
| 268 |
+
input_ids,
|
| 269 |
+
attention_mask,
|
| 270 |
+
token_type_ids,
|
| 271 |
+
use_reentrant=False,
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
|
| 275 |
+
def forward_pass(input_ids, attention_mask):
|
| 276 |
+
return model.get_text_features(
|
| 277 |
+
input_ids=input_ids,
|
| 278 |
+
attention_mask=attention_mask,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return cp.checkpoint(
|
| 282 |
+
forward_pass, input_ids, attention_mask, use_reentrant=False
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def _image_features_checkpoint(self, model, pixel_values):
|
| 286 |
+
def forward_pass(pixel_values):
|
| 287 |
+
return model.get_image_features(pixel_values=pixel_values)
|
| 288 |
+
|
| 289 |
+
return cp.checkpoint(forward_pass, pixel_values, use_reentrant=False)
|
| 290 |
+
|
| 291 |
+
# ------------------------------------------------------------------ #
|
| 292 |
+
# CLIP similarity
|
| 293 |
+
# ------------------------------------------------------------------ #
|
| 294 |
+
def clip_sim(self, model, nl_feat, img_feat):
|
| 295 |
+
img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
|
| 296 |
+
nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True)
|
| 297 |
+
|
| 298 |
+
logit_scale = getattr(model, "logit_scale", None)
|
| 299 |
+
logits_per_text = torch.matmul(nl_feat, img_feat.t())
|
| 300 |
+
if logit_scale is not None:
|
| 301 |
+
logits_per_text = logits_per_text * logit_scale.exp()
|
| 302 |
+
return logits_per_text
|
| 303 |
+
|
| 304 |
+
# ------------------------------------------------------------------ #
|
| 305 |
+
# Forward: single-video PredicateModel-style logic
|
| 306 |
+
# ------------------------------------------------------------------ #
|
| 307 |
+
def forward(
|
| 308 |
+
self,
|
| 309 |
+
video_frames: torch.Tensor,
|
| 310 |
+
masks: Dict[int, Dict[int, torch.Tensor]],
|
| 311 |
+
bboxes: Dict[int, Dict[int, List]],
|
| 312 |
+
categorical_keywords: List[str],
|
| 313 |
+
unary_keywords: Optional[List[str]] = None,
|
| 314 |
+
binary_keywords: Optional[List[str]] = None,
|
| 315 |
+
object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 316 |
+
return_flattened_segments: Optional[bool] = None,
|
| 317 |
+
return_valid_pairs: Optional[bool] = None,
|
| 318 |
+
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 319 |
+
debug_visualizations: Optional[bool] = None,
|
| 320 |
+
**kwargs: Any,
|
| 321 |
+
) -> Dict[str, Any]:
|
| 322 |
+
if unary_keywords is None:
|
| 323 |
+
unary_keywords = []
|
| 324 |
+
if binary_keywords is None:
|
| 325 |
+
binary_keywords = []
|
| 326 |
+
if object_pairs is None:
|
| 327 |
+
object_pairs = []
|
| 328 |
+
|
| 329 |
+
if return_flattened_segments is None:
|
| 330 |
+
return_flattened_segments = getattr(
|
| 331 |
+
self.config, "return_flattened_segments", False
|
| 332 |
+
)
|
| 333 |
+
if return_valid_pairs is None:
|
| 334 |
+
return_valid_pairs = getattr(self.config, "return_valid_pairs", False)
|
| 335 |
+
if interested_object_pairs is None or len(interested_object_pairs) == 0:
|
| 336 |
+
interested_object_pairs = (
|
| 337 |
+
getattr(self.config, "interested_object_pairs", []) or []
|
| 338 |
+
)
|
| 339 |
+
if debug_visualizations is None:
|
| 340 |
+
debug_visualizations = self.debug_visualizations
|
| 341 |
+
|
| 342 |
+
alpha = getattr(self.config, "alpha", 0.5)
|
| 343 |
+
white_alpha = getattr(self.config, "white_alpha", 0.8)
|
| 344 |
+
topk_cate = kwargs.pop("topk_cate", getattr(self.config, "topk_cate", 3))
|
| 345 |
+
dummy_str = kwargs.pop("dummy_str", getattr(self.config, "dummy_str", "$$$"))
|
| 346 |
+
multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
|
| 347 |
+
output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
|
| 348 |
+
output_embeddings = kwargs.pop("output_embeddings", False)
|
| 349 |
+
|
| 350 |
+
batched_video_ids = [0]
|
| 351 |
+
|
| 352 |
+
if torch.is_tensor(video_frames):
|
| 353 |
+
num_frames = video_frames.shape[0]
|
| 354 |
+
batched_videos = [
|
| 355 |
+
self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames)
|
| 356 |
+
]
|
| 357 |
+
else:
|
| 358 |
+
num_frames = len(video_frames)
|
| 359 |
+
batched_videos = [
|
| 360 |
+
self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames)
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
batched_masks: List[np.ndarray] = []
|
| 364 |
+
batched_bboxes: List[List[float]] = []
|
| 365 |
+
batched_object_ids: List[Tuple[int, int, int]] = []
|
| 366 |
+
|
| 367 |
+
for frame_id, frame_masks in masks.items():
|
| 368 |
+
if frame_id >= num_frames:
|
| 369 |
+
continue
|
| 370 |
+
frame_boxes = bboxes.get(frame_id, {})
|
| 371 |
+
for obj_id, mask in frame_masks.items():
|
| 372 |
+
if obj_id not in frame_boxes:
|
| 373 |
+
continue
|
| 374 |
+
bbox = frame_boxes[obj_id]
|
| 375 |
+
batched_object_ids.append((0, frame_id, obj_id))
|
| 376 |
+
batched_masks.append(self._mask_to_numpy(mask))
|
| 377 |
+
batched_bboxes.append(bbox)
|
| 378 |
+
|
| 379 |
+
batched_names = [list(categorical_keywords)]
|
| 380 |
+
batched_unary_kws = [list(unary_keywords)]
|
| 381 |
+
batched_binary_kws = [list(binary_keywords)]
|
| 382 |
+
|
| 383 |
+
batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 384 |
+
if object_pairs:
|
| 385 |
+
for frame_id, frame_masks in masks.items():
|
| 386 |
+
if frame_id >= num_frames:
|
| 387 |
+
continue
|
| 388 |
+
present_ids = set(frame_masks.keys())
|
| 389 |
+
for (from_oid, to_oid) in object_pairs:
|
| 390 |
+
if from_oid in present_ids and to_oid in present_ids:
|
| 391 |
+
batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
|
| 392 |
+
|
| 393 |
+
batched_video_splits = [0]
|
| 394 |
+
batched_binary_predicates = [None]
|
| 395 |
+
|
| 396 |
+
def fill_empty(batched_kw):
|
| 397 |
+
new_batched = []
|
| 398 |
+
for kw_ls in batched_kw:
|
| 399 |
+
if len(kw_ls) == 0:
|
| 400 |
+
new_batched.append([dummy_str])
|
| 401 |
+
else:
|
| 402 |
+
new_batched.append(list(kw_ls))
|
| 403 |
+
return new_batched
|
| 404 |
+
|
| 405 |
+
batched_names = fill_empty(batched_names)
|
| 406 |
+
batched_unary_kws = fill_empty(batched_unary_kws)
|
| 407 |
+
batched_binary_kws = fill_empty(batched_binary_kws)
|
| 408 |
+
|
| 409 |
+
dummy_prob = torch.tensor(0.0, device=self._device)
|
| 410 |
+
|
| 411 |
+
batched_obj_name_features = []
|
| 412 |
+
batched_unary_nl_features = []
|
| 413 |
+
batched_binary_nl_features = []
|
| 414 |
+
|
| 415 |
+
batched_object_ids_lookup: Dict[int, List[Tuple[int, int]]] = {0: []}
|
| 416 |
+
batch_size = len(batched_video_ids)
|
| 417 |
+
|
| 418 |
+
# Step 1: text features
|
| 419 |
+
for object_names, unary_kws, binary_kws in zip(
|
| 420 |
+
batched_names, batched_unary_kws, batched_binary_kws
|
| 421 |
+
):
|
| 422 |
+
if len(object_names) == 0:
|
| 423 |
+
batched_obj_name_features.append([])
|
| 424 |
+
else:
|
| 425 |
+
obj_tokens = self.clip_tokenizer(
|
| 426 |
+
object_names,
|
| 427 |
+
return_tensors="pt",
|
| 428 |
+
max_length=75,
|
| 429 |
+
truncation=True,
|
| 430 |
+
padding="max_length",
|
| 431 |
+
).to(self._device)
|
| 432 |
+
obj_feats = self._text_features_checkpoint(
|
| 433 |
+
self.clip_cate_model, obj_tokens
|
| 434 |
+
)
|
| 435 |
+
batched_obj_name_features.append(obj_feats)
|
| 436 |
+
|
| 437 |
+
if len(unary_kws) == 0:
|
| 438 |
+
batched_unary_nl_features.append([])
|
| 439 |
+
else:
|
| 440 |
+
unary_tokens = self.clip_tokenizer(
|
| 441 |
+
list(unary_kws),
|
| 442 |
+
return_tensors="pt",
|
| 443 |
+
max_length=75,
|
| 444 |
+
truncation=True,
|
| 445 |
+
padding="max_length",
|
| 446 |
+
).to(self._device)
|
| 447 |
+
unary_feats = self._text_features_checkpoint(
|
| 448 |
+
self.clip_unary_model, unary_tokens
|
| 449 |
+
)
|
| 450 |
+
batched_unary_nl_features.append(unary_feats)
|
| 451 |
+
|
| 452 |
+
if len(binary_kws) == 0:
|
| 453 |
+
batched_binary_nl_features.append([])
|
| 454 |
+
else:
|
| 455 |
+
binary_tokens = self.clip_tokenizer(
|
| 456 |
+
list(binary_kws),
|
| 457 |
+
return_tensors="pt",
|
| 458 |
+
max_length=75,
|
| 459 |
+
truncation=True,
|
| 460 |
+
padding="max_length",
|
| 461 |
+
).to(self._device)
|
| 462 |
+
binary_feats = self._text_features_checkpoint(
|
| 463 |
+
self.clip_binary_model, binary_tokens
|
| 464 |
+
)
|
| 465 |
+
batched_binary_nl_features.append(binary_feats)
|
| 466 |
+
|
| 467 |
+
# Step 2: crop objects
|
| 468 |
+
batched_frame_masks: Dict[Tuple[int, int, int], np.ndarray] = {}
|
| 469 |
+
batched_frame_bboxes: Dict[Tuple[int, int, int], List[float]] = {}
|
| 470 |
+
batched_cropped_objs: Dict[int, List[np.ndarray]] = {
|
| 471 |
+
vid: [] for vid in range(batch_size)
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
assert len(batched_object_ids) > 0, f"No object bbox: {batched_video_ids}"
|
| 475 |
+
|
| 476 |
+
batched_video_splits = [0] + batched_video_splits
|
| 477 |
+
|
| 478 |
+
for (video_id, frame_id, obj_id), mask, bbox in zip(
|
| 479 |
+
batched_object_ids, batched_masks, batched_bboxes
|
| 480 |
+
):
|
| 481 |
+
overall_frame_id = batched_video_splits[video_id] + frame_id
|
| 482 |
+
object_img = extract_single_object(
|
| 483 |
+
batched_videos[overall_frame_id], mask, white_alpha
|
| 484 |
+
)
|
| 485 |
+
cropped_object_img = crop_image_contain_bboxes(
|
| 486 |
+
object_img, [bbox], batched_video_ids
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if self.visualization_dir:
|
| 490 |
+
debug_crop_dir = os.path.join(self.visualization_dir, "debug_crops")
|
| 491 |
+
os.makedirs(debug_crop_dir, exist_ok=True)
|
| 492 |
+
cv2.imwrite(
|
| 493 |
+
os.path.join(debug_crop_dir, f"frame_{frame_id}_obj_{obj_id}.jpg"),
|
| 494 |
+
cv2.cvtColor(cropped_object_img, cv2.COLOR_RGB2BGR),
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
batched_frame_masks[(video_id, frame_id, obj_id)] = mask
|
| 498 |
+
batched_frame_bboxes[(video_id, frame_id, obj_id)] = bbox
|
| 499 |
+
batched_object_ids_lookup[video_id].append((frame_id, obj_id))
|
| 500 |
+
batched_cropped_objs[video_id].append(cropped_object_img)
|
| 501 |
+
|
| 502 |
+
# Step 3: categorical + unary
|
| 503 |
+
batched_image_unary_probs: Dict[int, Dict] = {}
|
| 504 |
+
batched_image_cate_probs: Dict[int, Dict] = {}
|
| 505 |
+
batched_obj_cate_features: Dict[int, Any] = {}
|
| 506 |
+
batched_obj_unary_features: Dict[int, Any] = {}
|
| 507 |
+
batched_obj_per_cate: Dict[int, Dict[str, List[Tuple[torch.Tensor, int]]]] = {}
|
| 508 |
+
|
| 509 |
+
for vid in range(batch_size):
|
| 510 |
+
batched_image_unary_probs[vid] = {}
|
| 511 |
+
batched_image_cate_probs[vid] = {}
|
| 512 |
+
batched_obj_cate_features[vid] = {}
|
| 513 |
+
batched_obj_unary_features[vid] = {}
|
| 514 |
+
batched_obj_per_cate[vid] = {}
|
| 515 |
+
|
| 516 |
+
for vid_id, (
|
| 517 |
+
unary_nl_feats,
|
| 518 |
+
object_name_feats,
|
| 519 |
+
cate,
|
| 520 |
+
unary_pred,
|
| 521 |
+
binary_predicates,
|
| 522 |
+
) in enumerate(
|
| 523 |
+
zip(
|
| 524 |
+
batched_unary_nl_features,
|
| 525 |
+
batched_obj_name_features,
|
| 526 |
+
batched_names,
|
| 527 |
+
batched_unary_kws,
|
| 528 |
+
batched_binary_predicates,
|
| 529 |
+
)
|
| 530 |
+
):
|
| 531 |
+
cropped_objs = batched_cropped_objs[vid_id]
|
| 532 |
+
|
| 533 |
+
if len(cropped_objs) != 0:
|
| 534 |
+
inputs = self.clip_processor(
|
| 535 |
+
images=cropped_objs, return_tensors="pt"
|
| 536 |
+
).to(self._device)
|
| 537 |
+
cate_obj_clip_features = self._image_features_checkpoint(
|
| 538 |
+
self.clip_cate_model, inputs["pixel_values"]
|
| 539 |
+
)
|
| 540 |
+
unary_obj_clip_features = self._image_features_checkpoint(
|
| 541 |
+
self.clip_unary_model, inputs["pixel_values"]
|
| 542 |
+
)
|
| 543 |
+
batched_obj_unary_features[vid_id] = unary_obj_clip_features
|
| 544 |
+
batched_obj_cate_features[vid_id] = cate_obj_clip_features
|
| 545 |
+
else:
|
| 546 |
+
batched_obj_cate_features[vid_id] = torch.tensor([])
|
| 547 |
+
batched_obj_unary_features[vid_id] = torch.tensor([])
|
| 548 |
+
|
| 549 |
+
object_ids = batched_object_ids_lookup[vid_id]
|
| 550 |
+
|
| 551 |
+
# Categorical logits
|
| 552 |
+
if (
|
| 553 |
+
len(object_name_feats) == 0
|
| 554 |
+
or len(object_ids) == 0
|
| 555 |
+
or len(cropped_objs) == 0
|
| 556 |
+
):
|
| 557 |
+
cate_logits_per_text = torch.tensor([])
|
| 558 |
+
else:
|
| 559 |
+
cate_logits_per_text = self.clip_sim(
|
| 560 |
+
self.clip_cate_model, object_name_feats, cate_obj_clip_features
|
| 561 |
+
)
|
| 562 |
+
if not output_logit:
|
| 563 |
+
cate_logits_per_text = cate_logits_per_text.softmax(dim=0)
|
| 564 |
+
|
| 565 |
+
if not (
|
| 566 |
+
len(object_ids) == 0
|
| 567 |
+
or (
|
| 568 |
+
cate_logits_per_text.ndim == 2
|
| 569 |
+
and cate_logits_per_text.shape[1] == len(object_ids)
|
| 570 |
+
)
|
| 571 |
+
or len(object_name_feats) == 0
|
| 572 |
+
):
|
| 573 |
+
print("Object cate shape mismatch here")
|
| 574 |
+
|
| 575 |
+
assert (
|
| 576 |
+
len(object_name_feats) == 0
|
| 577 |
+
or len(object_ids) == 0
|
| 578 |
+
or (
|
| 579 |
+
cate_logits_per_text.ndim == 2
|
| 580 |
+
and cate_logits_per_text.shape[1] == len(object_ids)
|
| 581 |
+
)
|
| 582 |
+
), f"Mismatched object id and cate logic: {batched_video_ids}"
|
| 583 |
+
|
| 584 |
+
# Aggregate per object id across frames
|
| 585 |
+
cate_prob_per_obj: Dict[int, Dict[str, List[torch.Tensor]]] = {}
|
| 586 |
+
for cate_name, probs in zip(cate, cate_logits_per_text):
|
| 587 |
+
if cate_name == dummy_str:
|
| 588 |
+
dummy_prob += probs.sum()
|
| 589 |
+
else:
|
| 590 |
+
for prob, (fid, oid) in zip(probs, object_ids):
|
| 591 |
+
cate_prob_per_obj.setdefault(oid, {})
|
| 592 |
+
cate_prob_per_obj[oid].setdefault(cate_name, []).append(prob)
|
| 593 |
+
|
| 594 |
+
new_cate_prob_per_obj: Dict[Tuple[int, str], torch.Tensor] = {}
|
| 595 |
+
obj_per_cate: Dict[str, List[Tuple[torch.Tensor, int]]] = {}
|
| 596 |
+
|
| 597 |
+
for oid, object_cate_info in cate_prob_per_obj.items():
|
| 598 |
+
# Pool across frames per category
|
| 599 |
+
pooled: Dict[str, torch.Tensor] = {}
|
| 600 |
+
for cate_name, prob_list in object_cate_info.items():
|
| 601 |
+
stacked = torch.stack(prob_list)
|
| 602 |
+
if getattr(self.config, "categorical_pool", "mean") == "mean":
|
| 603 |
+
pooled_prob = stacked.mean()
|
| 604 |
+
else:
|
| 605 |
+
pooled_prob = stacked.max()
|
| 606 |
+
pooled[cate_name] = pooled_prob
|
| 607 |
+
|
| 608 |
+
if not pooled:
|
| 609 |
+
continue
|
| 610 |
+
|
| 611 |
+
# Renormalize across categories so they sum to 1 per object
|
| 612 |
+
probs_tensor = torch.stack(list(pooled.values()))
|
| 613 |
+
denom = probs_tensor.sum()
|
| 614 |
+
if denom.item() <= 0:
|
| 615 |
+
norm_tensor = torch.ones_like(probs_tensor) / len(pooled)
|
| 616 |
+
else:
|
| 617 |
+
norm_tensor = probs_tensor / denom
|
| 618 |
+
|
| 619 |
+
for (cate_name, _), norm_prob in zip(pooled.items(), norm_tensor):
|
| 620 |
+
obj_per_cate.setdefault(cate_name, []).append((norm_prob, oid))
|
| 621 |
+
new_cate_prob_per_obj[(oid, cate_name)] = norm_prob
|
| 622 |
+
|
| 623 |
+
for cate_name in obj_per_cate:
|
| 624 |
+
obj_per_cate[cate_name] = sorted(
|
| 625 |
+
obj_per_cate[cate_name], key=lambda x: x[0], reverse=True
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# Unary
|
| 629 |
+
if len(unary_nl_feats) == 0 or len(cropped_objs) == 0:
|
| 630 |
+
unary_logits_per_text = torch.tensor([])
|
| 631 |
+
else:
|
| 632 |
+
unary_logits_per_text = self.clip_sim(
|
| 633 |
+
self.clip_unary_model, unary_nl_feats, unary_obj_clip_features
|
| 634 |
+
)
|
| 635 |
+
if not output_logit:
|
| 636 |
+
unary_logits_per_text = unary_logits_per_text.softmax(dim=0)
|
| 637 |
+
|
| 638 |
+
unary_prob_per_obj: Dict[Tuple[int, int, str], torch.Tensor] = {}
|
| 639 |
+
for unary_name, probs in zip(unary_pred, unary_logits_per_text):
|
| 640 |
+
if unary_name == dummy_str:
|
| 641 |
+
dummy_prob += probs.sum()
|
| 642 |
+
else:
|
| 643 |
+
for prob, (fid, oid) in zip(probs, object_ids):
|
| 644 |
+
unary_prob_per_obj[(fid, oid, unary_name)] = prob
|
| 645 |
+
|
| 646 |
+
batched_image_cate_probs[vid_id] = new_cate_prob_per_obj
|
| 647 |
+
batched_image_unary_probs[vid_id] = unary_prob_per_obj
|
| 648 |
+
batched_obj_per_cate[vid_id] = obj_per_cate
|
| 649 |
+
|
| 650 |
+
# Step 4: binary pairs
|
| 651 |
+
batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
|
| 652 |
+
frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
|
| 653 |
+
current_info = (0, 0)
|
| 654 |
+
frame_splits[current_info] = {"start": 0}
|
| 655 |
+
|
| 656 |
+
batched_topk_cate_candidates: Dict[int, Dict[str, List[int]]] = {
|
| 657 |
+
video_id: {} for video_id in range(batch_size)
|
| 658 |
+
}
|
| 659 |
+
for video_id, obj_per_cate in batched_obj_per_cate.items():
|
| 660 |
+
topk_cate_candidates: Dict[str, List[int]] = {}
|
| 661 |
+
for cate_name, pred_oid_ls in obj_per_cate.items():
|
| 662 |
+
for _, oid in pred_oid_ls[:topk_cate]:
|
| 663 |
+
topk_cate_candidates.setdefault(cate_name, []).append(oid)
|
| 664 |
+
batched_topk_cate_candidates[video_id] = topk_cate_candidates
|
| 665 |
+
|
| 666 |
+
obj_pair_lookup: Dict[int, Dict[Tuple[int, int], List[int]]] = {
|
| 667 |
+
video_id: {} for video_id in range(len(batched_video_ids))
|
| 668 |
+
}
|
| 669 |
+
for (vid, fid, (from_oid, to_oid)) in batched_obj_pairs:
|
| 670 |
+
if (from_oid, to_oid) not in obj_pair_lookup[vid]:
|
| 671 |
+
obj_pair_lookup[vid][(from_oid, to_oid)] = []
|
| 672 |
+
obj_pair_lookup[vid][(from_oid, to_oid)].append(fid)
|
| 673 |
+
|
| 674 |
+
selected_pairs = set()
|
| 675 |
+
if batched_binary_predicates[0] is None:
|
| 676 |
+
selected_pairs = set(batched_obj_pairs)
|
| 677 |
+
else:
|
| 678 |
+
for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
|
| 679 |
+
topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
|
| 680 |
+
for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
|
| 681 |
+
if (
|
| 682 |
+
from_obj_name in topk_cate_candidates
|
| 683 |
+
and to_obj_name in topk_cate_candidates
|
| 684 |
+
):
|
| 685 |
+
from_oids = topk_cate_candidates[from_obj_name]
|
| 686 |
+
to_oids = topk_cate_candidates[to_obj_name]
|
| 687 |
+
for from_oid in from_oids:
|
| 688 |
+
for to_oid in to_oids:
|
| 689 |
+
if (
|
| 690 |
+
bp_vid in obj_pair_lookup
|
| 691 |
+
and (from_oid, to_oid) in obj_pair_lookup[bp_vid]
|
| 692 |
+
):
|
| 693 |
+
for fid in obj_pair_lookup[bp_vid][
|
| 694 |
+
(from_oid, to_oid)
|
| 695 |
+
]:
|
| 696 |
+
selected_pairs.add(
|
| 697 |
+
(bp_vid, fid, (from_oid, to_oid))
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
selected_pairs = list(selected_pairs)
|
| 701 |
+
|
| 702 |
+
new_select_pairs: Dict[int, List[Tuple[int, int, Tuple[int, int]]]] = {
|
| 703 |
+
video_id: [] for video_id in range(len(batched_video_ids))
|
| 704 |
+
}
|
| 705 |
+
for (vid, fid, (from_oid, to_oid)) in selected_pairs:
|
| 706 |
+
new_select_pairs[vid].append((vid, fid, (from_oid, to_oid)))
|
| 707 |
+
|
| 708 |
+
for vid in range(len(batched_video_ids)):
|
| 709 |
+
batched_cropped_obj_pairs[vid] = []
|
| 710 |
+
|
| 711 |
+
for (vid, fid, (from_id, to_id)) in selected_pairs:
|
| 712 |
+
if (vid, fid, from_id) not in batched_frame_masks or (
|
| 713 |
+
vid,
|
| 714 |
+
fid,
|
| 715 |
+
to_id,
|
| 716 |
+
) not in batched_frame_masks:
|
| 717 |
+
continue
|
| 718 |
+
if (vid, fid, from_id) not in batched_frame_bboxes or (
|
| 719 |
+
vid,
|
| 720 |
+
fid,
|
| 721 |
+
to_id,
|
| 722 |
+
) not in batched_frame_bboxes:
|
| 723 |
+
continue
|
| 724 |
+
|
| 725 |
+
overall_frame_id = batched_video_splits[vid] + fid
|
| 726 |
+
mask1 = batched_frame_masks[(vid, fid, from_id)]
|
| 727 |
+
mask2 = batched_frame_masks[(vid, fid, to_id)]
|
| 728 |
+
bbox1 = batched_frame_bboxes[(vid, fid, from_id)]
|
| 729 |
+
bbox2 = batched_frame_bboxes[(vid, fid, to_id)]
|
| 730 |
+
bb_pop_image = extract_object_subject(
|
| 731 |
+
batched_videos[overall_frame_id],
|
| 732 |
+
mask1,
|
| 733 |
+
mask2,
|
| 734 |
+
alpha=alpha,
|
| 735 |
+
white_alpha=white_alpha,
|
| 736 |
+
)
|
| 737 |
+
cropped_bb_pop_image = crop_image_contain_bboxes(
|
| 738 |
+
img=bb_pop_image,
|
| 739 |
+
bbox_ls=[bbox1, bbox2],
|
| 740 |
+
data_id=batched_video_ids,
|
| 741 |
+
)
|
| 742 |
+
batched_cropped_obj_pairs[vid].append(cropped_bb_pop_image)
|
| 743 |
+
|
| 744 |
+
if len(selected_pairs) == 0:
|
| 745 |
+
selected_pairs.append((0, -1, (-1, -1)))
|
| 746 |
+
new_select_pairs[0] = [(0, -1, (-1, -1))]
|
| 747 |
+
dummy_img = batched_videos[0]
|
| 748 |
+
batched_cropped_obj_pairs[0] = [dummy_img]
|
| 749 |
+
|
| 750 |
+
batched_image_binary_probs: List[
|
| 751 |
+
Dict[Tuple[int, Tuple[int, int], str], torch.Tensor]
|
| 752 |
+
] = []
|
| 753 |
+
batched_obj_pair_features: Dict[int, torch.Tensor] = {
|
| 754 |
+
vid: torch.tensor([]) for vid in range(batch_size)
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
if len(batched_cropped_obj_pairs) == 0:
|
| 758 |
+
batched_image_binary_probs.append({})
|
| 759 |
+
else:
|
| 760 |
+
for vid, binary_nl_features in enumerate(batched_binary_nl_features):
|
| 761 |
+
if len(binary_nl_features) == 0:
|
| 762 |
+
batched_image_binary_probs.append({})
|
| 763 |
+
continue
|
| 764 |
+
|
| 765 |
+
binary_kws = batched_binary_kws[vid]
|
| 766 |
+
cropped_obj_pairs = batched_cropped_obj_pairs[vid]
|
| 767 |
+
if len(cropped_obj_pairs) == 0:
|
| 768 |
+
batched_image_binary_probs.append({})
|
| 769 |
+
continue
|
| 770 |
+
|
| 771 |
+
inputs = self.clip_processor(
|
| 772 |
+
images=cropped_obj_pairs, return_tensors="pt"
|
| 773 |
+
).to(self._device)
|
| 774 |
+
obj_features = self._image_features_checkpoint(
|
| 775 |
+
self.clip_binary_model, inputs["pixel_values"]
|
| 776 |
+
)
|
| 777 |
+
batched_obj_pair_features[vid] = obj_features
|
| 778 |
+
|
| 779 |
+
obj_clip_features = obj_features / obj_features.norm(
|
| 780 |
+
p=2, dim=-1, keepdim=True
|
| 781 |
+
)
|
| 782 |
+
binary_nl_features = binary_nl_features / binary_nl_features.norm(
|
| 783 |
+
p=2, dim=-1, keepdim=True
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
logit_scale = self.clip_binary_model.logit_scale
|
| 787 |
+
binary_logits_per_text = torch.matmul(
|
| 788 |
+
binary_nl_features, obj_clip_features.t()
|
| 789 |
+
) * logit_scale.exp()
|
| 790 |
+
|
| 791 |
+
if not output_logit:
|
| 792 |
+
if not multi_class:
|
| 793 |
+
binary_logits_per_text = binary_logits_per_text.softmax(dim=0)
|
| 794 |
+
else:
|
| 795 |
+
binary_logits_per_text = binary_logits_per_text.sigmoid()
|
| 796 |
+
|
| 797 |
+
binary_prob_per_obj: Dict[
|
| 798 |
+
Tuple[int, Tuple[int, int], str], torch.Tensor
|
| 799 |
+
] = {}
|
| 800 |
+
for binary_name, probs in zip(binary_kws, binary_logits_per_text):
|
| 801 |
+
if binary_name == dummy_str:
|
| 802 |
+
dummy_prob += probs.sum()
|
| 803 |
+
else:
|
| 804 |
+
for prob, (vid_, fid, obj_pair) in zip(
|
| 805 |
+
probs, new_select_pairs[vid]
|
| 806 |
+
):
|
| 807 |
+
if fid == -1:
|
| 808 |
+
dummy_prob += prob
|
| 809 |
+
else:
|
| 810 |
+
binary_prob_per_obj[(fid, obj_pair, binary_name)] = prob
|
| 811 |
+
batched_image_binary_probs.append(binary_prob_per_obj)
|
| 812 |
+
|
| 813 |
+
result: Dict[str, Any] = {
|
| 814 |
+
"categorical_probs": batched_image_cate_probs,
|
| 815 |
+
"unary_probs": batched_image_unary_probs,
|
| 816 |
+
"binary_probs": batched_image_binary_probs,
|
| 817 |
+
"dummy_prob": dummy_prob,
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
if output_embeddings:
|
| 821 |
+
embeddings_dict = {
|
| 822 |
+
"cate_obj_clip_features": batched_obj_cate_features,
|
| 823 |
+
"cate_object_ids": batched_object_ids_lookup,
|
| 824 |
+
"unary_obj_clip_features": batched_obj_unary_features,
|
| 825 |
+
"unary_object_ids": batched_object_ids_lookup,
|
| 826 |
+
"binary_obj_pair_features": batched_obj_pair_features,
|
| 827 |
+
"binary_object_pairs": new_select_pairs,
|
| 828 |
+
}
|
| 829 |
+
result["embeddings"] = embeddings_dict
|
| 830 |
+
|
| 831 |
+
if return_flattened_segments or return_valid_pairs:
|
| 832 |
+
flattened = flatten_segments_for_batch(
|
| 833 |
+
video_id=0,
|
| 834 |
+
segments=masks,
|
| 835 |
+
bbox_min_dim=self.config.bbox_min_dim,
|
| 836 |
+
)
|
| 837 |
+
if return_flattened_segments:
|
| 838 |
+
result["flattened_segments"] = flattened
|
| 839 |
+
if return_valid_pairs:
|
| 840 |
+
interested_pairs = (
|
| 841 |
+
interested_object_pairs if interested_object_pairs else None
|
| 842 |
+
)
|
| 843 |
+
result["valid_pairs"] = extract_valid_object_pairs(
|
| 844 |
+
flattened["object_ids"],
|
| 845 |
+
interested_pairs,
|
| 846 |
+
)
|
| 847 |
+
if interested_pairs is None:
|
| 848 |
+
result["valid_pairs_metadata"] = {"pair_source": "all_pairs"}
|
| 849 |
+
else:
|
| 850 |
+
result["valid_pairs_metadata"] = {
|
| 851 |
+
"pair_source": "filtered",
|
| 852 |
+
"requested_pairs": interested_object_pairs,
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
return result
|
| 856 |
+
|
| 857 |
+
# ------------------------------------------------------------------ #
|
| 858 |
+
# Helpers
|
| 859 |
+
# ------------------------------------------------------------------ #
|
| 860 |
+
def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 861 |
+
if torch.is_tensor(frame):
|
| 862 |
+
frame_np = frame.detach().cpu().numpy()
|
| 863 |
+
else:
|
| 864 |
+
frame_np = np.asarray(frame)
|
| 865 |
+
return np.ascontiguousarray(frame_np)
|
| 866 |
+
|
| 867 |
+
def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
| 868 |
+
if torch.is_tensor(mask):
|
| 869 |
+
mask_np = mask.detach().cpu().numpy()
|
| 870 |
+
else:
|
| 871 |
+
mask_np = np.asarray(mask)
|
| 872 |
+
|
| 873 |
+
if mask_np.ndim == 3:
|
| 874 |
+
if mask_np.shape[0] == 1:
|
| 875 |
+
mask_np = mask_np.squeeze(0)
|
| 876 |
+
elif mask_np.shape[2] == 1:
|
| 877 |
+
mask_np = mask_np.squeeze(2)
|
| 878 |
+
|
| 879 |
+
if mask_np.ndim != 2:
|
| 880 |
+
raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}")
|
| 881 |
+
|
| 882 |
+
return mask_np.astype(bool, copy=False)
|
| 883 |
+
|
| 884 |
+
def _extract_text_features(self, model, keywords: List[str]):
|
| 885 |
+
tokens = self.clip_tokenizer(
|
| 886 |
+
keywords,
|
| 887 |
+
return_tensors="pt",
|
| 888 |
+
max_length=75,
|
| 889 |
+
truncation=True,
|
| 890 |
+
padding="max_length",
|
| 891 |
+
).to(self._device)
|
| 892 |
+
return self._text_features_checkpoint(model, tokens)
|
| 893 |
+
|
| 894 |
+
def _extract_image_features(self, model, image):
|
| 895 |
+
if torch.is_tensor(image):
|
| 896 |
+
image = image.detach().cpu().numpy()
|
| 897 |
+
elif isinstance(image, np.ndarray):
|
| 898 |
+
pass
|
| 899 |
+
|
| 900 |
+
inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
|
| 901 |
+
return self._image_features_checkpoint(model, inputs["pixel_values"])
|
| 902 |
+
|
| 903 |
+
# ------------------------------------------------------------------ #
|
| 904 |
+
# High-level predict API
|
| 905 |
+
# ------------------------------------------------------------------ #
|
| 906 |
+
def predict(
|
| 907 |
+
self,
|
| 908 |
+
video_frames: torch.Tensor,
|
| 909 |
+
masks: Dict[int, Dict[int, torch.Tensor]],
|
| 910 |
+
bboxes: Dict[int, Dict[int, List]],
|
| 911 |
+
categorical_keywords: List[str],
|
| 912 |
+
unary_keywords: Optional[List[str]] = None,
|
| 913 |
+
binary_keywords: Optional[List[str]] = None,
|
| 914 |
+
object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 915 |
+
return_top_k: int = 3,
|
| 916 |
+
return_flattened_segments: Optional[bool] = None,
|
| 917 |
+
return_valid_pairs: Optional[bool] = None,
|
| 918 |
+
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 919 |
+
debug_visualizations: Optional[bool] = None,
|
| 920 |
+
) -> Dict[str, Any]:
|
| 921 |
+
with torch.no_grad():
|
| 922 |
+
outputs = self.forward(
|
| 923 |
+
video_frames=video_frames,
|
| 924 |
+
masks=masks,
|
| 925 |
+
bboxes=bboxes,
|
| 926 |
+
categorical_keywords=categorical_keywords,
|
| 927 |
+
unary_keywords=unary_keywords,
|
| 928 |
+
binary_keywords=binary_keywords,
|
| 929 |
+
object_pairs=object_pairs,
|
| 930 |
+
return_flattened_segments=return_flattened_segments,
|
| 931 |
+
return_valid_pairs=return_valid_pairs,
|
| 932 |
+
interested_object_pairs=interested_object_pairs,
|
| 933 |
+
debug_visualizations=debug_visualizations,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
|
| 937 |
+
for (obj_id, category), prob in outputs["categorical_probs"][0].items():
|
| 938 |
+
if obj_id not in formatted_categorical:
|
| 939 |
+
formatted_categorical[obj_id] = []
|
| 940 |
+
prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
|
| 941 |
+
formatted_categorical[obj_id].append((prob_val, category))
|
| 942 |
+
|
| 943 |
+
for obj_id in formatted_categorical:
|
| 944 |
+
formatted_categorical[obj_id] = sorted(
|
| 945 |
+
formatted_categorical[obj_id], reverse=True
|
| 946 |
+
)[:return_top_k]
|
| 947 |
+
|
| 948 |
+
formatted_unary: Dict[Tuple[int, int], List[Tuple[float, str]]] = {}
|
| 949 |
+
for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items():
|
| 950 |
+
key = (frame_id, obj_id)
|
| 951 |
+
if key not in formatted_unary:
|
| 952 |
+
formatted_unary[key] = []
|
| 953 |
+
prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
|
| 954 |
+
formatted_unary[key].append((prob_val, predicate))
|
| 955 |
+
|
| 956 |
+
for key in formatted_unary:
|
| 957 |
+
formatted_unary[key] = sorted(
|
| 958 |
+
formatted_unary[key], reverse=True
|
| 959 |
+
)[:return_top_k]
|
| 960 |
+
|
| 961 |
+
formatted_binary: Dict[Tuple[int, Tuple[int, int]], List[Tuple[float, str]]] = {}
|
| 962 |
+
if len(outputs["binary_probs"]) > 0:
|
| 963 |
+
for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items():
|
| 964 |
+
key = (frame_id, obj_pair)
|
| 965 |
+
if key not in formatted_binary:
|
| 966 |
+
formatted_binary[key] = []
|
| 967 |
+
prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
|
| 968 |
+
formatted_binary[key].append((prob_val, predicate))
|
| 969 |
+
|
| 970 |
+
for key in formatted_binary:
|
| 971 |
+
formatted_binary[key] = sorted(
|
| 972 |
+
formatted_binary[key], reverse=True
|
| 973 |
+
)[:return_top_k]
|
| 974 |
+
|
| 975 |
+
def max_conf(d: Dict[Any, List[Tuple[float, str]]]) -> float:
|
| 976 |
+
if not d:
|
| 977 |
+
return 0.0
|
| 978 |
+
return max(
|
| 979 |
+
(max((p for p, _ in preds), default=0.0) for preds in d.values()),
|
| 980 |
+
default=0.0,
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
result: Dict[str, Any] = {
|
| 984 |
+
"categorical_predictions": formatted_categorical,
|
| 985 |
+
"unary_predictions": formatted_unary,
|
| 986 |
+
"binary_predictions": formatted_binary,
|
| 987 |
+
"confidence_scores": {
|
| 988 |
+
"categorical": max_conf(formatted_categorical),
|
| 989 |
+
"unary": max_conf(formatted_unary),
|
| 990 |
+
"binary": max_conf(formatted_binary),
|
| 991 |
+
},
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
if "flattened_segments" in outputs:
|
| 995 |
+
result["flattened_segments"] = outputs["flattened_segments"]
|
| 996 |
+
if "valid_pairs" in outputs:
|
| 997 |
+
result["valid_pairs"] = outputs["valid_pairs"]
|
| 998 |
+
if "valid_pairs_metadata" in outputs:
|
| 999 |
+
result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"]
|
| 1000 |
+
|
| 1001 |
+
return result
|
vine_hf/vine_pipeline.py
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import hashlib
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Tuple, Optional, Any, Union
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import Pipeline
|
| 12 |
+
|
| 13 |
+
from .vine_config import VineConfig
|
| 14 |
+
from .vine_model import VineModel
|
| 15 |
+
from .vis_utils import render_dino_frames, render_sam_frames, render_vine_frame_sets
|
| 16 |
+
from laser.loading import load_video
|
| 17 |
+
from laser.preprocess.mask_generation_grounding_dino import generate_masks_grounding_dino
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VinePipeline(Pipeline):
|
| 21 |
+
"""
|
| 22 |
+
Pipeline for VINE model that handles end-to-end video understanding.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
sam_config_path: Optional[str] = None,
|
| 28 |
+
sam_checkpoint_path: Optional[str] = None,
|
| 29 |
+
gd_config_path: Optional[str] = None,
|
| 30 |
+
gd_checkpoint_path: Optional[str] = None,
|
| 31 |
+
**kwargs: Any,
|
| 32 |
+
):
|
| 33 |
+
self.grounding_model = None
|
| 34 |
+
self.sam_predictor = None
|
| 35 |
+
self.mask_generator = None
|
| 36 |
+
|
| 37 |
+
self.sam_config_path = sam_config_path
|
| 38 |
+
self.sam_checkpoint_path = sam_checkpoint_path
|
| 39 |
+
self.gd_config_path = gd_config_path
|
| 40 |
+
self.gd_checkpoint_path = gd_checkpoint_path
|
| 41 |
+
|
| 42 |
+
super().__init__(**kwargs)
|
| 43 |
+
|
| 44 |
+
self.segmentation_method = getattr(
|
| 45 |
+
self.model.config, "segmentation_method", "grounding_dino_sam2"
|
| 46 |
+
)
|
| 47 |
+
self.box_threshold = getattr(self.model.config, "box_threshold", 0.35)
|
| 48 |
+
self.text_threshold = getattr(self.model.config, "text_threshold", 0.25)
|
| 49 |
+
self.target_fps = getattr(self.model.config, "target_fps", 1)
|
| 50 |
+
self.visualize = getattr(self.model.config, "visualize", False)
|
| 51 |
+
self.visualization_dir = getattr(self.model.config, "visualization_dir", None)
|
| 52 |
+
self.debug_visualizations = getattr(
|
| 53 |
+
self.model.config, "debug_visualizations", False
|
| 54 |
+
)
|
| 55 |
+
self._device = getattr(self.model.config, "_device")
|
| 56 |
+
if kwargs.get("device") is not None:
|
| 57 |
+
self._device = kwargs.get("device")
|
| 58 |
+
|
| 59 |
+
# ------------------------------------------------------------------ #
|
| 60 |
+
# Segmentation model injection
|
| 61 |
+
# ------------------------------------------------------------------ #
|
| 62 |
+
def set_segmentation_models(
|
| 63 |
+
self,
|
| 64 |
+
*,
|
| 65 |
+
sam_predictor=None,
|
| 66 |
+
mask_generator=None,
|
| 67 |
+
grounding_model=None,
|
| 68 |
+
):
|
| 69 |
+
if sam_predictor is not None:
|
| 70 |
+
self.sam_predictor = sam_predictor
|
| 71 |
+
if mask_generator is not None:
|
| 72 |
+
self.mask_generator = mask_generator
|
| 73 |
+
if grounding_model is not None:
|
| 74 |
+
self.grounding_model = grounding_model
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------------ #
|
| 77 |
+
# Pipeline protocol
|
| 78 |
+
# ------------------------------------------------------------------ #
|
| 79 |
+
def _sanitize_parameters(self, **kwargs: Any):
|
| 80 |
+
preprocess_kwargs: Dict[str, Any] = {}
|
| 81 |
+
forward_kwargs: Dict[str, Any] = {}
|
| 82 |
+
postprocess_kwargs: Dict[str, Any] = {}
|
| 83 |
+
|
| 84 |
+
if "segmentation_method" in kwargs:
|
| 85 |
+
preprocess_kwargs["segmentation_method"] = kwargs["segmentation_method"]
|
| 86 |
+
if "target_fps" in kwargs:
|
| 87 |
+
preprocess_kwargs["target_fps"] = kwargs["target_fps"]
|
| 88 |
+
if "box_threshold" in kwargs:
|
| 89 |
+
preprocess_kwargs["box_threshold"] = kwargs["box_threshold"]
|
| 90 |
+
if "text_threshold" in kwargs:
|
| 91 |
+
preprocess_kwargs["text_threshold"] = kwargs["text_threshold"]
|
| 92 |
+
if "categorical_keywords" in kwargs:
|
| 93 |
+
preprocess_kwargs["categorical_keywords"] = kwargs["categorical_keywords"]
|
| 94 |
+
|
| 95 |
+
if "categorical_keywords" in kwargs:
|
| 96 |
+
forward_kwargs["categorical_keywords"] = kwargs["categorical_keywords"]
|
| 97 |
+
if "unary_keywords" in kwargs:
|
| 98 |
+
forward_kwargs["unary_keywords"] = kwargs["unary_keywords"]
|
| 99 |
+
if "binary_keywords" in kwargs:
|
| 100 |
+
forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
|
| 101 |
+
if "object_pairs" in kwargs:
|
| 102 |
+
forward_kwargs["object_pairs"] = kwargs["object_pairs"]
|
| 103 |
+
if "return_flattened_segments" in kwargs:
|
| 104 |
+
forward_kwargs["return_flattened_segments"] = kwargs[
|
| 105 |
+
"return_flattened_segments"
|
| 106 |
+
]
|
| 107 |
+
if "return_valid_pairs" in kwargs:
|
| 108 |
+
forward_kwargs["return_valid_pairs"] = kwargs["return_valid_pairs"]
|
| 109 |
+
if "interested_object_pairs" in kwargs:
|
| 110 |
+
forward_kwargs["interested_object_pairs"] = kwargs[
|
| 111 |
+
"interested_object_pairs"
|
| 112 |
+
]
|
| 113 |
+
if "debug_visualizations" in kwargs:
|
| 114 |
+
forward_kwargs["debug_visualizations"] = kwargs["debug_visualizations"]
|
| 115 |
+
postprocess_kwargs["debug_visualizations"] = kwargs["debug_visualizations"]
|
| 116 |
+
|
| 117 |
+
if "return_top_k" in kwargs:
|
| 118 |
+
postprocess_kwargs["return_top_k"] = kwargs["return_top_k"]
|
| 119 |
+
if "self.visualize" in kwargs:
|
| 120 |
+
postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
|
| 121 |
+
|
| 122 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
| 123 |
+
|
| 124 |
+
# ------------------------------------------------------------------ #
|
| 125 |
+
# Preprocess: video + segmentation
|
| 126 |
+
# ------------------------------------------------------------------ #
|
| 127 |
+
def preprocess(
|
| 128 |
+
self,
|
| 129 |
+
video_input: Union[str, np.ndarray, torch.Tensor],
|
| 130 |
+
segmentation_method: Optional[str] = None,
|
| 131 |
+
target_fps: Optional[int] = None,
|
| 132 |
+
box_threshold: Optional[float] = None,
|
| 133 |
+
text_threshold: Optional[float] = None,
|
| 134 |
+
categorical_keywords: Optional[List[str]] = None,
|
| 135 |
+
**kwargs: Any,
|
| 136 |
+
) -> Dict[str, Any]:
|
| 137 |
+
if segmentation_method is None:
|
| 138 |
+
segmentation_method = self.segmentation_method
|
| 139 |
+
if target_fps is None:
|
| 140 |
+
target_fps = self.target_fps
|
| 141 |
+
else:
|
| 142 |
+
self.target_fps = target_fps
|
| 143 |
+
if box_threshold is None:
|
| 144 |
+
box_threshold = self.box_threshold
|
| 145 |
+
else:
|
| 146 |
+
self.box_threshold = box_threshold
|
| 147 |
+
if text_threshold is None:
|
| 148 |
+
text_threshold = self.text_threshold
|
| 149 |
+
else:
|
| 150 |
+
self.text_threshold = text_threshold
|
| 151 |
+
if categorical_keywords is None:
|
| 152 |
+
categorical_keywords = ["object"]
|
| 153 |
+
|
| 154 |
+
if isinstance(video_input, str):
|
| 155 |
+
video_tensor = load_video(video_input, target_fps=target_fps)
|
| 156 |
+
if isinstance(video_tensor, list):
|
| 157 |
+
video_tensor = np.array(video_tensor)
|
| 158 |
+
elif isinstance(video_tensor, torch.Tensor):
|
| 159 |
+
video_tensor = video_tensor.cpu().numpy()
|
| 160 |
+
elif isinstance(video_input, (np.ndarray, torch.Tensor)):
|
| 161 |
+
if isinstance(video_input, torch.Tensor):
|
| 162 |
+
video_tensor = video_input.numpy()
|
| 163 |
+
else:
|
| 164 |
+
video_tensor = video_input
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Unsupported video input type: {type(video_input)}")
|
| 167 |
+
|
| 168 |
+
if not isinstance(video_tensor, np.ndarray):
|
| 169 |
+
video_tensor = np.array(video_tensor)
|
| 170 |
+
|
| 171 |
+
if len(video_tensor.shape) != 4:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Expected video tensor shape (frames, height, width, channels), got {video_tensor.shape}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
visualization_data: Dict[str, Any] = {}
|
| 177 |
+
print(f"Segmentation method: {segmentation_method}")
|
| 178 |
+
if segmentation_method == "sam2":
|
| 179 |
+
masks, bboxes, vis_data = self._generate_sam2_masks(video_tensor)
|
| 180 |
+
elif segmentation_method == "grounding_dino_sam2":
|
| 181 |
+
masks, bboxes, vis_data = self._generate_grounding_dino_sam2_masks(
|
| 182 |
+
video_tensor,
|
| 183 |
+
categorical_keywords,
|
| 184 |
+
box_threshold,
|
| 185 |
+
text_threshold,
|
| 186 |
+
video_input,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError(f"Unsupported segmentation method: {segmentation_method}")
|
| 190 |
+
if vis_data:
|
| 191 |
+
visualization_data.update(vis_data)
|
| 192 |
+
visualization_data.setdefault("sam_masks", masks)
|
| 193 |
+
|
| 194 |
+
return {
|
| 195 |
+
"video_frames": torch.tensor(video_tensor),
|
| 196 |
+
"masks": masks,
|
| 197 |
+
"bboxes": bboxes,
|
| 198 |
+
"num_frames": len(video_tensor),
|
| 199 |
+
"visualization_data": visualization_data,
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
# ------------------------------------------------------------------ #
|
| 203 |
+
# Segmentation helpers
|
| 204 |
+
# ------------------------------------------------------------------ #
|
| 205 |
+
def _generate_sam2_masks(
|
| 206 |
+
self, video_tensor: np.ndarray
|
| 207 |
+
) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]:
|
| 208 |
+
print("Generating SAM2 masks...")
|
| 209 |
+
if self.mask_generator is None:
|
| 210 |
+
self._initialize_segmentation_models()
|
| 211 |
+
if self.mask_generator is None:
|
| 212 |
+
raise ValueError("SAM2 mask generator not available")
|
| 213 |
+
|
| 214 |
+
masks: Dict[int, Dict[int, torch.Tensor]] = {}
|
| 215 |
+
bboxes: Dict[int, Dict[int, List[int]]] = {}
|
| 216 |
+
|
| 217 |
+
for frame_id, frame in enumerate(video_tensor):
|
| 218 |
+
if isinstance(frame, np.ndarray) and frame.dtype != np.uint8:
|
| 219 |
+
frame = (
|
| 220 |
+
(frame * 255).astype(np.uint8)
|
| 221 |
+
if frame.max() <= 1
|
| 222 |
+
else frame.astype(np.uint8)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
frame_masks = self.mask_generator.generate(frame)
|
| 226 |
+
|
| 227 |
+
masks[frame_id] = {}
|
| 228 |
+
bboxes[frame_id] = {}
|
| 229 |
+
|
| 230 |
+
for obj_id, mask_data in enumerate(frame_masks):
|
| 231 |
+
mask = mask_data["segmentation"]
|
| 232 |
+
if isinstance(mask, np.ndarray):
|
| 233 |
+
mask = torch.from_numpy(mask)
|
| 234 |
+
|
| 235 |
+
if len(mask.shape) == 2:
|
| 236 |
+
mask = mask.unsqueeze(-1)
|
| 237 |
+
elif len(mask.shape) == 3 and mask.shape[0] == 1:
|
| 238 |
+
mask = mask.permute(1, 2, 0)
|
| 239 |
+
|
| 240 |
+
wrapped_id = obj_id + 1
|
| 241 |
+
masks[frame_id][wrapped_id] = mask
|
| 242 |
+
|
| 243 |
+
mask_np = (
|
| 244 |
+
mask.squeeze().numpy()
|
| 245 |
+
if isinstance(mask, torch.Tensor)
|
| 246 |
+
else mask.squeeze()
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
coords = np.where(mask_np > 0)
|
| 250 |
+
if len(coords[0]) > 0:
|
| 251 |
+
y1, y2 = coords[0].min(), coords[0].max()
|
| 252 |
+
x1, x2 = coords[1].min(), coords[1].max()
|
| 253 |
+
bboxes[frame_id][wrapped_id] = [x1, y1, x2, y2]
|
| 254 |
+
|
| 255 |
+
tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes)
|
| 256 |
+
return tracked_masks, tracked_bboxes, {"sam_masks": tracked_masks}
|
| 257 |
+
|
| 258 |
+
def _generate_grounding_dino_sam2_masks(
|
| 259 |
+
self,
|
| 260 |
+
video_tensor: np.ndarray,
|
| 261 |
+
categorical_keywords: List[str],
|
| 262 |
+
box_threshold: float,
|
| 263 |
+
text_threshold: float,
|
| 264 |
+
video_path: Union[str, None],
|
| 265 |
+
) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]:
|
| 266 |
+
print("Generating Grounding DINO + SAM2 masks...")
|
| 267 |
+
if self.grounding_model is None or self.sam_predictor is None:
|
| 268 |
+
self._initialize_segmentation_models()
|
| 269 |
+
if self.grounding_model is None or self.sam_predictor is None:
|
| 270 |
+
raise ValueError("GroundingDINO or SAM2 models not available")
|
| 271 |
+
|
| 272 |
+
temp_video_path = None
|
| 273 |
+
if video_path is None or not isinstance(video_path, str):
|
| 274 |
+
temp_video_path = self._create_temp_video(video_tensor)
|
| 275 |
+
video_path = temp_video_path
|
| 276 |
+
|
| 277 |
+
CHUNK = 5
|
| 278 |
+
classes_ls = [
|
| 279 |
+
categorical_keywords[i : i + CHUNK]
|
| 280 |
+
for i in range(0, len(categorical_keywords), CHUNK)
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
base_name = Path(video_path).stem
|
| 284 |
+
fps_tag = f"fps{int(self.target_fps)}"
|
| 285 |
+
path_hash = hashlib.md5(video_path.encode("utf-8")).hexdigest()[:8]
|
| 286 |
+
video_cache_name = f"{base_name}_{fps_tag}_{path_hash}"
|
| 287 |
+
|
| 288 |
+
video_segments, oid_class_pred, _ = generate_masks_grounding_dino(
|
| 289 |
+
self.grounding_model,
|
| 290 |
+
box_threshold,
|
| 291 |
+
text_threshold,
|
| 292 |
+
self.sam_predictor,
|
| 293 |
+
self.mask_generator,
|
| 294 |
+
video_tensor,
|
| 295 |
+
video_path,
|
| 296 |
+
video_cache_name,
|
| 297 |
+
out_dir=tempfile.gettempdir(),
|
| 298 |
+
classes_ls=classes_ls,
|
| 299 |
+
target_fps=self.target_fps,
|
| 300 |
+
visualize=self.debug_visualizations,
|
| 301 |
+
frames=None,
|
| 302 |
+
max_prop_time=2,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
masks: Dict[int, Dict[int, torch.Tensor]] = {}
|
| 306 |
+
bboxes: Dict[int, Dict[int, List[int]]] = {}
|
| 307 |
+
|
| 308 |
+
for frame_id, frame_masks in video_segments.items():
|
| 309 |
+
masks[frame_id] = {}
|
| 310 |
+
bboxes[frame_id] = {}
|
| 311 |
+
|
| 312 |
+
for obj_id, mask in frame_masks.items():
|
| 313 |
+
if not isinstance(mask, torch.Tensor):
|
| 314 |
+
mask = torch.tensor(mask)
|
| 315 |
+
masks[frame_id][obj_id] = mask
|
| 316 |
+
mask_np = mask.numpy()
|
| 317 |
+
if mask_np.ndim == 3 and mask_np.shape[0] == 1:
|
| 318 |
+
mask_np = np.squeeze(mask_np, axis=0)
|
| 319 |
+
|
| 320 |
+
coords = np.where(mask_np > 0)
|
| 321 |
+
if len(coords[0]) > 0:
|
| 322 |
+
y1, y2 = coords[0].min(), coords[0].max()
|
| 323 |
+
x1, x2 = coords[1].min(), coords[1].max()
|
| 324 |
+
bboxes[frame_id][obj_id] = [x1, y1, x2, y2]
|
| 325 |
+
|
| 326 |
+
if temp_video_path and os.path.exists(temp_video_path):
|
| 327 |
+
os.remove(temp_video_path)
|
| 328 |
+
|
| 329 |
+
tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes)
|
| 330 |
+
|
| 331 |
+
vis_data: Dict[str, Any] = {
|
| 332 |
+
"sam_masks": tracked_masks,
|
| 333 |
+
"dino_labels": oid_class_pred,
|
| 334 |
+
}
|
| 335 |
+
return tracked_masks, tracked_bboxes, vis_data
|
| 336 |
+
|
| 337 |
+
# ------------------------------------------------------------------ #
|
| 338 |
+
# ID tracking across frames
|
| 339 |
+
# ------------------------------------------------------------------ #
|
| 340 |
+
def _bbox_iou(self, box1: List[int], box2: List[int]) -> float:
|
| 341 |
+
x1, y1, x2, y2 = box1
|
| 342 |
+
x1b, y1b, x2b, y2b = box2
|
| 343 |
+
ix1 = max(x1, x1b)
|
| 344 |
+
iy1 = max(y1, y1b)
|
| 345 |
+
ix2 = min(x2, x2b)
|
| 346 |
+
iy2 = min(y2, y2b)
|
| 347 |
+
iw = max(0, ix2 - ix1)
|
| 348 |
+
ih = max(0, iy2 - iy1)
|
| 349 |
+
inter = iw * ih
|
| 350 |
+
if inter <= 0:
|
| 351 |
+
return 0.0
|
| 352 |
+
area1 = max(0, x2 - x1) * max(0, y2 - y1)
|
| 353 |
+
area2 = max(0, x2b - x1b) * max(0, y2b - y1b)
|
| 354 |
+
union = area1 + area2 - inter
|
| 355 |
+
if union <= 0:
|
| 356 |
+
return 0.0
|
| 357 |
+
return inter / union
|
| 358 |
+
|
| 359 |
+
def _track_ids_across_frames(
|
| 360 |
+
self,
|
| 361 |
+
masks: Dict[int, Dict[int, torch.Tensor]],
|
| 362 |
+
bboxes: Dict[int, Dict[int, List[int]]],
|
| 363 |
+
iou_threshold: float = 0.3,
|
| 364 |
+
) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]]]:
|
| 365 |
+
frame_ids = sorted(masks.keys())
|
| 366 |
+
tracked_masks: Dict[int, Dict[int, torch.Tensor]] = {}
|
| 367 |
+
tracked_bboxes: Dict[int, Dict[int, List[int]]] = {}
|
| 368 |
+
next_track_id = 0
|
| 369 |
+
prev_tracks: Dict[int, List[int]] = {}
|
| 370 |
+
|
| 371 |
+
for frame_id in frame_ids:
|
| 372 |
+
frame_masks = masks.get(frame_id, {})
|
| 373 |
+
frame_boxes = bboxes.get(frame_id, {})
|
| 374 |
+
tracked_masks[frame_id] = {}
|
| 375 |
+
tracked_bboxes[frame_id] = {}
|
| 376 |
+
|
| 377 |
+
if not frame_boxes:
|
| 378 |
+
prev_tracks = {}
|
| 379 |
+
continue
|
| 380 |
+
|
| 381 |
+
det_ids = list(frame_boxes.keys())
|
| 382 |
+
prev_ids = list(prev_tracks.keys())
|
| 383 |
+
|
| 384 |
+
candidates: List[Tuple[float, int, int]] = []
|
| 385 |
+
for tid in prev_ids:
|
| 386 |
+
prev_box = prev_tracks[tid]
|
| 387 |
+
for det_id in det_ids:
|
| 388 |
+
iou = self._bbox_iou(prev_box, frame_boxes[det_id])
|
| 389 |
+
if iou > iou_threshold:
|
| 390 |
+
candidates.append((iou, tid, det_id))
|
| 391 |
+
candidates.sort(reverse=True)
|
| 392 |
+
|
| 393 |
+
matched_prev = set()
|
| 394 |
+
matched_det = set()
|
| 395 |
+
|
| 396 |
+
for iou, tid, det_id in candidates:
|
| 397 |
+
if tid in matched_prev or det_id in matched_det:
|
| 398 |
+
continue
|
| 399 |
+
matched_prev.add(tid)
|
| 400 |
+
matched_det.add(det_id)
|
| 401 |
+
tracked_masks[frame_id][tid] = frame_masks[det_id]
|
| 402 |
+
tracked_bboxes[frame_id][tid] = frame_boxes[det_id]
|
| 403 |
+
|
| 404 |
+
for det_id in det_ids:
|
| 405 |
+
if det_id in matched_det:
|
| 406 |
+
continue
|
| 407 |
+
tid = next_track_id
|
| 408 |
+
next_track_id += 1
|
| 409 |
+
tracked_masks[frame_id][tid] = frame_masks[det_id]
|
| 410 |
+
tracked_bboxes[frame_id][tid] = frame_boxes[det_id]
|
| 411 |
+
|
| 412 |
+
prev_tracks = {
|
| 413 |
+
tid: tracked_bboxes[frame_id][tid]
|
| 414 |
+
for tid in tracked_bboxes[frame_id].keys()
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
return tracked_masks, tracked_bboxes
|
| 418 |
+
|
| 419 |
+
# ------------------------------------------------------------------ #
|
| 420 |
+
# Segmentation model initialization
|
| 421 |
+
# ------------------------------------------------------------------ #
|
| 422 |
+
def _initialize_segmentation_models(self):
|
| 423 |
+
if self.sam_predictor is None or self.mask_generator is None:
|
| 424 |
+
self._initialize_sam2_models()
|
| 425 |
+
if self.grounding_model is None:
|
| 426 |
+
self._initialize_grounding_dino_model()
|
| 427 |
+
|
| 428 |
+
def _initialize_sam2_models(self):
|
| 429 |
+
try:
|
| 430 |
+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
| 431 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 432 |
+
except ImportError as e:
|
| 433 |
+
print(f"Warning: Could not import SAM2: {e}")
|
| 434 |
+
return
|
| 435 |
+
|
| 436 |
+
config_path, checkpoint_path = self._resolve_sam2_paths()
|
| 437 |
+
|
| 438 |
+
if self.sam_config_path is not None and not os.path.exists(config_path):
|
| 439 |
+
raise ValueError(f"SAM2 config path not found: {config_path}")
|
| 440 |
+
if self.sam_checkpoint_path is not None and not os.path.exists(checkpoint_path):
|
| 441 |
+
raise ValueError(f"SAM2 checkpoint path not found: {checkpoint_path}")
|
| 442 |
+
|
| 443 |
+
if not os.path.exists(checkpoint_path):
|
| 444 |
+
print(f"Warning: SAM2 checkpoint not found at {checkpoint_path}")
|
| 445 |
+
print("SAM2 functionality will be unavailable")
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
try:
|
| 449 |
+
device = self._device
|
| 450 |
+
self.sam_predictor = build_sam2_video_predictor(
|
| 451 |
+
config_path, checkpoint_path, device=device
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
sam2_model = build_sam2(
|
| 455 |
+
config_path,
|
| 456 |
+
checkpoint_path,
|
| 457 |
+
device=device,
|
| 458 |
+
apply_postprocessing=False,
|
| 459 |
+
)
|
| 460 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
| 461 |
+
model=sam2_model,
|
| 462 |
+
points_per_side=32,
|
| 463 |
+
points_per_batch=32,
|
| 464 |
+
pred_iou_thresh=0.7,
|
| 465 |
+
stability_score_thresh=0.8,
|
| 466 |
+
crop_n_layers=2,
|
| 467 |
+
box_nms_thresh=0.6,
|
| 468 |
+
crop_n_points_downscale_factor=2,
|
| 469 |
+
min_mask_region_area=100,
|
| 470 |
+
use_m2m=True,
|
| 471 |
+
)
|
| 472 |
+
print("✓ SAM2 models initialized successfully")
|
| 473 |
+
|
| 474 |
+
except Exception as e:
|
| 475 |
+
raise ValueError(f"Failed to initialize SAM2 with custom paths: {e}")
|
| 476 |
+
|
| 477 |
+
def _initialize_grounding_dino_model(self):
|
| 478 |
+
try:
|
| 479 |
+
from groundingdino.util.inference import Model as gd_Model
|
| 480 |
+
except ImportError as e:
|
| 481 |
+
print(f"Warning: Could not import GroundingDINO: {e}")
|
| 482 |
+
return
|
| 483 |
+
|
| 484 |
+
config_path, checkpoint_path = self._resolve_grounding_dino_paths()
|
| 485 |
+
|
| 486 |
+
if self.gd_config_path is not None and not os.path.exists(config_path):
|
| 487 |
+
raise ValueError(f"GroundingDINO config path not found: {config_path}")
|
| 488 |
+
if self.gd_checkpoint_path is not None and not os.path.exists(checkpoint_path):
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"GroundingDINO checkpoint path not found: {checkpoint_path}"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
if not (os.path.exists(config_path) and os.path.exists(checkpoint_path)):
|
| 494 |
+
print(
|
| 495 |
+
f"Warning: GroundingDINO models not found at {config_path} / {checkpoint_path}"
|
| 496 |
+
)
|
| 497 |
+
print("GroundingDINO functionality will be unavailable")
|
| 498 |
+
return
|
| 499 |
+
|
| 500 |
+
try:
|
| 501 |
+
device = self._device
|
| 502 |
+
self.grounding_model = gd_Model(
|
| 503 |
+
model_config_path=config_path,
|
| 504 |
+
model_checkpoint_path=checkpoint_path,
|
| 505 |
+
device=device,
|
| 506 |
+
)
|
| 507 |
+
print("✓ GroundingDINO model initialized successfully")
|
| 508 |
+
|
| 509 |
+
except Exception as e:
|
| 510 |
+
raise ValueError(f"Failed to initialize GroundingDINO with custom paths: {e}")
|
| 511 |
+
|
| 512 |
+
def _resolve_sam2_paths(self):
|
| 513 |
+
if self.sam_config_path and self.sam_checkpoint_path:
|
| 514 |
+
return self.sam_config_path, self.sam_checkpoint_path
|
| 515 |
+
|
| 516 |
+
def _resolve_grounding_dino_paths(self):
|
| 517 |
+
if self.gd_config_path and self.gd_checkpoint_path:
|
| 518 |
+
return self.gd_config_path, self.gd_checkpoint_path
|
| 519 |
+
|
| 520 |
+
# ------------------------------------------------------------------ #
|
| 521 |
+
# Video writing helpers
|
| 522 |
+
# ------------------------------------------------------------------ #
|
| 523 |
+
def _prepare_visualization_dir(self, name: str, enabled: bool) -> Optional[str]:
|
| 524 |
+
if not enabled:
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
if self.visualization_dir:
|
| 528 |
+
target_dir = (
|
| 529 |
+
os.path.join(self.visualization_dir, name)
|
| 530 |
+
if name
|
| 531 |
+
else self.visualization_dir
|
| 532 |
+
)
|
| 533 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 534 |
+
return target_dir
|
| 535 |
+
|
| 536 |
+
return tempfile.mkdtemp(prefix=f"vine_{name}_")
|
| 537 |
+
|
| 538 |
+
def _create_temp_video(
|
| 539 |
+
self,
|
| 540 |
+
video_tensor: np.ndarray,
|
| 541 |
+
base_dir: Optional[str] = None,
|
| 542 |
+
prefix: str = "temp_video",
|
| 543 |
+
) -> str:
|
| 544 |
+
import subprocess
|
| 545 |
+
|
| 546 |
+
if base_dir is None:
|
| 547 |
+
base_dir = tempfile.mkdtemp(prefix=f"vine_{prefix}_")
|
| 548 |
+
else:
|
| 549 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 550 |
+
file_name = f"{prefix}_{uuid.uuid4().hex}.mp4"
|
| 551 |
+
temp_path = os.path.join(base_dir, file_name)
|
| 552 |
+
|
| 553 |
+
height, width = video_tensor.shape[1:3]
|
| 554 |
+
processing_fps = max(1, self.target_fps)
|
| 555 |
+
output_fps = processing_fps
|
| 556 |
+
video_tensor_for_output = video_tensor
|
| 557 |
+
|
| 558 |
+
ffmpeg_success = False
|
| 559 |
+
try:
|
| 560 |
+
ffmpeg_success = self._create_video_with_ffmpeg(
|
| 561 |
+
video_tensor_for_output, temp_path, output_fps, width, height
|
| 562 |
+
)
|
| 563 |
+
except Exception as e:
|
| 564 |
+
print(f"FFmpeg method failed: {e}")
|
| 565 |
+
|
| 566 |
+
if not ffmpeg_success:
|
| 567 |
+
print("Using OpenCV fallback")
|
| 568 |
+
self._create_temp_video_opencv(
|
| 569 |
+
video_tensor_for_output, temp_path, output_fps, width, height
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
return temp_path
|
| 573 |
+
|
| 574 |
+
def _create_video_with_ffmpeg(
|
| 575 |
+
self, video_tensor: np.ndarray, output_path: str, fps: int, width: int, height: int
|
| 576 |
+
) -> bool:
|
| 577 |
+
import subprocess
|
| 578 |
+
|
| 579 |
+
try:
|
| 580 |
+
ffmpeg_cmd = [
|
| 581 |
+
"ffmpeg",
|
| 582 |
+
"-y",
|
| 583 |
+
"-f",
|
| 584 |
+
"rawvideo",
|
| 585 |
+
"-vcodec",
|
| 586 |
+
"rawvideo",
|
| 587 |
+
"-s",
|
| 588 |
+
f"{width}x{height}",
|
| 589 |
+
"-pix_fmt",
|
| 590 |
+
"rgb24",
|
| 591 |
+
"-r",
|
| 592 |
+
str(fps),
|
| 593 |
+
"-i",
|
| 594 |
+
"pipe:0",
|
| 595 |
+
"-c:v",
|
| 596 |
+
"libx264",
|
| 597 |
+
"-preset",
|
| 598 |
+
"fast",
|
| 599 |
+
"-crf",
|
| 600 |
+
"23",
|
| 601 |
+
"-pix_fmt",
|
| 602 |
+
"yuv420p",
|
| 603 |
+
"-movflags",
|
| 604 |
+
"+faststart",
|
| 605 |
+
"-loglevel",
|
| 606 |
+
"error",
|
| 607 |
+
output_path,
|
| 608 |
+
]
|
| 609 |
+
|
| 610 |
+
process = subprocess.Popen(
|
| 611 |
+
ffmpeg_cmd,
|
| 612 |
+
stdin=subprocess.PIPE,
|
| 613 |
+
stdout=subprocess.PIPE,
|
| 614 |
+
stderr=subprocess.PIPE,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
frame_data = b""
|
| 618 |
+
for frame in video_tensor:
|
| 619 |
+
if frame.dtype != np.uint8:
|
| 620 |
+
frame = (
|
| 621 |
+
(frame * 255).astype(np.uint8)
|
| 622 |
+
if frame.max() <= 1
|
| 623 |
+
else frame.astype(np.uint8)
|
| 624 |
+
)
|
| 625 |
+
frame_data += frame.tobytes()
|
| 626 |
+
|
| 627 |
+
stdout, stderr = process.communicate(input=frame_data, timeout=60)
|
| 628 |
+
|
| 629 |
+
if process.returncode == 0:
|
| 630 |
+
print(f"Video created with FFmpeg (H.264) at {fps} FPS")
|
| 631 |
+
return True
|
| 632 |
+
else:
|
| 633 |
+
error_msg = stderr.decode() if stderr else "Unknown error"
|
| 634 |
+
print(f"FFmpeg error: {error_msg}")
|
| 635 |
+
return False
|
| 636 |
+
|
| 637 |
+
except FileNotFoundError:
|
| 638 |
+
print("FFmpeg not found in PATH")
|
| 639 |
+
return False
|
| 640 |
+
except Exception as e:
|
| 641 |
+
print(f"FFmpeg exception: {e}")
|
| 642 |
+
return False
|
| 643 |
+
|
| 644 |
+
def _create_temp_video_opencv(
|
| 645 |
+
self, video_tensor: np.ndarray, temp_path: str, fps: int, width: int, height: int
|
| 646 |
+
) -> str:
|
| 647 |
+
codecs_to_try = ["avc1", "X264", "mp4v"]
|
| 648 |
+
out = None
|
| 649 |
+
used_codec = None
|
| 650 |
+
|
| 651 |
+
for codec in codecs_to_try:
|
| 652 |
+
try:
|
| 653 |
+
fourcc = cv2.VideoWriter_fourcc(*codec)
|
| 654 |
+
temp_out = cv2.VideoWriter(temp_path, fourcc, fps, (width, height))
|
| 655 |
+
|
| 656 |
+
if temp_out.isOpened():
|
| 657 |
+
out = temp_out
|
| 658 |
+
used_codec = codec
|
| 659 |
+
break
|
| 660 |
+
else:
|
| 661 |
+
temp_out.release()
|
| 662 |
+
except Exception as e:
|
| 663 |
+
print(f"Warning: Codec {codec} not available: {e}")
|
| 664 |
+
continue
|
| 665 |
+
|
| 666 |
+
if out is None or not out.isOpened():
|
| 667 |
+
raise RuntimeError(
|
| 668 |
+
f"Failed to initialize VideoWriter with any codec. Tried: {codecs_to_try}"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
print(f"Using OpenCV with codec: {used_codec}")
|
| 672 |
+
|
| 673 |
+
for frame in video_tensor:
|
| 674 |
+
if len(frame.shape) == 3 and frame.shape[2] == 3:
|
| 675 |
+
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 676 |
+
else:
|
| 677 |
+
frame_bgr = frame
|
| 678 |
+
if frame_bgr.dtype != np.uint8:
|
| 679 |
+
frame_bgr = (
|
| 680 |
+
(frame_bgr * 255).astype(np.uint8)
|
| 681 |
+
if frame_bgr.max() <= 1
|
| 682 |
+
else frame_bgr.astype(np.uint8)
|
| 683 |
+
)
|
| 684 |
+
out.write(frame_bgr)
|
| 685 |
+
|
| 686 |
+
out.release()
|
| 687 |
+
return temp_path
|
| 688 |
+
|
| 689 |
+
# ------------------------------------------------------------------ #
|
| 690 |
+
# Forward + postprocess
|
| 691 |
+
# ------------------------------------------------------------------ #
|
| 692 |
+
def _forward(self, model_inputs: Dict[str, Any], **forward_kwargs: Any) -> Dict[str, Any]:
|
| 693 |
+
outputs = self.model.predict(
|
| 694 |
+
video_frames=model_inputs["video_frames"],
|
| 695 |
+
masks=model_inputs["masks"],
|
| 696 |
+
bboxes=model_inputs["bboxes"],
|
| 697 |
+
**forward_kwargs,
|
| 698 |
+
)
|
| 699 |
+
outputs.setdefault("video_frames", model_inputs.get("video_frames"))
|
| 700 |
+
outputs.setdefault("bboxes", model_inputs.get("bboxes"))
|
| 701 |
+
outputs.setdefault("masks", model_inputs.get("masks"))
|
| 702 |
+
outputs.setdefault("visualization_data", model_inputs.get("visualization_data"))
|
| 703 |
+
return outputs
|
| 704 |
+
|
| 705 |
+
def postprocess(
|
| 706 |
+
self,
|
| 707 |
+
model_outputs: Dict[str, Any],
|
| 708 |
+
return_top_k: int = 3,
|
| 709 |
+
visualize: Optional[bool] = None,
|
| 710 |
+
**kwargs: Any,
|
| 711 |
+
) -> Dict[str, Any]:
|
| 712 |
+
results: Dict[str, Any] = {
|
| 713 |
+
"categorical_predictions": model_outputs.get("categorical_predictions", {}),
|
| 714 |
+
"unary_predictions": model_outputs.get("unary_predictions", {}),
|
| 715 |
+
"binary_predictions": model_outputs.get("binary_predictions", {}),
|
| 716 |
+
"confidence_scores": model_outputs.get("confidence_scores", {}),
|
| 717 |
+
"summary": self._generate_summary(model_outputs),
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
print("\n" + "=" * 50)
|
| 721 |
+
print("DEBUG: Raw Model Outputs - Categorical Predictions")
|
| 722 |
+
cat_preds = model_outputs.get("categorical_predictions", {})
|
| 723 |
+
for obj_id, preds in cat_preds.items():
|
| 724 |
+
print(f"Object {obj_id}: {preds}")
|
| 725 |
+
print("=" * 50 + "\n")
|
| 726 |
+
|
| 727 |
+
if "flattened_segments" in model_outputs:
|
| 728 |
+
results["flattened_segments"] = model_outputs["flattened_segments"]
|
| 729 |
+
if "valid_pairs" in model_outputs:
|
| 730 |
+
results["valid_pairs"] = model_outputs["valid_pairs"]
|
| 731 |
+
if "valid_pairs_metadata" in model_outputs:
|
| 732 |
+
results["valid_pairs_metadata"] = model_outputs["valid_pairs_metadata"]
|
| 733 |
+
if "visualization_data" in model_outputs:
|
| 734 |
+
results["visualization_data"] = model_outputs["visualization_data"]
|
| 735 |
+
|
| 736 |
+
if self.visualize and "video_frames" in model_outputs and "bboxes" in model_outputs:
|
| 737 |
+
frames_tensor = model_outputs["video_frames"]
|
| 738 |
+
if isinstance(frames_tensor, torch.Tensor):
|
| 739 |
+
frames_np = frames_tensor.detach().cpu().numpy()
|
| 740 |
+
else:
|
| 741 |
+
frames_np = np.asarray(frames_tensor)
|
| 742 |
+
if frames_np.dtype != np.uint8:
|
| 743 |
+
if np.issubdtype(frames_np.dtype, np.floating):
|
| 744 |
+
max_val = frames_np.max() if frames_np.size else 0.0
|
| 745 |
+
scale = 255.0 if max_val <= 1.0 else 1.0
|
| 746 |
+
frames_np = (frames_np * scale).clip(0, 255).astype(np.uint8)
|
| 747 |
+
else:
|
| 748 |
+
frames_np = frames_np.clip(0, 255).astype(np.uint8)
|
| 749 |
+
|
| 750 |
+
cat_label_lookup: Dict[int, Tuple[str, float]] = {}
|
| 751 |
+
for obj_id, preds in model_outputs.get("categorical_predictions", {}).items():
|
| 752 |
+
if preds:
|
| 753 |
+
prob, label = preds[0]
|
| 754 |
+
cat_label_lookup[obj_id] = (label, prob)
|
| 755 |
+
|
| 756 |
+
unary_preds = model_outputs.get("unary_predictions", {})
|
| 757 |
+
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]] = {}
|
| 758 |
+
for (frame_id, obj_id), preds in unary_preds.items():
|
| 759 |
+
if preds:
|
| 760 |
+
unary_lookup.setdefault(frame_id, {})[obj_id] = preds[:1]
|
| 761 |
+
|
| 762 |
+
binary_preds = model_outputs.get("binary_predictions", {})
|
| 763 |
+
binary_lookup: Dict[
|
| 764 |
+
int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]
|
| 765 |
+
] = {}
|
| 766 |
+
for (frame_id, obj_pair), preds in binary_preds.items():
|
| 767 |
+
if preds:
|
| 768 |
+
binary_lookup.setdefault(frame_id, []).append((obj_pair, preds[:1]))
|
| 769 |
+
|
| 770 |
+
bboxes = model_outputs["bboxes"]
|
| 771 |
+
visualization_data = model_outputs.get("visualization_data", {})
|
| 772 |
+
visualizations: Dict[str, Dict[str, Any]] = {}
|
| 773 |
+
debug_visualizations = kwargs.get("debug_visualizations")
|
| 774 |
+
if debug_visualizations is None:
|
| 775 |
+
debug_visualizations = self.debug_visualizations
|
| 776 |
+
|
| 777 |
+
vine_frame_sets = render_vine_frame_sets(
|
| 778 |
+
frames_np,
|
| 779 |
+
bboxes,
|
| 780 |
+
cat_label_lookup,
|
| 781 |
+
unary_lookup,
|
| 782 |
+
binary_lookup,
|
| 783 |
+
visualization_data.get("sam_masks"),
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
vine_visuals: Dict[str, Dict[str, Any]] = {}
|
| 787 |
+
final_frames = vine_frame_sets.get("all", [])
|
| 788 |
+
if final_frames:
|
| 789 |
+
final_entry: Dict[str, Any] = {"frames": final_frames, "video_path": None}
|
| 790 |
+
final_dir = self._prepare_visualization_dir(
|
| 791 |
+
"all", enabled=self.visualize
|
| 792 |
+
)
|
| 793 |
+
final_entry["video_path"] = self._create_temp_video(
|
| 794 |
+
np.stack(final_frames, axis=0),
|
| 795 |
+
base_dir=final_dir,
|
| 796 |
+
prefix="all_visualization",
|
| 797 |
+
)
|
| 798 |
+
vine_visuals["all"] = final_entry
|
| 799 |
+
|
| 800 |
+
if debug_visualizations:
|
| 801 |
+
sam_masks = visualization_data.get("sam_masks")
|
| 802 |
+
if sam_masks:
|
| 803 |
+
sam_frames = render_sam_frames(
|
| 804 |
+
frames_np, sam_masks, visualization_data.get("dino_labels")
|
| 805 |
+
)
|
| 806 |
+
sam_entry = {"frames": sam_frames, "video_path": None}
|
| 807 |
+
if sam_frames:
|
| 808 |
+
sam_dir = self._prepare_visualization_dir(
|
| 809 |
+
"sam", enabled=self.visualize
|
| 810 |
+
)
|
| 811 |
+
sam_entry["video_path"] = self._create_temp_video(
|
| 812 |
+
np.stack(sam_frames, axis=0),
|
| 813 |
+
base_dir=sam_dir,
|
| 814 |
+
prefix="sam_visualization",
|
| 815 |
+
)
|
| 816 |
+
visualizations["sam"] = sam_entry
|
| 817 |
+
|
| 818 |
+
dino_labels = visualization_data.get("dino_labels")
|
| 819 |
+
if dino_labels:
|
| 820 |
+
dino_frames = render_dino_frames(frames_np, bboxes, dino_labels)
|
| 821 |
+
dino_entry = {"frames": dino_frames, "video_path": None}
|
| 822 |
+
if dino_frames:
|
| 823 |
+
dino_dir = self._prepare_visualization_dir(
|
| 824 |
+
"dino", enabled=self.visualize
|
| 825 |
+
)
|
| 826 |
+
dino_entry["video_path"] = self._create_temp_video(
|
| 827 |
+
np.stack(dino_frames, axis=0),
|
| 828 |
+
base_dir=dino_dir,
|
| 829 |
+
prefix="dino_visualization",
|
| 830 |
+
)
|
| 831 |
+
visualizations["dino"] = dino_entry
|
| 832 |
+
|
| 833 |
+
for name in ("object", "unary", "binary"):
|
| 834 |
+
frames_list = vine_frame_sets.get(name, [])
|
| 835 |
+
entry: Dict[str, Any] = {"frames": frames_list, "video_path": None}
|
| 836 |
+
if frames_list:
|
| 837 |
+
vine_dir = self._prepare_visualization_dir(
|
| 838 |
+
name, enabled=self.visualize
|
| 839 |
+
)
|
| 840 |
+
entry["video_path"] = self._create_temp_video(
|
| 841 |
+
np.stack(frames_list, axis=0),
|
| 842 |
+
base_dir=vine_dir,
|
| 843 |
+
prefix=f"{name}_visualization",
|
| 844 |
+
)
|
| 845 |
+
vine_visuals[name] = entry
|
| 846 |
+
|
| 847 |
+
if vine_visuals:
|
| 848 |
+
visualizations["vine"] = vine_visuals
|
| 849 |
+
|
| 850 |
+
if visualizations:
|
| 851 |
+
results["visualizations"] = visualizations
|
| 852 |
+
|
| 853 |
+
return results
|
| 854 |
+
|
| 855 |
+
# ------------------------------------------------------------------ #
|
| 856 |
+
# Summary JSON
|
| 857 |
+
# ------------------------------------------------------------------ #
|
| 858 |
+
def _generate_summary(self, model_outputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 859 |
+
"""
|
| 860 |
+
Per-object summary:
|
| 861 |
+
{
|
| 862 |
+
"num_objects_detected": N,
|
| 863 |
+
"objects": {
|
| 864 |
+
"<obj_id>": {
|
| 865 |
+
"top_categories": [{"label": str, "probability": float}, ...],
|
| 866 |
+
"top_unary": [{"frame_id": int, "predicate": str, "probability": float}, ...],
|
| 867 |
+
}
|
| 868 |
+
}
|
| 869 |
+
}
|
| 870 |
+
"""
|
| 871 |
+
categorical_preds = model_outputs.get("categorical_predictions", {})
|
| 872 |
+
unary_preds = model_outputs.get("unary_predictions", {})
|
| 873 |
+
|
| 874 |
+
unary_by_obj: Dict[int, List[Tuple[float, str, int]]] = {}
|
| 875 |
+
for (frame_id, obj_id), preds in unary_preds.items():
|
| 876 |
+
for prob, predicate in preds:
|
| 877 |
+
prob_val = (
|
| 878 |
+
float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
|
| 879 |
+
)
|
| 880 |
+
unary_by_obj.setdefault(obj_id, []).append((prob_val, predicate, frame_id))
|
| 881 |
+
|
| 882 |
+
objects_summary: Dict[str, Dict[str, Any]] = {}
|
| 883 |
+
all_obj_ids = set(categorical_preds.keys()) | set(unary_by_obj.keys())
|
| 884 |
+
|
| 885 |
+
for obj_id in sorted(all_obj_ids):
|
| 886 |
+
cat_list = categorical_preds.get(obj_id, [])
|
| 887 |
+
cat_sorted = sorted(
|
| 888 |
+
[
|
| 889 |
+
(
|
| 890 |
+
float(p.detach().cpu()) if torch.is_tensor(p) else float(p),
|
| 891 |
+
label,
|
| 892 |
+
)
|
| 893 |
+
for p, label in cat_list
|
| 894 |
+
],
|
| 895 |
+
key=lambda x: x[0],
|
| 896 |
+
reverse=True,
|
| 897 |
+
)[:3]
|
| 898 |
+
|
| 899 |
+
top_categories = [
|
| 900 |
+
{"label": label, "probability": prob} for prob, label in cat_sorted
|
| 901 |
+
]
|
| 902 |
+
|
| 903 |
+
unary_list = unary_by_obj.get(obj_id, [])
|
| 904 |
+
unary_sorted = sorted(unary_list, key=lambda x: x[0], reverse=True)[:3]
|
| 905 |
+
top_unary = [
|
| 906 |
+
{
|
| 907 |
+
"frame_id": int(frame_id),
|
| 908 |
+
"predicate": predicate,
|
| 909 |
+
"probability": prob,
|
| 910 |
+
}
|
| 911 |
+
for (prob, predicate, frame_id) in unary_sorted
|
| 912 |
+
]
|
| 913 |
+
|
| 914 |
+
objects_summary[str(obj_id)] = {
|
| 915 |
+
"top_categories": top_categories,
|
| 916 |
+
"top_unary": top_unary,
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
+
summary = {
|
| 920 |
+
"num_objects_detected": len(objects_summary),
|
| 921 |
+
"objects": objects_summary,
|
| 922 |
+
}
|
| 923 |
+
return summary
|
vine_hf/vis_utils.py
ADDED
|
@@ -0,0 +1,941 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import torch
|
| 6 |
+
import random
|
| 7 |
+
import math
|
| 8 |
+
from matplotlib.patches import Rectangle
|
| 9 |
+
import itertools
|
| 10 |
+
from typing import Any, Dict, List, Tuple, Optional, Union
|
| 11 |
+
|
| 12 |
+
from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
|
| 13 |
+
|
| 14 |
+
########################################################################################
|
| 15 |
+
########## Visualization Library ########
|
| 16 |
+
########################################################################################
|
| 17 |
+
# This module renders SAM masks, GroundingDINO boxes, and VINE predictions.
|
| 18 |
+
#
|
| 19 |
+
# Conventions (RGB frames, pixel coords):
|
| 20 |
+
# - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3).
|
| 21 |
+
# - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted.
|
| 22 |
+
# - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1.
|
| 23 |
+
#
|
| 24 |
+
# Per-frame stores use one of:
|
| 25 |
+
# - Dict[int(frame_id) -> Dict[int(obj_id) -> value]]
|
| 26 |
+
# - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order)
|
| 27 |
+
#
|
| 28 |
+
# Renderer inputs/outputs:
|
| 29 |
+
# 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray]
|
| 30 |
+
# - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor.
|
| 31 |
+
# - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks.
|
| 32 |
+
#
|
| 33 |
+
# 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray]
|
| 34 |
+
# - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2].
|
| 35 |
+
#
|
| 36 |
+
# 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None)
|
| 37 |
+
# -> List[np.ndarray] (the "all" view)
|
| 38 |
+
# - cat_label_lookup: Dict[obj_id, (label: str, prob: float)]
|
| 39 |
+
# - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]]
|
| 40 |
+
# - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]]
|
| 41 |
+
# - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist.
|
| 42 |
+
#
|
| 43 |
+
# Ground-truth helpers used by plotting utilities:
|
| 44 |
+
# - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)].
|
| 45 |
+
#
|
| 46 |
+
# All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
|
| 47 |
+
########################################################################################
|
| 48 |
+
|
| 49 |
+
def clean_label(label):
|
| 50 |
+
"""Replace underscores and slashes with spaces for uniformity."""
|
| 51 |
+
return label.replace("_", " ").replace("/", " ")
|
| 52 |
+
|
| 53 |
+
# Should be performed somewhere else I believe
|
| 54 |
+
def format_cate_preds(cate_preds):
|
| 55 |
+
# Group object predictions from the model output.
|
| 56 |
+
obj_pred_dict = {}
|
| 57 |
+
for (oid, label), prob in cate_preds.items():
|
| 58 |
+
# Clean the predicted label as well.
|
| 59 |
+
clean_pred = clean_label(label)
|
| 60 |
+
if oid not in obj_pred_dict:
|
| 61 |
+
obj_pred_dict[oid] = []
|
| 62 |
+
obj_pred_dict[oid].append((clean_pred, prob))
|
| 63 |
+
for oid in obj_pred_dict:
|
| 64 |
+
obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
|
| 65 |
+
return obj_pred_dict
|
| 66 |
+
|
| 67 |
+
def format_binary_cate_preds(binary_preds):
|
| 68 |
+
frame_binary_preds = []
|
| 69 |
+
for key, score in binary_preds.items():
|
| 70 |
+
# Expect key format: (frame_id, (subject, object), predicted_relation)
|
| 71 |
+
try:
|
| 72 |
+
f_id, (subj, obj), pred_rel = key
|
| 73 |
+
frame_binary_preds.append((f_id, subj, obj, pred_rel, score))
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print("Skipping key with unexpected format:", key)
|
| 76 |
+
continue
|
| 77 |
+
frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
|
| 78 |
+
return frame_binary_preds
|
| 79 |
+
|
| 80 |
+
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]:
|
| 84 |
+
if mask is None:
|
| 85 |
+
return None
|
| 86 |
+
if isinstance(mask, torch.Tensor):
|
| 87 |
+
mask_np = mask.detach().cpu().numpy()
|
| 88 |
+
else:
|
| 89 |
+
mask_np = np.asarray(mask)
|
| 90 |
+
if mask_np.ndim == 0:
|
| 91 |
+
return None
|
| 92 |
+
if mask_np.ndim == 3:
|
| 93 |
+
mask_np = np.squeeze(mask_np)
|
| 94 |
+
if mask_np.ndim != 2:
|
| 95 |
+
return None
|
| 96 |
+
if mask_np.dtype == bool:
|
| 97 |
+
return mask_np
|
| 98 |
+
return mask_np > 0
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]:
|
| 102 |
+
if bbox is None:
|
| 103 |
+
return None
|
| 104 |
+
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
|
| 105 |
+
x1, y1, x2, y2 = [float(b) for b in bbox[:4]]
|
| 106 |
+
elif isinstance(bbox, np.ndarray) and bbox.size >= 4:
|
| 107 |
+
x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]]
|
| 108 |
+
else:
|
| 109 |
+
return None
|
| 110 |
+
x1 = int(np.clip(round(x1), 0, width - 1))
|
| 111 |
+
y1 = int(np.clip(round(y1), 0, height - 1))
|
| 112 |
+
x2 = int(np.clip(round(x2), 0, width - 1))
|
| 113 |
+
y2 = int(np.clip(round(y2), 0, height - 1))
|
| 114 |
+
if x2 <= x1 or y2 <= y1:
|
| 115 |
+
return None
|
| 116 |
+
return (x1, y1, x2, y2)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]:
|
| 120 |
+
color = get_color(obj_id)
|
| 121 |
+
rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]]
|
| 122 |
+
return (rgb[2], rgb[1], rgb[0])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]:
|
| 126 |
+
return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _draw_label_block(
|
| 130 |
+
image: np.ndarray,
|
| 131 |
+
lines: List[str],
|
| 132 |
+
anchor: Tuple[int, int],
|
| 133 |
+
color: Tuple[int, int, int],
|
| 134 |
+
font_scale: float = 0.5,
|
| 135 |
+
thickness: int = 1,
|
| 136 |
+
direction: str = "up",
|
| 137 |
+
) -> None:
|
| 138 |
+
if not lines:
|
| 139 |
+
return
|
| 140 |
+
img_h, img_w = image.shape[:2]
|
| 141 |
+
x, y = anchor
|
| 142 |
+
x = int(np.clip(x, 0, img_w - 1))
|
| 143 |
+
y_cursor = int(np.clip(y, 0, img_h - 1))
|
| 144 |
+
bg_color = _background_color(color)
|
| 145 |
+
|
| 146 |
+
if direction == "down":
|
| 147 |
+
for text in lines:
|
| 148 |
+
text = str(text)
|
| 149 |
+
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
|
| 150 |
+
left_x = x
|
| 151 |
+
right_x = min(left_x + tw + 8, img_w - 1)
|
| 152 |
+
top_y = int(np.clip(y_cursor + 6, 0, img_h - 1))
|
| 153 |
+
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
|
| 154 |
+
if bottom_y <= top_y:
|
| 155 |
+
break
|
| 156 |
+
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 157 |
+
text_x = left_x + 4
|
| 158 |
+
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 159 |
+
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
| 160 |
+
y_cursor = bottom_y
|
| 161 |
+
else:
|
| 162 |
+
for text in lines:
|
| 163 |
+
text = str(text)
|
| 164 |
+
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
|
| 165 |
+
top_y = max(y_cursor - th - baseline - 6, 0)
|
| 166 |
+
left_x = x
|
| 167 |
+
right_x = min(left_x + tw + 8, img_w - 1)
|
| 168 |
+
bottom_y = min(top_y + th + baseline + 6, img_h - 1)
|
| 169 |
+
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 170 |
+
text_x = left_x + 4
|
| 171 |
+
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 172 |
+
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
| 173 |
+
y_cursor = top_y
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _draw_centered_label(
|
| 177 |
+
image: np.ndarray,
|
| 178 |
+
text: str,
|
| 179 |
+
center: Tuple[int, int],
|
| 180 |
+
color: Tuple[int, int, int],
|
| 181 |
+
font_scale: float = 0.5,
|
| 182 |
+
thickness: int = 1,
|
| 183 |
+
) -> None:
|
| 184 |
+
text = str(text)
|
| 185 |
+
img_h, img_w = image.shape[:2]
|
| 186 |
+
(tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
|
| 187 |
+
cx = int(np.clip(center[0], 0, img_w - 1))
|
| 188 |
+
cy = int(np.clip(center[1], 0, img_h - 1))
|
| 189 |
+
left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1))
|
| 190 |
+
top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
|
| 191 |
+
right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
|
| 192 |
+
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
|
| 193 |
+
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1)
|
| 194 |
+
text_x = left_x + 4
|
| 195 |
+
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 196 |
+
cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]:
|
| 200 |
+
if isinstance(store, dict):
|
| 201 |
+
frame_entry = store.get(frame_idx, {})
|
| 202 |
+
elif isinstance(store, list) and 0 <= frame_idx < len(store):
|
| 203 |
+
frame_entry = store[frame_idx]
|
| 204 |
+
else:
|
| 205 |
+
frame_entry = {}
|
| 206 |
+
if isinstance(frame_entry, dict):
|
| 207 |
+
return frame_entry
|
| 208 |
+
if isinstance(frame_entry, list):
|
| 209 |
+
return {i: value for i, value in enumerate(frame_entry)}
|
| 210 |
+
return {}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _label_anchor_and_direction(
|
| 214 |
+
bbox: Tuple[int, int, int, int],
|
| 215 |
+
position: str,
|
| 216 |
+
) -> Tuple[Tuple[int, int], str]:
|
| 217 |
+
x1, y1, x2, y2 = bbox
|
| 218 |
+
if position == "bottom":
|
| 219 |
+
return (x1, y2), "down"
|
| 220 |
+
return (x1, y1), "up"
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _draw_bbox_with_label(
|
| 224 |
+
image: np.ndarray,
|
| 225 |
+
bbox: Tuple[int, int, int, int],
|
| 226 |
+
obj_id: int,
|
| 227 |
+
title: Optional[str] = None,
|
| 228 |
+
sub_lines: Optional[List[str]] = None,
|
| 229 |
+
label_position: str = "top",
|
| 230 |
+
) -> None:
|
| 231 |
+
color = _object_color_bgr(obj_id)
|
| 232 |
+
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
|
| 233 |
+
head = title if title else f"#{obj_id}"
|
| 234 |
+
if not head.startswith("#"):
|
| 235 |
+
head = f"#{obj_id} {head}"
|
| 236 |
+
lines = [head]
|
| 237 |
+
if sub_lines:
|
| 238 |
+
lines.extend(sub_lines)
|
| 239 |
+
anchor, direction = _label_anchor_and_direction(bbox, label_position)
|
| 240 |
+
_draw_label_block(image, lines, anchor, color, direction=direction)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def render_sam_frames(
|
| 244 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 245 |
+
sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None],
|
| 246 |
+
dino_labels: Optional[Dict[int, str]] = None,
|
| 247 |
+
) -> List[np.ndarray]:
|
| 248 |
+
results: List[np.ndarray] = []
|
| 249 |
+
frames_iterable = frames if isinstance(frames, list) else list(frames)
|
| 250 |
+
dino_labels = dino_labels or {}
|
| 251 |
+
|
| 252 |
+
for frame_idx, frame in enumerate(frames_iterable):
|
| 253 |
+
if frame is None:
|
| 254 |
+
continue
|
| 255 |
+
frame_rgb = np.asarray(frame)
|
| 256 |
+
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 257 |
+
overlay = frame_bgr.astype(np.float32)
|
| 258 |
+
masks_for_frame = _extract_frame_entities(sam_masks, frame_idx)
|
| 259 |
+
|
| 260 |
+
for obj_id, mask in masks_for_frame.items():
|
| 261 |
+
mask_np = _to_numpy_mask(mask)
|
| 262 |
+
if mask_np is None or not np.any(mask_np):
|
| 263 |
+
continue
|
| 264 |
+
color = _object_color_bgr(obj_id)
|
| 265 |
+
alpha = 0.45
|
| 266 |
+
overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32)
|
| 267 |
+
|
| 268 |
+
annotated = np.clip(overlay, 0, 255).astype(np.uint8)
|
| 269 |
+
frame_h, frame_w = annotated.shape[:2]
|
| 270 |
+
|
| 271 |
+
for obj_id, mask in masks_for_frame.items():
|
| 272 |
+
mask_np = _to_numpy_mask(mask)
|
| 273 |
+
if mask_np is None or not np.any(mask_np):
|
| 274 |
+
continue
|
| 275 |
+
bbox = mask_to_bbox(mask_np)
|
| 276 |
+
bbox = _sanitize_bbox(bbox, frame_w, frame_h)
|
| 277 |
+
if not bbox:
|
| 278 |
+
continue
|
| 279 |
+
label = dino_labels.get(obj_id)
|
| 280 |
+
title = f"{label}" if label else None
|
| 281 |
+
_draw_bbox_with_label(annotated, bbox, obj_id, title=title)
|
| 282 |
+
|
| 283 |
+
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
|
| 284 |
+
|
| 285 |
+
return results
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def render_dino_frames(
|
| 289 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 290 |
+
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
|
| 291 |
+
dino_labels: Optional[Dict[int, str]] = None,
|
| 292 |
+
) -> List[np.ndarray]:
|
| 293 |
+
results: List[np.ndarray] = []
|
| 294 |
+
frames_iterable = frames if isinstance(frames, list) else list(frames)
|
| 295 |
+
dino_labels = dino_labels or {}
|
| 296 |
+
|
| 297 |
+
for frame_idx, frame in enumerate(frames_iterable):
|
| 298 |
+
if frame is None:
|
| 299 |
+
continue
|
| 300 |
+
frame_rgb = np.asarray(frame)
|
| 301 |
+
annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 302 |
+
frame_h, frame_w = annotated.shape[:2]
|
| 303 |
+
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
|
| 304 |
+
|
| 305 |
+
for obj_id, bbox_values in frame_bboxes.items():
|
| 306 |
+
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
|
| 307 |
+
if not bbox:
|
| 308 |
+
continue
|
| 309 |
+
label = dino_labels.get(obj_id)
|
| 310 |
+
title = f"{label}" if label else None
|
| 311 |
+
_draw_bbox_with_label(annotated, bbox, obj_id, title=title)
|
| 312 |
+
|
| 313 |
+
results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
|
| 314 |
+
|
| 315 |
+
return results
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def render_vine_frame_sets(
|
| 319 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 320 |
+
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
|
| 321 |
+
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 322 |
+
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 323 |
+
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 324 |
+
masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
|
| 325 |
+
) -> Dict[str, List[np.ndarray]]:
|
| 326 |
+
frame_groups: Dict[str, List[np.ndarray]] = {
|
| 327 |
+
"object": [],
|
| 328 |
+
"unary": [],
|
| 329 |
+
"binary": [],
|
| 330 |
+
"all": [],
|
| 331 |
+
}
|
| 332 |
+
frames_iterable = frames if isinstance(frames, list) else list(frames)
|
| 333 |
+
|
| 334 |
+
for frame_idx, frame in enumerate(frames_iterable):
|
| 335 |
+
if frame is None:
|
| 336 |
+
continue
|
| 337 |
+
frame_rgb = np.asarray(frame)
|
| 338 |
+
base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 339 |
+
frame_h, frame_w = base_bgr.shape[:2]
|
| 340 |
+
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
|
| 341 |
+
frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {}
|
| 342 |
+
|
| 343 |
+
objects_bgr = base_bgr.copy()
|
| 344 |
+
unary_bgr = base_bgr.copy()
|
| 345 |
+
binary_bgr = base_bgr.copy()
|
| 346 |
+
all_bgr = base_bgr.copy()
|
| 347 |
+
|
| 348 |
+
bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {}
|
| 349 |
+
unary_lines_lookup: Dict[int, List[str]] = {}
|
| 350 |
+
titles_lookup: Dict[int, Optional[str]] = {}
|
| 351 |
+
|
| 352 |
+
for obj_id, bbox_values in frame_bboxes.items():
|
| 353 |
+
bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
|
| 354 |
+
if not bbox:
|
| 355 |
+
continue
|
| 356 |
+
bbox_lookup[obj_id] = bbox
|
| 357 |
+
cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None))
|
| 358 |
+
title_parts = []
|
| 359 |
+
if cat_label:
|
| 360 |
+
if cat_prob is not None:
|
| 361 |
+
title_parts.append(f"{cat_label} {cat_prob:.2f}")
|
| 362 |
+
else:
|
| 363 |
+
title_parts.append(cat_label)
|
| 364 |
+
titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None
|
| 365 |
+
unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, [])
|
| 366 |
+
unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds]
|
| 367 |
+
unary_lines_lookup[obj_id] = unary_lines
|
| 368 |
+
|
| 369 |
+
for obj_id, bbox in bbox_lookup.items():
|
| 370 |
+
unary_lines = unary_lines_lookup.get(obj_id, [])
|
| 371 |
+
if not unary_lines:
|
| 372 |
+
continue
|
| 373 |
+
mask_raw = frame_masks.get(obj_id)
|
| 374 |
+
mask_np = _to_numpy_mask(mask_raw)
|
| 375 |
+
if mask_np is None or not np.any(mask_np):
|
| 376 |
+
continue
|
| 377 |
+
color = np.array(_object_color_bgr(obj_id), dtype=np.float32)
|
| 378 |
+
alpha = 0.45
|
| 379 |
+
for target in (unary_bgr, all_bgr):
|
| 380 |
+
target_vals = target[mask_np].astype(np.float32)
|
| 381 |
+
blended = (1.0 - alpha) * target_vals + alpha * color
|
| 382 |
+
target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8)
|
| 383 |
+
|
| 384 |
+
for obj_id, bbox in bbox_lookup.items():
|
| 385 |
+
title = titles_lookup.get(obj_id)
|
| 386 |
+
unary_lines = unary_lines_lookup.get(obj_id, [])
|
| 387 |
+
_draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top")
|
| 388 |
+
_draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top")
|
| 389 |
+
if unary_lines:
|
| 390 |
+
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 391 |
+
_draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
|
| 392 |
+
_draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top")
|
| 393 |
+
_draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top")
|
| 394 |
+
if unary_lines:
|
| 395 |
+
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 396 |
+
_draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
|
| 397 |
+
|
| 398 |
+
for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
|
| 399 |
+
if len(obj_pair) != 2 or not relation_preds:
|
| 400 |
+
continue
|
| 401 |
+
subj_id, obj_id = obj_pair
|
| 402 |
+
subj_bbox = bbox_lookup.get(subj_id)
|
| 403 |
+
obj_bbox = bbox_lookup.get(obj_id)
|
| 404 |
+
if not subj_bbox or not obj_bbox:
|
| 405 |
+
continue
|
| 406 |
+
start, end = relation_line(subj_bbox, obj_bbox)
|
| 407 |
+
color = tuple(int(c) for c in np.clip(
|
| 408 |
+
(np.array(_object_color_bgr(subj_id), dtype=np.float32) +
|
| 409 |
+
np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
|
| 410 |
+
0, 255
|
| 411 |
+
))
|
| 412 |
+
prob, relation = relation_preds[0]
|
| 413 |
+
label_text = f"{relation} {prob:.2f}"
|
| 414 |
+
mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
|
| 415 |
+
cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA)
|
| 416 |
+
cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA)
|
| 417 |
+
_draw_centered_label(binary_bgr, label_text, mid_point, color)
|
| 418 |
+
_draw_centered_label(all_bgr, label_text, mid_point, color)
|
| 419 |
+
|
| 420 |
+
frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB))
|
| 421 |
+
frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB))
|
| 422 |
+
frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB))
|
| 423 |
+
frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB))
|
| 424 |
+
|
| 425 |
+
return frame_groups
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def render_vine_frames(
|
| 429 |
+
frames: Union[np.ndarray, List[np.ndarray]],
|
| 430 |
+
bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
|
| 431 |
+
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 432 |
+
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 433 |
+
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 434 |
+
masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
|
| 435 |
+
) -> List[np.ndarray]:
|
| 436 |
+
return render_vine_frame_sets(
|
| 437 |
+
frames,
|
| 438 |
+
bboxes,
|
| 439 |
+
cat_label_lookup,
|
| 440 |
+
unary_lookup,
|
| 441 |
+
binary_lookup,
|
| 442 |
+
masks,
|
| 443 |
+
).get("all", [])
|
| 444 |
+
|
| 445 |
+
def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
|
| 446 |
+
all_colors = []
|
| 447 |
+
all_texts = []
|
| 448 |
+
for (obj_id, bbox, gt_label) in gt_labels:
|
| 449 |
+
preds = obj_pred_dict.get(obj_id, [])
|
| 450 |
+
if len(preds) == 0:
|
| 451 |
+
top1 = "N/A"
|
| 452 |
+
box_color = (0, 0, 255) # bright red if no prediction
|
| 453 |
+
else:
|
| 454 |
+
top1, prob1 = preds[0]
|
| 455 |
+
topk_labels = [p[0] for p in preds[:topk_object]]
|
| 456 |
+
# Compare cleaned labels.
|
| 457 |
+
if top1.lower() == gt_label.lower():
|
| 458 |
+
box_color = (0, 255, 0) # bright green for correct
|
| 459 |
+
elif gt_label.lower() in [p.lower() for p in topk_labels]:
|
| 460 |
+
box_color = (0, 165, 255) # bright orange for partial match
|
| 461 |
+
else:
|
| 462 |
+
box_color = (0, 0, 255) # bright red for incorrect
|
| 463 |
+
|
| 464 |
+
label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
|
| 465 |
+
all_colors.append(box_color)
|
| 466 |
+
all_texts.append(label_text)
|
| 467 |
+
return all_colors, all_texts
|
| 468 |
+
|
| 469 |
+
def plot_unary(frame_img, gt_labels, all_colors, all_texts):
|
| 470 |
+
|
| 471 |
+
for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts):
|
| 472 |
+
x1, y1, x2, y2 = map(int, bbox)
|
| 473 |
+
cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
|
| 474 |
+
(tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
| 475 |
+
cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1)
|
| 476 |
+
cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX,
|
| 477 |
+
0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
| 478 |
+
|
| 479 |
+
return frame_img
|
| 480 |
+
|
| 481 |
+
def get_white_pane(pane_height,
|
| 482 |
+
pane_width=600,
|
| 483 |
+
header_height = 50,
|
| 484 |
+
header_font = cv2.FONT_HERSHEY_SIMPLEX,
|
| 485 |
+
header_font_scale = 0.7,
|
| 486 |
+
header_thickness = 2,
|
| 487 |
+
header_color = (0, 0, 0)):
|
| 488 |
+
# Create an expanded white pane to display text info.
|
| 489 |
+
white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
|
| 490 |
+
|
| 491 |
+
# --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
|
| 492 |
+
left_width = int(pane_width * 0.6)
|
| 493 |
+
right_width = pane_width - left_width
|
| 494 |
+
left_pane = white_pane[:, :left_width, :].copy()
|
| 495 |
+
right_pane = white_pane[:, left_width:, :].copy()
|
| 496 |
+
|
| 497 |
+
cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30),
|
| 498 |
+
header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
|
| 499 |
+
cv2.putText(right_pane, "Ground Truth", (10, header_height - 30),
|
| 500 |
+
header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
|
| 501 |
+
|
| 502 |
+
return white_pane
|
| 503 |
+
|
| 504 |
+
# This is for ploting binary prediction results with frame-based scene graphs
|
| 505 |
+
def plot_binary_sg(frame_img,
|
| 506 |
+
white_pane,
|
| 507 |
+
bin_preds,
|
| 508 |
+
gt_relations,
|
| 509 |
+
topk_binary,
|
| 510 |
+
header_height=50,
|
| 511 |
+
indicator_size=20,
|
| 512 |
+
pane_width=600):
|
| 513 |
+
# Leave vertical space for the headers.
|
| 514 |
+
line_height = 30 # vertical spacing per line
|
| 515 |
+
x_text = 10 # left margin for text
|
| 516 |
+
y_text_left = header_height + 10 # starting y for left pane text
|
| 517 |
+
y_text_right = header_height + 10 # starting y for right pane text
|
| 518 |
+
|
| 519 |
+
# Left section: top-k binary predictions.
|
| 520 |
+
left_width = int(pane_width * 0.6)
|
| 521 |
+
right_width = pane_width - left_width
|
| 522 |
+
left_pane = white_pane[:, :left_width, :].copy()
|
| 523 |
+
right_pane = white_pane[:, left_width:, :].copy()
|
| 524 |
+
|
| 525 |
+
for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]:
|
| 526 |
+
correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
|
| 527 |
+
for gt in gt_relations)
|
| 528 |
+
indicator_color = (0, 255, 0) if correct else (0, 0, 255)
|
| 529 |
+
cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5),
|
| 530 |
+
(x_text + indicator_size, y_text_left + 5), indicator_color, -1)
|
| 531 |
+
text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
|
| 532 |
+
cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5),
|
| 533 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
|
| 534 |
+
y_text_left += line_height
|
| 535 |
+
|
| 536 |
+
# Right section: ground truth binary relations.
|
| 537 |
+
for gt in gt_relations:
|
| 538 |
+
if len(gt) != 3:
|
| 539 |
+
continue
|
| 540 |
+
text = f"{gt[0]} - {gt[2]} - {gt[1]}"
|
| 541 |
+
cv2.putText(right_pane, text, (x_text, y_text_right + 5),
|
| 542 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
|
| 543 |
+
y_text_right += line_height
|
| 544 |
+
|
| 545 |
+
# Combine the two text panes and then with the frame image.
|
| 546 |
+
combined_pane = np.hstack((left_pane, right_pane))
|
| 547 |
+
combined_image = np.hstack((frame_img, combined_pane))
|
| 548 |
+
return combined_image
|
| 549 |
+
|
| 550 |
+
def visualized_frame(frame_img,
|
| 551 |
+
bboxes,
|
| 552 |
+
object_ids,
|
| 553 |
+
gt_labels,
|
| 554 |
+
cate_preds,
|
| 555 |
+
binary_preds,
|
| 556 |
+
gt_relations,
|
| 557 |
+
topk_object,
|
| 558 |
+
topk_binary,
|
| 559 |
+
phase="unary"):
|
| 560 |
+
|
| 561 |
+
"""Return the combined annotated frame for frame index i as an image (in BGR)."""
|
| 562 |
+
# Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
|
| 563 |
+
|
| 564 |
+
# --- Process Object Predictions (for overlaying bboxes) ---
|
| 565 |
+
if phase == "unary":
|
| 566 |
+
objs = []
|
| 567 |
+
for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels):
|
| 568 |
+
gt_label = clean_label(gt_label)
|
| 569 |
+
objs.append((obj_id, bbox, gt_label))
|
| 570 |
+
|
| 571 |
+
formatted_cate_preds = format_cate_preds(cate_preds)
|
| 572 |
+
all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object)
|
| 573 |
+
updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
|
| 574 |
+
return updated_frame_img
|
| 575 |
+
|
| 576 |
+
else:
|
| 577 |
+
# --- Process Binary Predictions & Ground Truth for the Text Pane ---
|
| 578 |
+
formatted_binary_preds = format_binary_cate_preds(binary_preds)
|
| 579 |
+
|
| 580 |
+
# Ground truth binary relations for the frame.
|
| 581 |
+
# Clean ground truth relations.
|
| 582 |
+
gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations]
|
| 583 |
+
|
| 584 |
+
pane_width = 600 # increased pane width for more horizontal space
|
| 585 |
+
pane_height = frame_img.shape[0]
|
| 586 |
+
|
| 587 |
+
# --- Add header labels to each text pane with extra space ---
|
| 588 |
+
header_height = 50 # increased header space
|
| 589 |
+
white_pane = get_white_pane(pane_height, pane_width, header_height=header_height)
|
| 590 |
+
|
| 591 |
+
combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary)
|
| 592 |
+
|
| 593 |
+
return combined_image
|
| 594 |
+
|
| 595 |
+
def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
| 596 |
+
# Ensure mask is a numpy array
|
| 597 |
+
mask = np.array(mask)
|
| 598 |
+
# Handle different mask shapes
|
| 599 |
+
if mask.ndim == 3:
|
| 600 |
+
# (1, H, W) -> (H, W)
|
| 601 |
+
if mask.shape[0] == 1:
|
| 602 |
+
mask = mask.squeeze(0)
|
| 603 |
+
# (H, W, 1) -> (H, W)
|
| 604 |
+
elif mask.shape[2] == 1:
|
| 605 |
+
mask = mask.squeeze(2)
|
| 606 |
+
# Now mask should be (H, W)
|
| 607 |
+
assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}"
|
| 608 |
+
|
| 609 |
+
if random_color:
|
| 610 |
+
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
|
| 611 |
+
else:
|
| 612 |
+
cmap = plt.get_cmap("gist_rainbow")
|
| 613 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
| 614 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 615 |
+
color[3] = 0.5
|
| 616 |
+
color = np.array(color)
|
| 617 |
+
|
| 618 |
+
# Expand mask to (H, W, 1) for broadcasting
|
| 619 |
+
mask_expanded = mask[..., None]
|
| 620 |
+
mask_image = mask_expanded * color.reshape(1, 1, -1)
|
| 621 |
+
|
| 622 |
+
# draw a box around the mask with the det_class as the label
|
| 623 |
+
if not det_class is None:
|
| 624 |
+
# Find the bounding box coordinates
|
| 625 |
+
y_indices, x_indices = np.where(mask > 0)
|
| 626 |
+
if y_indices.size > 0 and x_indices.size > 0:
|
| 627 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
| 628 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
| 629 |
+
rect = Rectangle(
|
| 630 |
+
(x_min, y_min),
|
| 631 |
+
x_max - x_min,
|
| 632 |
+
y_max - y_min,
|
| 633 |
+
linewidth=1.5,
|
| 634 |
+
edgecolor=color[:3],
|
| 635 |
+
facecolor="none",
|
| 636 |
+
alpha=color[3]
|
| 637 |
+
)
|
| 638 |
+
ax.add_patch(rect)
|
| 639 |
+
ax.text(
|
| 640 |
+
x_min,
|
| 641 |
+
y_min - 5,
|
| 642 |
+
f"{det_class}",
|
| 643 |
+
color="white",
|
| 644 |
+
fontsize=6,
|
| 645 |
+
backgroundcolor=np.array(color),
|
| 646 |
+
alpha=1
|
| 647 |
+
)
|
| 648 |
+
ax.imshow(mask_image)
|
| 649 |
+
|
| 650 |
+
def save_mask_one_image(frame_image, masks, save_path):
|
| 651 |
+
"""Render masks on top of a frame and store the visualization on disk."""
|
| 652 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 653 |
+
|
| 654 |
+
frame_np = (
|
| 655 |
+
frame_image.detach().cpu().numpy()
|
| 656 |
+
if torch.is_tensor(frame_image)
|
| 657 |
+
else np.asarray(frame_image)
|
| 658 |
+
)
|
| 659 |
+
frame_np = np.ascontiguousarray(frame_np)
|
| 660 |
+
|
| 661 |
+
if isinstance(masks, dict):
|
| 662 |
+
mask_iter = masks.items()
|
| 663 |
+
else:
|
| 664 |
+
mask_iter = enumerate(masks)
|
| 665 |
+
|
| 666 |
+
prepared_masks = {
|
| 667 |
+
obj_id: (
|
| 668 |
+
mask.detach().cpu().numpy()
|
| 669 |
+
if torch.is_tensor(mask)
|
| 670 |
+
else np.asarray(mask)
|
| 671 |
+
)
|
| 672 |
+
for obj_id, mask in mask_iter
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
ax.imshow(frame_np)
|
| 676 |
+
ax.axis("off")
|
| 677 |
+
|
| 678 |
+
for obj_id, mask_np in prepared_masks.items():
|
| 679 |
+
show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False)
|
| 680 |
+
|
| 681 |
+
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 682 |
+
plt.close(fig)
|
| 683 |
+
return save_path
|
| 684 |
+
|
| 685 |
+
def get_video_masks_visualization(video_tensor,
|
| 686 |
+
video_masks,
|
| 687 |
+
video_id,
|
| 688 |
+
video_save_base_dir,
|
| 689 |
+
oid_class_pred=None,
|
| 690 |
+
sample_rate = 1):
|
| 691 |
+
|
| 692 |
+
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 693 |
+
if not os.path.exists(video_save_dir):
|
| 694 |
+
os.makedirs(video_save_dir, exist_ok=True)
|
| 695 |
+
|
| 696 |
+
for frame_id, image in enumerate(video_tensor):
|
| 697 |
+
if frame_id not in video_masks:
|
| 698 |
+
print("No mask for Frame", frame_id)
|
| 699 |
+
continue
|
| 700 |
+
|
| 701 |
+
masks = video_masks[frame_id]
|
| 702 |
+
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 703 |
+
get_mask_one_image(image, masks, oid_class_pred)
|
| 704 |
+
|
| 705 |
+
def get_mask_one_image(frame_image, masks, oid_class_pred=None):
|
| 706 |
+
# Create a figure and axis
|
| 707 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 708 |
+
|
| 709 |
+
# Display the frame image
|
| 710 |
+
ax.imshow(frame_image)
|
| 711 |
+
ax.axis('off')
|
| 712 |
+
|
| 713 |
+
if type(masks) == list:
|
| 714 |
+
masks = {i: m for i, m in enumerate(masks)}
|
| 715 |
+
|
| 716 |
+
# Add the masks
|
| 717 |
+
for obj_id, mask in masks.items():
|
| 718 |
+
det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None
|
| 719 |
+
show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
|
| 720 |
+
|
| 721 |
+
# Show the plot
|
| 722 |
+
return fig, ax
|
| 723 |
+
|
| 724 |
+
def save_video(frames, output_filename, output_fps):
|
| 725 |
+
|
| 726 |
+
# --- Create a video from all frames ---
|
| 727 |
+
num_frames = len(frames)
|
| 728 |
+
frame_h, frame_w = frames.shape[:2]
|
| 729 |
+
|
| 730 |
+
# Use a codec supported by VS Code (H.264 via 'avc1').
|
| 731 |
+
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
| 732 |
+
out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
|
| 733 |
+
|
| 734 |
+
print(f"Processing {num_frames} frames...")
|
| 735 |
+
for i in range(num_frames):
|
| 736 |
+
vis_frame = get_visualized_frame(i)
|
| 737 |
+
out.write(vis_frame)
|
| 738 |
+
if i % 10 == 0:
|
| 739 |
+
print(f"Processed frame {i+1}/{num_frames}")
|
| 740 |
+
|
| 741 |
+
out.release()
|
| 742 |
+
print(f"Video saved as {output_filename}")
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def list_depth(lst):
|
| 746 |
+
"""Calculates the depth of a nested list."""
|
| 747 |
+
if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
|
| 748 |
+
return 0
|
| 749 |
+
elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0):
|
| 750 |
+
return 1
|
| 751 |
+
else:
|
| 752 |
+
return 1 + max(list_depth(item) for item in lst)
|
| 753 |
+
|
| 754 |
+
def normalize_prompt(points, labels):
|
| 755 |
+
if list_depth(points) == 3:
|
| 756 |
+
points = torch.stack([p.unsqueeze(0) for p in points])
|
| 757 |
+
labels = torch.stack([l.unsqueeze(0) for l in labels])
|
| 758 |
+
return points, labels
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def show_box(box, ax, object_id):
|
| 762 |
+
if len(box) == 0:
|
| 763 |
+
return
|
| 764 |
+
|
| 765 |
+
cmap = plt.get_cmap("gist_rainbow")
|
| 766 |
+
cmap_idx = 0 if object_id is None else object_id
|
| 767 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 768 |
+
|
| 769 |
+
x0, y0 = box[0], box[1]
|
| 770 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 771 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2))
|
| 772 |
+
|
| 773 |
+
def show_points(coords, labels, ax, object_id=None, marker_size=375):
|
| 774 |
+
if len(labels) == 0:
|
| 775 |
+
return
|
| 776 |
+
|
| 777 |
+
pos_points = coords[labels==1]
|
| 778 |
+
neg_points = coords[labels==0]
|
| 779 |
+
|
| 780 |
+
cmap = plt.get_cmap("gist_rainbow")
|
| 781 |
+
cmap_idx = 0 if object_id is None else object_id
|
| 782 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 783 |
+
|
| 784 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25)
|
| 785 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25)
|
| 786 |
+
|
| 787 |
+
def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
|
| 788 |
+
# Create a figure and axis
|
| 789 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 790 |
+
|
| 791 |
+
# Display the frame image
|
| 792 |
+
ax.imshow(frame_image)
|
| 793 |
+
ax.axis('off')
|
| 794 |
+
|
| 795 |
+
points, labels = normalize_prompt(points, labels)
|
| 796 |
+
if type(boxes) == torch.Tensor:
|
| 797 |
+
for object_id, box in enumerate(boxes):
|
| 798 |
+
# Add the bounding boxes
|
| 799 |
+
if not box is None:
|
| 800 |
+
show_box(box.cpu(), ax, object_id=object_id)
|
| 801 |
+
elif type(boxes) == dict:
|
| 802 |
+
for object_id, box in boxes.items():
|
| 803 |
+
# Add the bounding boxes
|
| 804 |
+
if not box is None:
|
| 805 |
+
show_box(box.cpu(), ax, object_id=object_id)
|
| 806 |
+
elif type(boxes) == list and len(boxes) == 0:
|
| 807 |
+
pass
|
| 808 |
+
else:
|
| 809 |
+
raise Exception()
|
| 810 |
+
|
| 811 |
+
for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
|
| 812 |
+
if not len(point_ls) == 0:
|
| 813 |
+
show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
|
| 814 |
+
|
| 815 |
+
# Show the plot
|
| 816 |
+
plt.savefig(save_path)
|
| 817 |
+
plt.close()
|
| 818 |
+
|
| 819 |
+
def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir):
|
| 820 |
+
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 821 |
+
if not os.path.exists(video_save_dir):
|
| 822 |
+
os.makedirs(video_save_dir, exist_ok=True)
|
| 823 |
+
|
| 824 |
+
for frame_id, image in enumerate(video_tensor):
|
| 825 |
+
boxes, points, labels = [], [], []
|
| 826 |
+
|
| 827 |
+
if frame_id in video_boxes:
|
| 828 |
+
boxes = video_boxes[frame_id]
|
| 829 |
+
|
| 830 |
+
if frame_id in video_points:
|
| 831 |
+
points = video_points[frame_id]
|
| 832 |
+
if frame_id in video_labels:
|
| 833 |
+
labels = video_labels[frame_id]
|
| 834 |
+
|
| 835 |
+
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 836 |
+
save_prompts_one_image(image, boxes, points, labels, save_path)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1):
|
| 840 |
+
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 841 |
+
if not os.path.exists(video_save_dir):
|
| 842 |
+
os.makedirs(video_save_dir, exist_ok=True)
|
| 843 |
+
|
| 844 |
+
for frame_id, image in enumerate(video_tensor):
|
| 845 |
+
if random.random() > sample_rate:
|
| 846 |
+
continue
|
| 847 |
+
if frame_id not in video_masks:
|
| 848 |
+
print("No mask for Frame", frame_id)
|
| 849 |
+
continue
|
| 850 |
+
masks = video_masks[frame_id]
|
| 851 |
+
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 852 |
+
save_mask_one_image(image, masks, save_path)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
|
| 857 |
+
cmap = plt.get_cmap(cmap_name)
|
| 858 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
| 859 |
+
color = list(cmap((cmap_idx * 47) % 256))
|
| 860 |
+
color[3] = 0.5
|
| 861 |
+
color = np.array(color)
|
| 862 |
+
return color
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
| 866 |
+
return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def relation_line(
|
| 870 |
+
bbox1: Tuple[int, int, int, int],
|
| 871 |
+
bbox2: Tuple[int, int, int, int],
|
| 872 |
+
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
| 873 |
+
"""
|
| 874 |
+
Returns integer pixel centers suitable for drawing a relation line. For
|
| 875 |
+
coincident boxes, nudges the target center to ensure the segment has span.
|
| 876 |
+
"""
|
| 877 |
+
center1 = _bbox_center(bbox1)
|
| 878 |
+
center2 = _bbox_center(bbox2)
|
| 879 |
+
if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3):
|
| 880 |
+
offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
|
| 881 |
+
center2 = (center2[0] + offset, center2[1])
|
| 882 |
+
start = (int(round(center1[0])), int(round(center1[1])))
|
| 883 |
+
end = (int(round(center2[0])), int(round(center2[1])))
|
| 884 |
+
if start == end:
|
| 885 |
+
end = (end[0] + 1, end[1])
|
| 886 |
+
return start, end
|
| 887 |
+
|
| 888 |
+
def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
|
| 889 |
+
# Create a figure and axis
|
| 890 |
+
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 891 |
+
|
| 892 |
+
# Display the frame image
|
| 893 |
+
ax.imshow(frame_image)
|
| 894 |
+
ax.axis('off')
|
| 895 |
+
|
| 896 |
+
all_objs_to_show = set()
|
| 897 |
+
all_lines_to_show = []
|
| 898 |
+
|
| 899 |
+
# print(rel_pred_ls[0])
|
| 900 |
+
for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
|
| 901 |
+
all_objs_to_show.add(from_obj_id)
|
| 902 |
+
all_objs_to_show.add(to_obj_id)
|
| 903 |
+
|
| 904 |
+
from_mask = masks[from_obj_id]
|
| 905 |
+
bbox1 = mask_to_bbox(from_mask)
|
| 906 |
+
to_mask = masks[to_obj_id]
|
| 907 |
+
bbox2 = mask_to_bbox(to_mask)
|
| 908 |
+
|
| 909 |
+
c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
|
| 910 |
+
|
| 911 |
+
line_color = get_color(from_obj_id)
|
| 912 |
+
face_color = get_color(to_obj_id)
|
| 913 |
+
line = c1, c2, face_color, line_color, rel_text
|
| 914 |
+
all_lines_to_show.append(line)
|
| 915 |
+
|
| 916 |
+
masks_to_show = {}
|
| 917 |
+
for oid in all_objs_to_show:
|
| 918 |
+
masks_to_show[oid] = masks[oid]
|
| 919 |
+
|
| 920 |
+
# Add the masks
|
| 921 |
+
for obj_id, mask in masks_to_show.items():
|
| 922 |
+
show_mask(mask, ax, obj_id=obj_id, random_color=False)
|
| 923 |
+
|
| 924 |
+
for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show:
|
| 925 |
+
|
| 926 |
+
plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3)
|
| 927 |
+
mid_pt_x = (from_pt_x + to_pt_x) / 2
|
| 928 |
+
mid_pt_y = (from_pt_y + to_pt_y) / 2
|
| 929 |
+
ax.text(
|
| 930 |
+
mid_pt_x - 5,
|
| 931 |
+
mid_pt_y,
|
| 932 |
+
rel_text,
|
| 933 |
+
color="white",
|
| 934 |
+
fontsize=6,
|
| 935 |
+
backgroundcolor=np.array(line_color),
|
| 936 |
+
bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'),
|
| 937 |
+
alpha=1
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# Show the plot
|
| 941 |
+
return fig, ax
|