ASethi04 commited on
Commit
f9a6349
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
GroundingDINO_SwinT_OGC.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_T_224_1k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LASER Demo
3
+ emoji: 🐠
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 6.0.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from collections.abc import Mapping, Sequence
3
+ from functools import lru_cache
4
+ import inspect
5
+ import shutil
6
+ import tempfile
7
+ import os
8
+ import sys
9
+
10
+ import spaces # <-- ZeroGPU integration
11
+ import gradio as gr
12
+ import torch
13
+ from transformers import pipeline # not strictly necessary, but fine
14
+
15
+
16
+ # -----------------------------
17
+ # Environment / diagnostics
18
+ # -----------------------------
19
+ os.environ["GRADIO_TEMP_DIR"] = str(Path(__file__).parent / "gradio_temp")
20
+ os.environ["OPENAI_API_KEY"] = "test"
21
+ os.environ["OMP_NUM_THREADS"] = "4"
22
+
23
+ print("All imports finished")
24
+ print(f"Python version: {sys.version}")
25
+ print(f"PyTorch version: {torch.__version__}")
26
+ print(f"CUDA available: {torch.cuda.is_available()}")
27
+ print(f"CUDA version: {torch.version.cuda}")
28
+ print(f"cuDNN version: {torch.backends.cudnn.version()}")
29
+ print(f"Number of GPUs: {torch.cuda.device_count()}")
30
+
31
+ if torch.cuda.is_available():
32
+ for i in range(torch.cuda.device_count()):
33
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
34
+ print(
35
+ f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB"
36
+ )
37
+
38
+ torch.backends.cuda.matmul.allow_tf32 = False
39
+ torch.backends.cudnn.allow_tf32 = False
40
+ os.environ["TORCH_DTYPE"] = "float32"
41
+ torch.set_default_dtype(torch.float32)
42
+
43
+ current_dir = Path(__file__).resolve().parent
44
+ # For Spaces, assume checkpoints live alongside app.py or in a "checkpoints" subdir.
45
+ # If you keep them next to app.py locally, this still works.
46
+ sam_config_path = str(current_dir / "sam2_hiera_t.yaml")
47
+ sam_checkpoint_path = str(current_dir / "sam2_hiera_tiny.pt")
48
+ gd_config_path = str(current_dir / "GroundingDINO_SwinT_OGC.py")
49
+ gd_checkpoint_path = str(current_dir / "groundingdino_swint_ogc.pth")
50
+ visualization_dir = str(current_dir / "outputs")
51
+ print(
52
+ f"Setting up paths: {sam_config_path}, {sam_checkpoint_path}, {gd_config_path}, {gd_checkpoint_path}"
53
+ )
54
+
55
+
56
+ @lru_cache(maxsize=1)
57
+ def _load_vine_pipeline():
58
+ """
59
+ Lazy-load and cache the Vine pipeline so we don't re-download/rebuild it on every request.
60
+ """
61
+ from vine_hf import VineConfig, VineModel, VinePipeline
62
+
63
+ config = VineConfig(
64
+ segmentation_method="grounding_dino_sam2",
65
+ model_name="openai/clip-vit-base-patch32",
66
+ use_hf_repo=True,
67
+ model_repo="KevinX-Penn28/testing",
68
+ box_threshold=0.35,
69
+ text_threshold=0.25,
70
+ target_fps=1, # default 1 FPS
71
+ topk_cate=5,
72
+ white_alpha=0.3,
73
+ visualization_dir=visualization_dir,
74
+ visualize=True,
75
+ debug_visualizations=False,
76
+ device="cuda",
77
+ categorical_pool="max",
78
+ )
79
+ model = VineModel(config)
80
+ return VinePipeline(
81
+ model=model,
82
+ tokenizer=None,
83
+ sam_config_path=sam_config_path,
84
+ sam_checkpoint_path=sam_checkpoint_path,
85
+ gd_config_path=gd_config_path,
86
+ gd_checkpoint_path=gd_checkpoint_path,
87
+ device="cuda",
88
+ trust_remote_code=True,
89
+ )
90
+
91
+
92
+ @spaces.GPU(duration=300) # Up to ~5 minutes of H200 ZeroGPU time per call
93
+ def process_video(
94
+ video_file,
95
+ categorical_keywords,
96
+ unary_keywords,
97
+ binary_keywords,
98
+ object_pairs,
99
+ output_fps,
100
+ box_threshold,
101
+ text_threshold,
102
+ ):
103
+ vine_pipe = _load_vine_pipeline()
104
+
105
+ # Normalize incoming video input to a file path
106
+ if isinstance(video_file, dict):
107
+ video_file = (
108
+ video_file.get("name")
109
+ or video_file.get("filepath")
110
+ or video_file.get("data")
111
+ )
112
+ if not isinstance(video_file, (str, Path)):
113
+ raise ValueError(f"Unsupported video input type: {type(video_file)}")
114
+
115
+ categorical_keywords = (
116
+ [kw.strip() for kw in categorical_keywords.split(",")]
117
+ if categorical_keywords
118
+ else []
119
+ )
120
+ unary_keywords = (
121
+ [kw.strip() for kw in unary_keywords.split(",")] if unary_keywords else []
122
+ )
123
+ binary_keywords = (
124
+ [kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
125
+ )
126
+ object_pairs = (
127
+ [tuple(map(int, pair.split("-"))) for pair in object_pairs.split(",")]
128
+ if object_pairs
129
+ else []
130
+ )
131
+
132
+ results = vine_pipe(
133
+ inputs=video_file,
134
+ categorical_keywords=categorical_keywords,
135
+ unary_keywords=unary_keywords,
136
+ binary_keywords=binary_keywords,
137
+ object_pairs=object_pairs,
138
+ segmentation_method="grounding_dino_sam2",
139
+ return_top_k=5,
140
+ include_visualizations=True,
141
+ debug_visualizations=False,
142
+ device="cuda",
143
+ box_threshold=box_threshold,
144
+ text_threshold=text_threshold,
145
+ target_fps=output_fps,
146
+ )
147
+
148
+ vine_pipe.box_threshold = box_threshold
149
+ vine_pipe.text_threshold = text_threshold
150
+ vine_pipe.target_fps = output_fps
151
+
152
+ if isinstance(results, Mapping):
153
+ results_dict = results
154
+ elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping):
155
+ results_dict = results[0]
156
+ else:
157
+ results_dict = {}
158
+
159
+ visualizations = results_dict.get("visualizations") or {}
160
+ vine = visualizations.get("vine") or {}
161
+ all_vis = vine.get("all") or {}
162
+ result_video_path = all_vis.get("video_path")
163
+ if not result_video_path:
164
+ candidates = sorted(
165
+ Path(visualization_dir).rglob("*.mp4"),
166
+ key=lambda p: p.stat().st_mtime,
167
+ reverse=True,
168
+ )
169
+ result_video_path = str(candidates[0]) if candidates else None
170
+ summary = results_dict.get("summary") or {}
171
+
172
+ if result_video_path and os.path.exists(result_video_path):
173
+ gradio_tmp = Path(
174
+ os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
175
+ ) / "vine_outputs"
176
+ gradio_tmp.mkdir(parents=True, exist_ok=True)
177
+ dest_path = gradio_tmp / Path(result_video_path).name
178
+ try:
179
+ shutil.copyfile(result_video_path, dest_path)
180
+ video_path_for_ui = str(dest_path)
181
+ except Exception as e:
182
+ print(f"Warning: failed to copy video to Gradio temp dir: {e}")
183
+ video_path_for_ui = str(result_video_path)
184
+ else:
185
+ video_path_for_ui = None
186
+ print(
187
+ "Warning: annotated video not found or empty; check visualization settings."
188
+ )
189
+
190
+ return video_path_for_ui, summary
191
+
192
+
193
+ def _video_component(label: str, *, is_output: bool = False):
194
+ """
195
+ Build a Gradio Video component that is compatible with older Gradio versions
196
+ (no `type`/`sources`/`format` kwargs) and newer ones when available.
197
+ """
198
+ kwargs = {"label": label}
199
+ sig = inspect.signature(gr.Video.__init__)
200
+
201
+ # Only set format for OUTPUT components
202
+ if is_output and "format" in sig.parameters:
203
+ kwargs["format"] = "mp4"
204
+
205
+ if not is_output:
206
+ if "type" in sig.parameters:
207
+ kwargs["type"] = "filepath"
208
+ if "sources" in sig.parameters:
209
+ kwargs["sources"] = ["upload"]
210
+
211
+ if is_output and "autoplay" in sig.parameters:
212
+ kwargs["autoplay"] = True
213
+
214
+ return gr.Video(**kwargs)
215
+
216
+
217
+ def _create_blocks():
218
+ """
219
+ Build a Blocks context that works across Gradio versions.
220
+ """
221
+ blocks_kwargs = {"title": "VINE Demo"}
222
+ soft_theme = None
223
+
224
+ if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
225
+ try:
226
+ soft_theme = gr.themes.Soft()
227
+ except Exception:
228
+ soft_theme = None
229
+
230
+ if "theme" in inspect.signature(gr.Blocks).parameters and soft_theme is not None:
231
+ blocks_kwargs["theme"] = soft_theme
232
+
233
+ return gr.Blocks(**blocks_kwargs)
234
+
235
+
236
+ # Create Gradio interface
237
+ with _create_blocks() as demo:
238
+ video_input = _video_component("Upload Video", is_output=False)
239
+ categorical_input = gr.Textbox(
240
+ label="Categorical Keywords (comma-separated)",
241
+ value="person, car, tree, background",
242
+ )
243
+ unary_input = gr.Textbox(
244
+ label="Unary Keywords (comma-separated)", value="walking, running, standing"
245
+ )
246
+ binary_input = gr.Textbox(
247
+ label="Binary Keywords (comma-separated)",
248
+ placeholder="e.g., chasing, carrying",
249
+ )
250
+ pairs_input = gr.Textbox(
251
+ label="Object Pairs (comma-separated indices)",
252
+ placeholder="e.g., 0-1,0-2 for pairs of objects",
253
+ )
254
+ fps_input = gr.Number(
255
+ label="Output FPS (affects processing speed)", value=1 # default 1 FPS
256
+ )
257
+
258
+ with gr.Accordion("Advanced Settings", open=False):
259
+ box_threshold_input = gr.Slider(
260
+ label="Box Threshold", minimum=0.1, maximum=0.9, value=0.35, step=0.05
261
+ )
262
+ text_threshold_input = gr.Slider(
263
+ label="Text Threshold", minimum=0.1, maximum=0.9, value=0.25, step=0.05
264
+ )
265
+
266
+ submit_btn = gr.Button("Process Video", variant="primary")
267
+
268
+ video_output = _video_component("Output Video with Annotations", is_output=True)
269
+ json_output = gr.JSON(label="Summary of Detected Events")
270
+
271
+ submit_btn.click(
272
+ fn=process_video,
273
+ inputs=[
274
+ video_input,
275
+ categorical_input,
276
+ unary_input,
277
+ binary_input,
278
+ pairs_input,
279
+ fps_input,
280
+ box_threshold_input,
281
+ text_threshold_input,
282
+ ],
283
+ outputs=[video_output, json_output],
284
+ )
285
+
286
+ if __name__ == "__main__":
287
+ print("Got to main")
288
+ demo.launch(share=True, debug=True)
groundingdino_swint_ogc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799
3
+ size 693997677
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ spaces>=0.24.0
3
+
4
+ transformers>=4.40.0
5
+ huggingface-hub>=0.23.0
6
+ safetensors>=0.4.2
7
+ accelerate>=0.30.0
8
+
9
+ --extra-index-url https://download.pytorch.org/whl/cu121
10
+ torch==2.2.1
11
+ torchvision==0.17.1
12
+
13
+ numpy
14
+ opencv-python
15
+ pillow
16
+ matplotlib
17
+ seaborn
18
+ pandas
19
+ tqdm
20
+ scikit-learn
21
+
22
+ -e git+https://github.com/video-fm/video-sam2.git#egg=video_sam2
23
+ -e git+https://github.com/IDEA-Research/GroundingDINO.git#egg=GroundingDINO
24
+ -e git+https://github.com/kevinxuez/LASER.git#egg=laser
sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
sam2_hiera_tiny.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65b50056e05bcb13694174f51bb6da89c894b57b75ccdf0ba6352c597c5d1125
3
+ size 155906050
vine_hf/OVERVIEW.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VINE HuggingFace Interface - Complete Overview
2
+
3
+ This directory contains a complete HuggingFace-compatible interface for the VINE (Video Understanding with Natural Language) model. The interface allows you to easily use, share, and deploy your VINE model through the HuggingFace ecosystem.
4
+
5
+ ## 📁 Directory Structure
6
+
7
+ ```
8
+ vine_hf/
9
+ ├── __init__.py # Package initialization and exports
10
+ ├── vine_config.py # VineConfig class (PretrainedConfig)
11
+ ├── vine_model.py # VineModel class (PreTrainedModel)
12
+ ├── vine_pipeline.py # VinePipeline class (Pipeline)
13
+ ├── example_usage.py # Comprehensive usage examples
14
+ ├── convert_inference.py # Migration guide from inference.py
15
+ ├── push_to_hub.py # Script to push model to HF Hub
16
+ ├── setup.py # Package setup configuration
17
+ ├── README.md # Detailed documentation
18
+ └── OVERVIEW.md # This file
19
+ ```
20
+
21
+ ## 🏗️ Architecture Components
22
+
23
+ ### 1. VineConfig (`vine_config.py`)
24
+ - Inherits from `PretrainedConfig`
25
+ - Configures model parameters, segmentation methods, and processing options
26
+ - Compatible with HuggingFace configuration system
27
+
28
+ ### 2. VineModel (`vine_model.py`)
29
+ - Inherits from `PreTrainedModel`
30
+ - Implements the core VINE model with three CLIP backbones
31
+ - Supports categorical, unary, and binary predictions
32
+ - Provides both `forward()` and `predict()` methods
33
+
34
+ ### 3. VinePipeline (`vine_pipeline.py`)
35
+ - Inherits from `Pipeline`
36
+ - Handles end-to-end video processing workflow
37
+ - Integrates segmentation (SAM2, Grounding DINO + SAM2)
38
+ - Provides user-friendly interface for video understanding
39
+
40
+ ## 🚀 Key Features
41
+
42
+ ✅ **Full HuggingFace Compatibility**
43
+ - Compatible with `transformers` library
44
+ - Supports `AutoModel` and `pipeline` interfaces
45
+ - Can be pushed to and loaded from HuggingFace Hub
46
+
47
+ ✅ **Flexible Segmentation**
48
+ - Support for SAM2 automatic segmentation
49
+ - Support for Grounding DINO + SAM2 text-guided segmentation
50
+ - Configurable thresholds and parameters
51
+
52
+ ✅ **Multi-Modal Understanding**
53
+ - Categorical classification (object types)
54
+ - Unary predicates (single object actions)
55
+ - Binary relations (object-object relationships)
56
+
57
+ ✅ **Easy Integration**
58
+ - Simple pipeline interface for end users
59
+ - Direct model access for researchers
60
+ - Comprehensive configuration options
61
+
62
+ ## 📖 Usage Examples
63
+
64
+ ### Quick Start with Pipeline
65
+ ```python
66
+ from transformers import pipeline
67
+ from vine_hf import VineModel, VinePipeline
68
+
69
+ # Create pipeline
70
+ vine_pipeline = pipeline(
71
+ "vine-video-understanding",
72
+ model="your-username/vine-model",
73
+ trust_remote_code=True
74
+ )
75
+
76
+ # Process video
77
+ results = vine_pipeline(
78
+ "video.mp4",
79
+ categorical_keywords=['human', 'dog', 'frisbee'],
80
+ unary_keywords=['running', 'jumping'],
81
+ binary_keywords=['chasing', 'behind']
82
+ )
83
+ ```
84
+
85
+ ### Direct Model Usage
86
+ ```python
87
+ from vine_hf import VineConfig, VineModel
88
+
89
+ config = VineConfig(segmentation_method="grounding_dino_sam2")
90
+ model = VineModel(config)
91
+
92
+ results = model.predict(
93
+ video_frames=video_tensor,
94
+ masks=masks_dict,
95
+ bboxes=bboxes_dict,
96
+ categorical_keywords=['human', 'dog'],
97
+ unary_keywords=['running', 'sitting'],
98
+ binary_keywords=['chasing', 'near']
99
+ )
100
+ ```
101
+
102
+ ## 🔧 Migration from Original Code
103
+
104
+ The `convert_inference.py` script shows how to migrate from the original `inference.py` workflow:
105
+
106
+ **Original Approach:**
107
+ - Manual model loading and configuration
108
+ - Direct handling of segmentation pipeline
109
+ - Custom result processing
110
+ - Complex setup requirements
111
+
112
+ **New HuggingFace Interface:**
113
+ - Standardized model configuration
114
+ - Automatic preprocessing/postprocessing
115
+ - Simple pipeline interface
116
+ - Easy sharing via HuggingFace Hub
117
+
118
+ ## 📤 Sharing Your Model
119
+
120
+ Use the `push_to_hub.py` script to share your trained model:
121
+
122
+ ```bash
123
+ python vine_hf/push_to_hub.py \
124
+ --weights path/to/your/model.pth \
125
+ --repo your-username/vine-model \
126
+ --login
127
+ ```
128
+
129
+ ## 🛠️ Installation & Setup
130
+
131
+ 1. **Install Dependencies:**
132
+ ```bash
133
+ pip install transformers torch torchvision opencv-python pillow numpy
134
+ ```
135
+
136
+ 2. **Install Segmentation Models (Optional):**
137
+ - SAM2: https://github.com/facebookresearch/sam2
138
+ - Grounding DINO: https://github.com/IDEA-Research/GroundingDINO
139
+
140
+ 3. **Install VINE HF Interface:**
141
+ ```bash
142
+ cd vine_hf
143
+ pip install -e .
144
+ ```
145
+
146
+ ## 🎯 Configuration Options
147
+
148
+ The `VineConfig` class supports extensive configuration:
149
+
150
+ - **Model Settings:** CLIP backbone, hidden dimensions
151
+ - **Segmentation:** Method, thresholds, target FPS
152
+ - **Processing:** Alpha values, top-k results, video length limits
153
+ - **Performance:** Multi-class mode, output format options
154
+
155
+ ## 📊 Output Format
156
+
157
+ The interface returns structured predictions:
158
+
159
+ ```python
160
+ {
161
+ "categorical_predictions": {obj_id: [(prob, category), ...]},
162
+ "unary_predictions": {(frame, obj): [(prob, action), ...]},
163
+ "binary_predictions": {(frame, pair): [(prob, relation), ...]},
164
+ "confidence_scores": {"categorical": float, "unary": float, "binary": float},
165
+ "summary": {
166
+ "num_objects_detected": int,
167
+ "top_categories": [(category, prob), ...],
168
+ "top_actions": [(action, prob), ...],
169
+ "top_relations": [(relation, prob), ...]
170
+ }
171
+ }
172
+ ```
173
+
174
+ ## 🔍 Testing & Validation
175
+
176
+ Run the example scripts to test your setup:
177
+
178
+ ```bash
179
+ # Test basic functionality
180
+ python vine_hf/example_usage.py
181
+
182
+ # Test migration from original code
183
+ python vine_hf/convert_inference.py
184
+ ```
185
+
186
+ ## 🤝 Contributing
187
+
188
+ To contribute or customize:
189
+
190
+ 1. **Modify Configuration:** Edit `vine_config.py` for new parameters
191
+ 2. **Extend Model:** Add functionality to `vine_model.py`
192
+ 3. **Enhance Pipeline:** Improve preprocessing/postprocessing in `vine_pipeline.py`
193
+ 4. **Add Features:** Create additional utility scripts
194
+
195
+ ## 📝 Next Steps
196
+
197
+ 1. **Load Your Weights:** Use your trained VINE model weights
198
+ 2. **Test Segmentation:** Set up Grounding DINO and SAM2 models
199
+ 3. **Validate Results:** Compare with original inference.py output
200
+ 4. **Share Model:** Push to HuggingFace Hub for community use
201
+ 5. **Deploy:** Use in applications, demos, or research projects
202
+
203
+ ## 🐛 Troubleshooting
204
+
205
+ **Common Issues:**
206
+ - **Import Errors:** Check PYTHONPATH and package installation
207
+ - **Segmentation Failures:** Verify Grounding DINO/SAM2 setup
208
+ - **Weight Loading:** Adjust weight loading logic in `convert_inference.py`
209
+ - **CUDA Issues:** Check GPU availability and PyTorch installation
210
+
211
+ **Support:**
212
+ - Check the README.md for detailed documentation
213
+ - Review example_usage.py for working code examples
214
+ - Examine convert_inference.py for migration guidance
215
+
216
+ ---
217
+
218
+ This HuggingFace interface makes VINE accessible to the broader ML community while maintaining all the powerful video understanding capabilities of the original model. The standardized interface enables easy sharing, deployment, and integration with existing HuggingFace workflows.
vine_hf/README.md ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VINE HuggingFace Interface
2
+
3
+ VINE (Video Understanding with Natural Language) is a model that processes videos along with categorical, unary, and binary keywords to return probability distributions over those keywords for detected objects and their relationships.
4
+
5
+ This package provides a HuggingFace-compatible interface for the VINE model, making it easy to use for video understanding tasks.
6
+
7
+ ## Features
8
+
9
+ - **Categorical Classification**: Classify objects in videos (e.g., "human", "dog", "frisbee")
10
+ - **Unary Predicates**: Detect actions on single objects (e.g., "running", "jumping", "sitting")
11
+ - **Binary Relations**: Detect relationships between object pairs (e.g., "behind", "in front of", "chasing")
12
+ - **Multiple Segmentation Methods**: Support for SAM2 and Grounding DINO + SAM2
13
+ - **HuggingFace Integration**: Full compatibility with HuggingFace transformers and pipelines
14
+ - **Visualization Hooks**: Optional high-level visualizations plus lightweight debug mask dumps for quick sanity checks
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ # Install the package (assuming it's in your Python path)
20
+ pip install transformers torch torchvision
21
+ pip install opencv-python pillow numpy
22
+
23
+ # For segmentation functionality, you'll also need:
24
+ # - SAM2: https://github.com/facebookresearch/sam2
25
+ # - Grounding DINO: https://github.com/IDEA-Research/GroundingDINO
26
+ ```
27
+
28
+ ## Segmentation Model Configuration
29
+
30
+ `VinePipeline` lazily brings up the segmentation stack the first time a call needs masks. Thresholds, FPS, visualization toggles, and device selection live in `VineConfig`; the pipeline constructor tells it where to fetch SAM2 / GroundingDINO weights or lets you inject already-instantiated modules.
31
+
32
+ ### Provide file paths at construction (most common)
33
+
34
+ ```python
35
+ from vine_hf import VineConfig, VineModel, VinePipeline
36
+
37
+ vine_config = VineConfig(
38
+ segmentation_method="grounding_dino_sam2", # or "sam2"
39
+ box_threshold=0.35,
40
+ text_threshold=0.25,
41
+ target_fps=5,
42
+ visualization_dir="output/visualizations", # where to write visualizations (and debug visualizations if enabled)
43
+ debug_visualizations=True, # Write videos of the groundingDINO/SAM2/Binary/Unary, etc... outputs
44
+ pretrained_vine_path="/abs/path/to/laser_model_v1.pkl",
45
+ device="cuda:0", # accepts int, str, or torch.device
46
+ )
47
+
48
+ vine_model = VineModel(vine_config)
49
+
50
+ vine_pipeline = VinePipeline(
51
+ model=vine_model,
52
+ tokenizer=None,
53
+ sam_config_path="/abs/path/to/sam2/sam2.1_hiera_t.yaml",
54
+ sam_checkpoint_path="/abs/path/to/sam2/sam2_hiera_tiny.pt",
55
+ gd_config_path="/abs/path/to/groundingdino/config/GroundingDINO_SwinT_OGC.py",
56
+ gd_checkpoint_path="/abs/path/to/groundingdino/weights/groundingdino_swint_ogc.pth",
57
+ device=vine_config._device,
58
+ )
59
+ ```
60
+
61
+ When `segmentation_method="grounding_dino_sam2"`, both SAM2 and GroundingDINO must be reachable. The pipeline validates the paths; missing files raise a `ValueError`. If you pick `"sam2"`, only the SAM2 config and checkpoint are required.
62
+
63
+ ### Reuse pre-initialized segmentation modules
64
+
65
+ If you build the segmentation stack elsewhere, inject the components with `set_segmentation_models` before running the pipeline:
66
+
67
+ ```python
68
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
69
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
70
+ from groundingdino.util.inference import Model as GroundingDINOModel
71
+
72
+ sam_predictor = build_sam2_video_predictor(..., device=vine_config._device)
73
+ mask_generator = SAM2AutomaticMaskGenerator(build_sam2(..., device=vine_config._device))
74
+ grounding_model = GroundingDINOModel(..., device=vine_config._device)
75
+
76
+ vine_pipeline.set_segmentation_models(
77
+ sam_predictor=sam_predictor,
78
+ mask_generator=mask_generator,
79
+ grounding_model=grounding_model,
80
+ )
81
+ ```
82
+
83
+ Any argument left as `None` is initialized lazily from the file paths when the pipeline first needs that backend.
84
+
85
+ ## Quick Start
86
+
87
+ ## Requirements
88
+ -torch
89
+ -torchvision
90
+ -transformers
91
+ -opencv-python
92
+ -matplotlib
93
+ -seaborn
94
+ -pandas
95
+ -numpy
96
+ -ipywidgets
97
+ -tqdm
98
+ -scikit-learn
99
+ -sam2 (from Facebook Research) "https://github.com/video-fm/video-sam2"
100
+ -sam2 weights (downloaded separately. EX: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)
101
+ -groundingdino (from IDEA Research)
102
+ -groundingdino weights (downloaded separately. EX:https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth)
103
+ -spacy-fastlang
104
+ -en-core-web-sm (for spacy-fastlang)
105
+ -ffmpeg (for video processing)
106
+ -(optional) laser weights/full model checkpoint (downloaded separately. EX: https://huggingface.co/video-fm/vine_v0)
107
+
108
+ Usually, by running the laser/environments/laser_env.yml from the LASER repo, most dependencies will be installed. You will need to manually install sam2 and groundingdino as per their instructions.
109
+
110
+ ### Using the Pipeline (Recommended)
111
+ ```python
112
+ from transformers.pipelines import PIPELINE_REGISTRY
113
+ from vine_hf import VineConfig, VineModel, VinePipeline
114
+
115
+ PIPELINE_REGISTRY.register_pipeline(
116
+ "vine-video-understanding",
117
+ pipeline_class=VinePipeline,
118
+ pt_model=VineModel,
119
+ type="multimodal",
120
+ )
121
+
122
+ config = VineConfig(
123
+ segmentation_method="grounding_dino_sam2",
124
+ pretrained_vine_path="/abs/path/to/laser_model_v1.pkl",
125
+ visualization_dir="output",
126
+ visualize=True,
127
+ device="cuda:0",
128
+ )
129
+
130
+ model = VineModel(config)
131
+
132
+ vine_pipeline = VinePipeline(
133
+ model=model,
134
+ tokenizer=None,
135
+ sam_config_path="/abs/path/to/sam2/sam2.1_hiera_t.yaml",
136
+ sam_checkpoint_path="/abs/path/to/sam2/sam2_hiera_tiny.pt",
137
+ gd_config_path="/abs/path/to/groundingdino/config/GroundingDINO_SwinT_OGC.py",
138
+ gd_checkpoint_path="/abs/path/to/groundingdino/weights/groundingdino_swint_ogc.pth",
139
+ device=config._device,
140
+ )
141
+
142
+ results = vine_pipeline(
143
+ "/path/to/video.mp4",
144
+ categorical_keywords=["dog", "human"],
145
+ unary_keywords=["running"],
146
+ binary_keywords=["chasing"],
147
+ object_pairs=[(0, 1)],
148
+ return_top_k=3,
149
+ include_visualizations=True,
150
+ )
151
+ print(results["summary"])
152
+ ```
153
+
154
+ ### Using the Model Directly (Advanced)
155
+
156
+ For advanced users who want to provide their own segmentation:
157
+
158
+ ```python
159
+ from vine_hf import VineConfig, VineModel
160
+ import torch
161
+
162
+ # Create configuration
163
+ config = VineConfig(
164
+ pretrained_vine_path="/path/to/your/vine/weights" # Optional: your fine-tuned weights
165
+ )
166
+
167
+ # Initialize model
168
+ model = VineModel(config)
169
+
170
+ # If you have your own video frames, masks, and bboxes from external segmentation
171
+ video_frames = torch.randn(3, 224, 224, 3) * 255 # Your video frames
172
+ masks = {0: {1: torch.ones(224, 224, 1)}} # Your segmentation masks
173
+ bboxes = {0: {1: [50, 50, 150, 150]}} # Your bounding boxes
174
+
175
+ # Run prediction
176
+ results = model.predict(
177
+ video_frames=video_frames,
178
+ masks=masks,
179
+ bboxes=bboxes,
180
+ categorical_keywords=['human', 'dog', 'frisbee'],
181
+ unary_keywords=['running', 'jumping'],
182
+ binary_keywords=['chasing', 'following'],
183
+ object_pairs=[(1, 2)],
184
+ return_top_k=3
185
+ )
186
+ ```
187
+
188
+ **Note**: For most users, the pipeline approach above is recommended as it handles video loading and segmentation automatically.
189
+
190
+ ## Configuration Options
191
+
192
+ The `VineConfig` class supports the following parameters (non-exhaustive):
193
+
194
+ - `model_name`: CLIP model backbone (default: `"openai/clip-vit-large-patch14-336"`)
195
+ - `pretrained_vine_path`: Optional path or Hugging Face repo with pretrained VINE weights
196
+ - `segmentation_method`: `"sam2"` or `"grounding_dino_sam2"` (default: `"grounding_dino_sam2"`)
197
+ - `box_threshold` / `text_threshold`: Grounding DINO thresholds
198
+ - `target_fps`: Target FPS for video processing (default: `1`)
199
+ - `alpha`, `white_alpha`: Rendering parameters used when extracting masked crops
200
+ - `topk_cate`: Top-k categories to return per object (default: `3`)
201
+ - `max_video_length`: Maximum frames to process (default: `100`)
202
+ - `visualize`: When `True`, pipeline post-processing attempts to create stitched visualizations
203
+ - `visualization_dir`: Optional base directory where visualization assets are written
204
+ - `debug_visualizations`: When `True`, the model saves a single first-frame mask composite for quick inspection
205
+ - `debug_visualization_path`: Target filepath for the debug mask composite (must point to a writable file)
206
+ - `return_flattened_segments`, `return_valid_pairs`, `interested_object_pairs`: Advanced geometry outputs for downstream consumers
207
+
208
+ ## Output Format
209
+
210
+ The model returns a dictionary with the following structure:
211
+
212
+ ```python
213
+ {
214
+ "masks" : {},
215
+
216
+ "boxes" : {},
217
+
218
+ "categorical_predictions": {
219
+ object_id: [(probability, category), ...]
220
+ },
221
+ "unary_predictions": {
222
+ (frame_id, object_id): [(probability, action), ...]
223
+ },
224
+ "binary_predictions": {
225
+ (frame_id, (obj1_id, obj2_id)): [(probability, relation), ...]
226
+ },
227
+ "confidence_scores": {
228
+ "categorical": max_categorical_confidence,
229
+ "unary": max_unary_confidence,
230
+ "binary": max_binary_confidence
231
+ },
232
+ "summary": {
233
+ "num_objects_detected": int,
234
+ "top_categories": [(category, probability), ...],
235
+ "top_actions": [(action, probability), ...],
236
+ "top_relations": [(relation, probability), ...]
237
+ }
238
+ }
239
+ ```
240
+
241
+ ## Visualization & Debugging
242
+
243
+ There are two complementary visualization layers:
244
+
245
+ - **Post-process visualizations** (`include_visualizations=True` in the pipeline call) produces a high-level stitched video summarizing detections, actions, and relations over time.
246
+
247
+ - **Debug visualizations** (`debug_visualizations=True` in `VineConfig`) dumps videos of intermediate segmentation masks and outputs from GroundingDINO, SAM2, Unary, Binary, etc. for quick sanity checks.
248
+
249
+ If you plan to enable either option, ensure the relevant output directories exist before running the pipeline.
250
+
251
+ ## Segmentation Methods
252
+
253
+ ### Grounding DINO + SAM2 (Recommended)
254
+
255
+ Uses Grounding DINO for object detection based on text prompts, then SAM2 for precise segmentation.
256
+
257
+ Requirements:
258
+ - Grounding DINO model and weights
259
+ - SAM2 model and weights
260
+ - Properly configured paths to model checkpoints
261
+
262
+ ### SAM2 Only
263
+
264
+ Uses SAM2's automatic mask generation without text-based object detection.
265
+
266
+ Requirements:
267
+ - SAM2 model and weights
268
+
269
+ ## Model Architecture
270
+
271
+ VINE is built on top of CLIP and uses three separate CLIP models for different tasks:
272
+ - **Categorical Model**: For object classification
273
+ - **Unary Model**: For single-object action recognition
274
+ - **Binary Model**: For relationship detection between object pairs
275
+
276
+ Each model processes both visual and textual features to compute similarity scores and probability distributions.
277
+
278
+ ## Pushing to HuggingFace Hub
279
+
280
+ ```python
281
+ from vine_hf import VineConfig, VineModel
282
+
283
+ # Create and configure your model
284
+ config = VineConfig()
285
+ model = VineModel(config)
286
+
287
+ # Load your pretrained weights
288
+ # model.load_state_dict(torch.load('path/to/your/weights.pth'))
289
+
290
+ # Register for auto classes
291
+ config.register_for_auto_class()
292
+ model.register_for_auto_class("AutoModel")
293
+
294
+ # Push to Hub
295
+ config.push_to_hub('your-username/vine-model')
296
+ model.push_to_hub('your-username/vine-model')
297
+ ```
298
+
299
+ ## Loading from HuggingFace Hub
300
+
301
+ ```python
302
+ from transformers import AutoModel, pipeline
303
+
304
+ # Load model
305
+ model = AutoModel.from_pretrained('your-username/vine-model', trust_remote_code=True)
306
+
307
+ # Or use with pipeline
308
+ vine_pipeline = pipeline(
309
+ 'vine-video-understanding',
310
+ model='your-username/vine-model',
311
+ trust_remote_code=True
312
+ )
313
+ ```
314
+
315
+ ## Examples
316
+
317
+ See `example_usage.py` for comprehensive examples including:
318
+ - Direct model usage
319
+ - Pipeline usage
320
+ - HuggingFace Hub integration
321
+ - Real video processing
322
+
323
+ ## Requirements
324
+
325
+ - Python 3.7+
326
+ - PyTorch 1.9+
327
+ - transformers 4.20+
328
+ - OpenCV
329
+ - PIL/Pillow
330
+ - NumPy
331
+
332
+ For segmentation:
333
+ - SAM2 (Facebook Research)
334
+ - Grounding DINO (IDEA Research)
335
+
336
+ ## Citation
337
+
338
+ If you use VINE in your research, please cite:
339
+
340
+ ```bibtex
341
+ @article{vine2024,
342
+ title={VINE: Video Understanding with Natural Language},
343
+ author={Your Authors},
344
+ journal={Your Journal},
345
+ year={2024}
346
+ }
347
+ ```
348
+
349
+ ## License
350
+
351
+ [Your License Here]
352
+
353
+ ## Contact
354
+
355
+ [Your Contact Information Here]
vine_hf/README_HF.md ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VINE: Video Understanding with Natural Language
2
+
3
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-video--fm%2Fvine-blue)](https://huggingface.co/video-fm/vine)
4
+ [![GitHub](https://img.shields.io/badge/GitHub-LASER-green)](https://github.com/kevinxuez/LASER)
5
+
6
+ VINE is a video understanding model that processes videos along with categorical, unary, and binary keywords to return probability distributions over those keywords for detected objects and their relationships.
7
+
8
+ ## Quick Start
9
+
10
+ ```python
11
+ from transformers import AutoModel
12
+ from vine_hf import VineConfig, VineModel, VinePipeline
13
+
14
+ # Load VINE model from HuggingFace
15
+ model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
16
+
17
+ # Create pipeline with your checkpoint paths
18
+ vine_pipeline = VinePipeline(
19
+ model=model,
20
+ tokenizer=None,
21
+ sam_config_path="/path/to/sam2_config.yaml",
22
+ sam_checkpoint_path="/path/to/sam2_checkpoint.pt",
23
+ gd_config_path="/path/to/grounding_dino_config.py",
24
+ gd_checkpoint_path="/path/to/grounding_dino_checkpoint.pth",
25
+ device="cuda",
26
+ trust_remote_code=True
27
+ )
28
+
29
+ # Process a video
30
+ results = vine_pipeline(
31
+ 'path/to/video.mp4',
32
+ categorical_keywords=['human', 'dog', 'frisbee'],
33
+ unary_keywords=['running', 'jumping'],
34
+ binary_keywords=['chasing', 'behind'],
35
+ return_top_k=3
36
+ )
37
+ ```
38
+
39
+ ## Installation
40
+
41
+ ### Option 1: Automated Setup (Recommended)
42
+
43
+ ```bash
44
+ # Download the setup script
45
+ wget https://raw.githubusercontent.com/kevinxuez/vine_hf/main/setup_vine_demo.sh
46
+
47
+ # Run the setup
48
+ bash setup_vine_demo.sh
49
+
50
+ # Activate environment
51
+ conda activate vine_demo
52
+ ```
53
+
54
+ ### Option 2: Manual Installation
55
+
56
+ ```bash
57
+ # 1. Create conda environment
58
+ conda create -n vine_demo python=3.10 -y
59
+ conda activate vine_demo
60
+
61
+ # 2. Install PyTorch with CUDA support
62
+ pip install torch==2.7.1 torchvision==0.22.1 --index-url https://download.pytorch.org/whl/cu126
63
+
64
+ # 3. Install core dependencies
65
+ pip install transformers huggingface-hub safetensors
66
+
67
+ # 4. Clone and install required repositories
68
+ git clone https://github.com/video-fm/video-sam2.git
69
+ git clone https://github.com/video-fm/GroundingDINO.git
70
+ git clone https://github.com/kevinxuez/LASER.git
71
+ git clone https://github.com/kevinxuez/vine_hf.git
72
+
73
+ # Install in editable mode
74
+ pip install -e ./video-sam2
75
+ pip install -e ./GroundingDINO
76
+ pip install -e ./LASER
77
+ pip install -e ./vine_hf
78
+
79
+ # Build GroundingDINO extensions
80
+ cd GroundingDINO && python setup.py build_ext --force --inplace && cd ..
81
+ ```
82
+
83
+ ## Required Checkpoints
84
+
85
+ VINE requires SAM2 and GroundingDINO checkpoints for segmentation. Download these separately:
86
+
87
+ ### SAM2 Checkpoint
88
+ ```bash
89
+ wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt
90
+ wget https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
91
+ ```
92
+
93
+ ### GroundingDINO Checkpoint
94
+ ```bash
95
+ wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
96
+ wget https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py
97
+ ```
98
+
99
+ ## Architecture
100
+
101
+ ```
102
+ video-fm/vine (HuggingFace Hub)
103
+ ├── VINE Model Weights (~1.8GB)
104
+ │ ├── Categorical CLIP model (fine-tuned)
105
+ │ ├── Unary CLIP model (fine-tuned)
106
+ │ └── Binary CLIP model (fine-tuned)
107
+ └── Architecture Files
108
+ ├── vine_config.py
109
+ ├── vine_model.py
110
+ ├── vine_pipeline.py
111
+ └── utilities
112
+
113
+ User Provides:
114
+ ├── Dependencies (via pip/conda)
115
+ │ ├── laser (video processing utilities)
116
+ │ ├── sam2 (segmentation)
117
+ │ └── groundingdino (object detection)
118
+ └── Checkpoints (downloaded separately)
119
+ ├── SAM2 model files
120
+ └── GroundingDINO model files
121
+ ```
122
+
123
+ ## Why This Architecture?
124
+
125
+ This separation of concerns provides several benefits:
126
+
127
+ 1. **Lightweight Distribution**: Only VINE-specific weights (~1.8GB) are on HuggingFace
128
+ 2. **Version Control**: Users can choose their preferred SAM2/GroundingDINO versions
129
+ 3. **Licensing**: Keeps different model licenses separate
130
+ 4. **Flexibility**: Easy to swap segmentation backends
131
+ 5. **Standard Practice**: Similar to models like LLaVA, BLIP-2, etc.
132
+
133
+ ## Full Usage Example
134
+
135
+ ```python
136
+ import os
137
+ from pathlib import Path
138
+ from transformers import AutoModel
139
+ from vine_hf import VinePipeline
140
+
141
+ # Set up paths
142
+ checkpoint_dir = Path("/path/to/checkpoints")
143
+ sam_config = checkpoint_dir / "sam2_hiera_t.yaml"
144
+ sam_checkpoint = checkpoint_dir / "sam2_hiera_tiny.pt"
145
+ gd_config = checkpoint_dir / "GroundingDINO_SwinT_OGC.py"
146
+ gd_checkpoint = checkpoint_dir / "groundingdino_swint_ogc.pth"
147
+
148
+ # Load VINE from HuggingFace
149
+ model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
150
+
151
+ # Create pipeline
152
+ vine_pipeline = VinePipeline(
153
+ model=model,
154
+ tokenizer=None,
155
+ sam_config_path=str(sam_config),
156
+ sam_checkpoint_path=str(sam_checkpoint),
157
+ gd_config_path=str(gd_config),
158
+ gd_checkpoint_path=str(gd_checkpoint),
159
+ device="cuda:0",
160
+ trust_remote_code=True
161
+ )
162
+
163
+ # Process video
164
+ results = vine_pipeline(
165
+ "path/to/video.mp4",
166
+ categorical_keywords=['person', 'dog', 'ball'],
167
+ unary_keywords=['running', 'jumping', 'sitting'],
168
+ binary_keywords=['chasing', 'next to', 'holding'],
169
+ object_pairs=[(0, 1), (0, 2)], # person-dog, person-ball
170
+ return_top_k=5,
171
+ include_visualizations=True
172
+ )
173
+
174
+ # Access results
175
+ print(f"Detected {results['summary']['num_objects_detected']} objects")
176
+ print(f"Top categories: {results['summary']['top_categories']}")
177
+ print(f"Top actions: {results['summary']['top_actions']}")
178
+ print(f"Top relations: {results['summary']['top_relations']}")
179
+
180
+ # Access detailed predictions
181
+ for obj_id, predictions in results['categorical_predictions'].items():
182
+ print(f"\nObject {obj_id}:")
183
+ for prob, category in predictions:
184
+ print(f" {category}: {prob:.3f}")
185
+ ```
186
+
187
+ ## Output Format
188
+
189
+ ```python
190
+ {
191
+ "categorical_predictions": {
192
+ object_id: [(probability, category), ...]
193
+ },
194
+ "unary_predictions": {
195
+ (frame_id, object_id): [(probability, action), ...]
196
+ },
197
+ "binary_predictions": {
198
+ (frame_id, (obj1_id, obj2_id)): [(probability, relation), ...]
199
+ },
200
+ "confidence_scores": {
201
+ "categorical": float,
202
+ "unary": float,
203
+ "binary": float
204
+ },
205
+ "summary": {
206
+ "num_objects_detected": int,
207
+ "top_categories": [(category, probability), ...],
208
+ "top_actions": [(action, probability), ...],
209
+ "top_relations": [(relation, probability), ...]
210
+ },
211
+ "visualizations": { # if include_visualizations=True
212
+ "vine": {
213
+ "all": {"frames": [...], "video_path": "..."},
214
+ ...
215
+ }
216
+ }
217
+ }
218
+ ```
219
+
220
+ ## Configuration Options
221
+
222
+ ```python
223
+ from vine_hf import VineConfig
224
+
225
+ config = VineConfig(
226
+ model_name="openai/clip-vit-base-patch32", # CLIP backbone
227
+ segmentation_method="grounding_dino_sam2", # or "sam2"
228
+ box_threshold=0.35, # GroundingDINO threshold
229
+ text_threshold=0.25, # GroundingDINO threshold
230
+ target_fps=5, # Video sampling rate
231
+ visualize=True, # Enable visualizations
232
+ visualization_dir="outputs/", # Output directory
233
+ debug_visualizations=False, # Debug mode
234
+ device="cuda:0" # Device
235
+ )
236
+ ```
237
+
238
+ ## Deployment Examples
239
+
240
+ ### Local Script
241
+ ```python
242
+ # test_vine.py
243
+ from transformers import AutoModel
244
+ from vine_hf import VinePipeline
245
+
246
+ model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
247
+ pipeline = VinePipeline(model=model, ...)
248
+ results = pipeline("video.mp4", ...)
249
+ ```
250
+
251
+ ### HuggingFace Spaces
252
+ ```python
253
+ # app.py for Gradio Space
254
+ import gradio as gr
255
+ from transformers import AutoModel
256
+ from vine_hf import VinePipeline
257
+
258
+ model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
259
+ # ... set up pipeline and Gradio interface
260
+ ```
261
+
262
+ ### API Server
263
+ ```python
264
+ # FastAPI server
265
+ from fastapi import FastAPI
266
+ from transformers import AutoModel
267
+ from vine_hf import VinePipeline
268
+
269
+ app = FastAPI()
270
+ model = AutoModel.from_pretrained('video-fm/vine', trust_remote_code=True)
271
+ pipeline = VinePipeline(model=model, ...)
272
+
273
+ @app.post("/process")
274
+ async def process_video(video_path: str):
275
+ return pipeline(video_path, ...)
276
+ ```
277
+
278
+ ## Troubleshooting
279
+
280
+ ### Import Errors
281
+ ```bash
282
+ # Make sure all dependencies are installed
283
+ pip list | grep -E "laser|sam2|groundingdino"
284
+
285
+ # Reinstall if needed
286
+ pip install -e ./LASER
287
+ pip install -e ./video-sam2
288
+ pip install -e ./GroundingDINO
289
+ ```
290
+
291
+ ### CUDA Errors
292
+ ```python
293
+ # Check CUDA availability
294
+ import torch
295
+ print(torch.cuda.is_available())
296
+ print(torch.version.cuda)
297
+
298
+ # Use CPU if needed
299
+ pipeline = VinePipeline(model=model, device="cpu", ...)
300
+ ```
301
+
302
+ ### Checkpoint Not Found
303
+ ```bash
304
+ # Verify checkpoint paths
305
+ ls -lh /path/to/sam2_hiera_tiny.pt
306
+ ls -lh /path/to/groundingdino_swint_ogc.pth
307
+ ```
308
+
309
+ ## System Requirements
310
+
311
+ - **Python**: 3.10+
312
+ - **CUDA**: 11.8+ (for GPU)
313
+ - **GPU**: 8GB+ VRAM recommended (T4, V100, A100, etc.)
314
+ - **RAM**: 16GB+ recommended
315
+ - **Storage**: ~3GB for checkpoints
316
+
317
+ ## Citation
318
+
319
+ ```bibtex
320
+ @article{laser2024,
321
+ title={LASER: Language-guided Object Grounding and Relation Understanding in Videos},
322
+ author={Your Authors},
323
+ journal={Your Conference/Journal},
324
+ year={2024}
325
+ }
326
+ ```
327
+
328
+ ## License
329
+
330
+ This model and code are released under the MIT License. Note that SAM2 and GroundingDINO have their own respective licenses.
331
+
332
+ ## Links
333
+
334
+ - **Model**: https://huggingface.co/video-fm/vine
335
+ - **Code**: https://github.com/kevinxuez/LASER
336
+ - **vine_hf Package**: https://github.com/kevinxuez/vine_hf
337
+ - **SAM2**: https://github.com/facebookresearch/sam2
338
+ - **GroundingDINO**: https://github.com/IDEA-Research/GroundingDINO
339
+
340
+ ## Support
341
+
342
+ For issues or questions:
343
+ - **Model/Architecture**: [HuggingFace Discussions](https://huggingface.co/video-fm/vine/discussions)
344
+ - **LASER Framework**: [GitHub Issues](https://github.com/kevinxuez/LASER/issues)
345
+ - **vine_hf Package**: [GitHub Issues](https://github.com/kevinxuez/vine_hf/issues)
vine_hf/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VINE HuggingFace Interface
3
+
4
+ VINE (Video Understanding with Natural Language) is a model that processes videos
5
+ along with categorical, unary, and binary keywords to return probability
6
+ distributions over those keywords for detected objects and their relationships.
7
+
8
+ This package provides a HuggingFace-compatible interface for the VINE model,
9
+ including configuration, model, and pipeline classes.
10
+ """
11
+
12
+ from .vine_config import VineConfig
13
+ from .vine_model import VineModel
14
+ from .vine_pipeline import VinePipeline
15
+
16
+ __version__ = "1.0.0"
17
+ __author__ = "LASER Team"
18
+
19
+ __all__ = [
20
+ "VineConfig",
21
+ "VineModel",
22
+ "VinePipeline"
23
+ ]
vine_hf/convert_inference.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to convert existing inference.py workflow to use VINE HuggingFace interface
3
+
4
+ This script demonstrates how to migrate from the original inference.py approach
5
+ to the new HuggingFace-compatible interface.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ import numpy as np
12
+ from typing import Dict, List, Tuple, Any
13
+
14
+ # Add paths for imports
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+
17
+ from vine_hf import VineConfig, VineModel, VinePipeline
18
+ from laser.loading import load_video
19
+
20
+
21
+ def load_pretrained_vine_model(model_dir: str, model_name: str, epoch: int = 0) -> VineModel:
22
+ """
23
+ Load a pretrained VINE model from the original format into HuggingFace format.
24
+
25
+ Args:
26
+ model_dir: Directory containing the model
27
+ model_name: Name of the model file (without .{epoch}.model extension)
28
+ epoch: Epoch number to load
29
+
30
+ Returns:
31
+ VineModel instance with loaded weights
32
+ """
33
+ print(f"Loading pretrained VINE model from {model_dir}")
34
+
35
+ # Create configuration (adjust parameters as needed)
36
+ # We expect local ensemble weights in `model_dir`, so configure
37
+ # VineConfig to load from local directory/filename.
38
+ model_file = f"{model_name}.{epoch}.model"
39
+ config = VineConfig(
40
+ model_name="openai/clip-vit-base-patch32",
41
+ segmentation_method="grounding_dino_sam2",
42
+ target_fps=1,
43
+ box_threshold=0.35,
44
+ text_threshold=0.25,
45
+ use_hf_repo=False,
46
+ local_dir=model_dir,
47
+ local_filename=model_file,
48
+ )
49
+
50
+ # Initialize model (VineModel will consult the config when loading)
51
+ vine_model = VineModel(config)
52
+
53
+ # Load original weights
54
+ model_file = f"{model_name}.{epoch}.model"
55
+ model_path = os.path.join(model_dir, model_file)
56
+
57
+ if os.path.exists(model_path):
58
+ print(f"Loading weights from: {model_path}")
59
+ try:
60
+ # Add safe globals for PyTorch 2.6+
61
+ import torch.serialization
62
+ from laser.models.llava_clip_model_v3 import PredicateModel
63
+ torch.serialization.add_safe_globals([PredicateModel])
64
+
65
+ # Load the original model
66
+ original_model = torch.load(model_path, map_location='cpu', weights_only=False)
67
+
68
+ # Transfer weights to HuggingFace model
69
+ # This assumes the original model has the same structure
70
+ # You may need to adjust this based on your specific model structure
71
+
72
+ if hasattr(original_model, 'clip_cate_model'):
73
+ vine_model.clip_cate_model.load_state_dict(original_model.clip_cate_model.state_dict())
74
+ if hasattr(original_model, 'clip_unary_model'):
75
+ vine_model.clip_unary_model.load_state_dict(original_model.clip_unary_model.state_dict())
76
+ if hasattr(original_model, 'clip_binary_model'):
77
+ vine_model.clip_binary_model.load_state_dict(original_model.clip_binary_model.state_dict())
78
+ if hasattr(original_model, 'clip_tokenizer'):
79
+ vine_model.clip_tokenizer = original_model.clip_tokenizer
80
+ if hasattr(original_model, 'clip_processor'):
81
+ vine_model.clip_processor = original_model.clip_processor
82
+
83
+ print("✓ Weights transferred successfully")
84
+
85
+ except Exception as e:
86
+ print(f"✗ Error loading weights: {e}")
87
+ print("You may need to adjust the weight loading logic for your specific model")
88
+
89
+ else:
90
+ print(f"✗ Model file not found: {model_path}")
91
+
92
+ return vine_model
93
+
94
+
95
+ def convert_inference_workflow():
96
+ """
97
+ Convert the original inference.py workflow to use HuggingFace interface.
98
+
99
+ This function demonstrates how to replicate the original inference workflow
100
+ using the new HuggingFace-compatible components.
101
+ """
102
+ print("=== Converting Inference Workflow ===")
103
+
104
+ # Original parameters from inference.py
105
+ video_id = 'v1'
106
+ target_fps = 1
107
+ classes = ['human', 'dog', 'frisbee']
108
+ unary_keywords = ['running', 'jumping', 'sitting', 'standing']
109
+ binary_keywords = ['behind', 'bite', 'front', 'jump over', 'right', 'left']
110
+
111
+ # Paths (adjust these to match your setup)
112
+ demo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../demo"))
113
+ video_dir = os.path.join(demo_dir, "videos")
114
+ video_path = os.path.join(video_dir, f"{video_id}.mp4")
115
+
116
+ # Model paths (adjust these to match your setup)
117
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
118
+ model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
119
+ model_name = "ensemble-2025-02-10-14-57-22"
120
+
121
+ # Segmentation model paths (adjust these to your actual paths)
122
+ sam_config_path = "/path/to/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml"
123
+ sam_checkpoint_path = "/path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
124
+ gd_config_path = "/path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py"
125
+ gd_checkpoint_path = "/path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth"
126
+
127
+ print(f"Video path: {video_path}")
128
+ print(f"Model dir: {model_dir}")
129
+ print(f"SAM2 config: {sam_config_path}")
130
+ print(f"GroundingDINO config: {gd_config_path}")
131
+
132
+ # Check if video exists
133
+ if not os.path.exists(video_path):
134
+ print(f"✗ Video not found: {video_path}")
135
+ print("Please adjust the video path or use your own video file")
136
+ return
137
+
138
+ # 1. Load video (same as original)
139
+ print(f"Loading video: {video_id}")
140
+ video_tensor = load_video(video_path, target_fps=target_fps)
141
+ print(f"Video shape: {video_tensor.shape}")
142
+
143
+ # 2. Load VINE model with HuggingFace interface
144
+ print("Loading VINE model...")
145
+ if os.path.exists(model_dir):
146
+ vine_model = load_pretrained_vine_model(model_dir, model_name, epoch=0)
147
+ else:
148
+ print(f"Model directory not found: {model_dir}")
149
+ print("Creating new model with random weights for demonstration")
150
+ config = VineConfig()
151
+ vine_model = VineModel(config)
152
+
153
+ # 3. Create pipeline for easier use
154
+ print("Creating VINE pipeline...")
155
+ from transformers.pipelines import PIPELINE_REGISTRY
156
+
157
+ # Register pipeline if not already registered
158
+ try:
159
+ PIPELINE_REGISTRY.register_pipeline(
160
+ "vine-video-understanding",
161
+ pipeline_class=VinePipeline,
162
+ pt_model=VineModel,
163
+ type="multimodal",
164
+ )
165
+ except Exception:
166
+ pass # Already registered
167
+
168
+ # Create pipeline instance with segmentation model paths
169
+ vine_pipeline = VinePipeline(
170
+ model=vine_model,
171
+ tokenizer=None,
172
+ # SAM2 configuration
173
+ sam_config_path=sam_config_path,
174
+ sam_checkpoint_path=sam_checkpoint_path,
175
+ # GroundingDINO configuration
176
+ gd_config_path=gd_config_path,
177
+ gd_checkpoint_path=gd_checkpoint_path
178
+ )
179
+
180
+ # 4. Process video with new interface
181
+ print("Processing video with VINE HuggingFace interface...")
182
+
183
+ try:
184
+ # Use the pipeline to process the video
185
+ results = vine_pipeline(
186
+ video_path,
187
+ categorical_keywords=classes,
188
+ unary_keywords=unary_keywords,
189
+ binary_keywords=binary_keywords,
190
+ object_pairs=[(1, 2), (2, 3)], # Example object pairs
191
+ segmentation_method='grounding_dino_sam2',
192
+ target_fps=target_fps,
193
+ return_top_k=3,
194
+ include_visualizations=False
195
+ )
196
+
197
+ # 5. Display results (similar to original format)
198
+ print("\n=== VINE Results (HuggingFace Interface) ===")
199
+
200
+ # Categorical predictions
201
+ print("\nCategorical Predictions:")
202
+ for obj_id, predictions in results['categorical_predictions'].items():
203
+ print(f" Object {obj_id}:")
204
+ for prob, category in predictions:
205
+ print(f" {prob:.3f}: {category}")
206
+
207
+ # Unary predictions
208
+ print("\nUnary Predictions:")
209
+ for (frame_id, obj_id), predictions in results['unary_predictions'].items():
210
+ print(f" Frame {frame_id}, Object {obj_id}:")
211
+ for prob, action in predictions:
212
+ print(f" {prob:.3f}: {action}")
213
+
214
+ # Binary predictions
215
+ print("\nBinary Predictions:")
216
+ for (frame_id, obj_pair), predictions in results['binary_predictions'].items():
217
+ print(f" Frame {frame_id}, Objects {obj_pair}:")
218
+ for prob, relation in predictions:
219
+ print(f" {prob:.3f}: {relation}")
220
+
221
+ # Summary
222
+ print(f"\nSummary:")
223
+ print(f" Objects detected: {results['summary']['num_objects_detected']}")
224
+ print(f" Top categories: {results['summary']['top_categories']}")
225
+ print(f" Top actions: {results['summary']['top_actions']}")
226
+ print(f" Top relations: {results['summary']['top_relations']}")
227
+
228
+ print("\n✓ Successfully processed video with VINE HuggingFace interface!")
229
+
230
+ except Exception as e:
231
+ print(f"✗ Error processing video: {e}")
232
+ print("This may be due to missing segmentation models or other dependencies")
233
+ print("The interface is set up correctly, but full functionality requires:")
234
+ print(" 1. Properly installed Grounding DINO and SAM2")
235
+ print(" 2. Correct model weights")
236
+ print(" 3. Proper configuration paths")
237
+
238
+
239
+ def compare_interfaces():
240
+ """
241
+ Compare the original inference.py approach with the new HuggingFace interface.
242
+ """
243
+ print("\n=== Interface Comparison ===")
244
+
245
+ print("\nOriginal inference.py approach:")
246
+ print("✓ Direct access to model internals")
247
+ print("✓ Full control over segmentation pipeline")
248
+ print("✗ Complex setup and configuration")
249
+ print("✗ Not compatible with HuggingFace ecosystem")
250
+ print("✗ Requires manual handling of all components")
251
+
252
+ print("\nNew HuggingFace interface:")
253
+ print("✓ Easy to use pipeline interface")
254
+ print("✓ Compatible with HuggingFace Hub")
255
+ print("✓ Standardized configuration")
256
+ print("✓ Automatic handling of preprocessing/postprocessing")
257
+ print("✓ Easy sharing and distribution")
258
+ print("✓ Configurable segmentation model paths")
259
+ print("✗ Slightly less direct control (can still access model directly)")
260
+
261
+ print("\nMigration benefits:")
262
+ print("• Share your model easily on HuggingFace Hub")
263
+ print("• Users can load your model with a single line")
264
+ print("• Standardized interface for video understanding")
265
+ print("• Better integration with other HuggingFace tools")
266
+ print("• Simplified deployment and inference")
267
+ print("• Flexible segmentation model configuration")
268
+
269
+
270
+ if __name__ == "__main__":
271
+ print("VINE HuggingFace Interface Conversion")
272
+ print("=" * 50)
273
+
274
+ # Run conversion demonstration
275
+ convert_inference_workflow()
276
+
277
+ # Show comparison
278
+ compare_interfaces()
279
+
280
+ print("\n" + "=" * 50)
281
+ print("Next steps:")
282
+ print("1. Install SAM2 and GroundingDINO dependencies")
283
+ print("2. Download the required model checkpoints")
284
+ print("3. Update the paths in this script to point to your models")
285
+ print("4. Test the interface with your specific model weights")
286
+ print("5. Adjust configuration parameters as needed")
287
+ print("6. Push your model to HuggingFace Hub using push_to_hub.py")
288
+ print("7. Share with the community!")
vine_hf/example_ensemble_weights.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example demonstrating how to load and use VINE ensemble weights
3
+
4
+ This script shows the correct way to load your pretrained VINE ensemble weights
5
+ and use them with the HuggingFace interface, based on the actual inference.py workflow.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ import numpy as np
12
+ from transformers.pipelines import PIPELINE_REGISTRY
13
+
14
+ #os.environ["OPENAI_API_KEY"]="dummy-key" # Set your OpenAI API key here or via environment variable
15
+
16
+ # Add the parent directory to the path to import vine_hf
17
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ from vine_hf import VineConfig, VineModel, VinePipeline
20
+ from laser.loading import load_video
21
+
22
+
23
+ def example_load_ensemble_weights():
24
+ """Example of loading ensemble weights correctly."""
25
+ print("=== Loading Ensemble VINE Weights ===")
26
+
27
+ # Path to your ensemble model (adjust this to your actual path)
28
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
29
+ model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
30
+
31
+ print(f"Looking for ensemble weights in: {model_dir}")
32
+
33
+ if os.path.exists(model_dir):
34
+ print("✓ Model directory found")
35
+
36
+ # List available model files
37
+ model_files = [f for f in os.listdir(model_dir) if f.endswith('.model')]
38
+ print(f"Available model files: {model_files}")
39
+
40
+ if model_files:
41
+ # Create configuration with ensemble path (local directory with .model files)
42
+ config = VineConfig(
43
+ segmentation_method="grounding_dino_sam2",
44
+ use_hf_repo=False,
45
+ local_dir=model_dir,
46
+ local_filename=None,
47
+ )
48
+
49
+ print("Creating VINE model with ensemble weights...")
50
+ vine_model = VineModel(config)
51
+
52
+ print("✓ VINE model created with ensemble weights!")
53
+ return vine_model
54
+ else:
55
+ print("✗ No .model files found in directory")
56
+ return None
57
+ else:
58
+ print(f"✗ Model directory not found: {model_dir}")
59
+ print("Please adjust the path to point to your ensemble weights")
60
+ return None
61
+
62
+
63
+ def example_direct_ensemble_loading():
64
+ """Example of loading ensemble weights using from_pretrained_vine."""
65
+ print("\n=== Direct Ensemble Loading ===")
66
+
67
+ # Path to specific ensemble file
68
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
69
+ model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
70
+
71
+ if os.path.exists(model_dir):
72
+ try:
73
+ # Use the class method for direct loading
74
+ vine_model = VineModel.from_pretrained_vine(
75
+ model_path=model_dir,
76
+ epoch=0 # Load epoch 0
77
+ )
78
+
79
+ print("✓ Model loaded using from_pretrained_vine!")
80
+ return vine_model
81
+
82
+ except Exception as e:
83
+ print(f"✗ Error loading with from_pretrained_vine: {e}")
84
+ return None
85
+ else:
86
+ print(f"✗ Model directory not found: {model_dir}")
87
+ return None
88
+
89
+
90
+ def example_compare_original_vs_hf():
91
+ """Compare the original inference.py approach with HuggingFace interface."""
92
+ print("\n=== Comparing Original vs HuggingFace Interface ===")
93
+
94
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
95
+ model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
96
+ model_name = "ensemble-2025-02-10-14-57-22"
97
+ epoch = 0
98
+
99
+ if not os.path.exists(model_dir):
100
+ print(f"Model directory not found: {model_dir}")
101
+ return
102
+
103
+ print("Original approach (from inference.py):")
104
+ print("```python")
105
+ print("def load_model(model_dir, model_name, epoch, device):")
106
+ print(" model_name = model_name + f'.{epoch}.model'")
107
+ print(" predicate_model = torch.load(os.path.join(model_dir, model_name), map_location=device, weights_only=False)")
108
+ print(" return predicate_model")
109
+ print("")
110
+ print("predicate_model = load_model(model_dir, model_name, epoch, device)")
111
+ print("```")
112
+
113
+ print("\nNew HuggingFace approach:")
114
+ print("```python")
115
+ print("config = VineConfig(pretrained_vine_path=model_dir)")
116
+ print("vine_model = VineModel(config)")
117
+ print("# or")
118
+ print("vine_model = VineModel.from_pretrained_vine(model_dir, epoch=0)")
119
+ print("```")
120
+
121
+ # Try to load with both approaches if possible
122
+ try:
123
+ # Original approach
124
+ def load_model(model_dir, model_name, epoch, device):
125
+ model_name = model_name + f'.{epoch}.model'
126
+ model_path = os.path.join(model_dir, model_name)
127
+ if os.path.exists(model_path):
128
+ return torch.load(model_path, map_location=device, weights_only=False)
129
+ else:
130
+ print(f"Model file not found: {model_path}")
131
+ return None
132
+
133
+ device = "cuda" if torch.cuda.is_available() else "cpu"
134
+ original_model = load_model(model_dir, model_name, epoch, device)
135
+
136
+ if original_model:
137
+ print(f"✓ Original model loaded: {type(original_model)}")
138
+ print(f" Has clip_cate_model: {hasattr(original_model, 'clip_cate_model')}")
139
+ print(f" Has clip_unary_model: {hasattr(original_model, 'clip_unary_model')}")
140
+ print(f" Has clip_binary_model: {hasattr(original_model, 'clip_binary_model')}")
141
+
142
+ # HuggingFace approach
143
+ vine_model = VineModel.from_pretrained_vine(model_dir, epoch=epoch)
144
+
145
+ if vine_model:
146
+ print(f"✓ HuggingFace model loaded: {type(vine_model)}")
147
+ print(f" Has clip_cate_model: {hasattr(vine_model, 'clip_cate_model')}")
148
+ print(f" Has clip_unary_model: {hasattr(vine_model, 'clip_unary_model')}")
149
+ print(f" Has clip_binary_model: {hasattr(vine_model, 'clip_binary_model')}")
150
+
151
+ print("\n✓ Both approaches work! HuggingFace interface successfully loads ensemble weights.")
152
+
153
+ except Exception as e:
154
+ print(f"Error in comparison: {e}")
155
+
156
+
157
+ def example_ensemble_with_pipeline():
158
+ """Example using ensemble weights with the pipeline."""
159
+ print("\n=== Using Ensemble Weights with Pipeline ===")
160
+
161
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
162
+ model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
163
+
164
+ if not os.path.exists(model_dir):
165
+ print(f"Model directory not found: {model_dir}")
166
+ return
167
+
168
+ # Register pipeline
169
+ PIPELINE_REGISTRY.register_pipeline(
170
+ "vine-video-understanding",
171
+ pipeline_class=VinePipeline,
172
+ pt_model=VineModel,
173
+ type="multimodal",
174
+ )
175
+
176
+ # Create model with ensemble weights (local directory)
177
+ config = VineConfig(
178
+ segmentation_method="grounding_dino_sam2",
179
+ use_hf_repo=False,
180
+ local_dir=model_dir,
181
+ local_filename=None,
182
+ )
183
+
184
+ vine_model = VineModel(config)
185
+ # Create pipeline with segmentation model paths
186
+ vine_pipeline = VinePipeline(
187
+ model=vine_model,
188
+ tokenizer=None,
189
+ # SAM2 configuration
190
+ sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
191
+ sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
192
+ # GroundingDINO configuration
193
+ gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
194
+ gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
195
+ device="cuda" if torch.cuda.is_available() else "cpu",
196
+ )
197
+
198
+ print("✓ Pipeline created with ensemble VINE weights")
199
+
200
+ # Check for demo video
201
+ demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/v1.mp4")
202
+
203
+ if os.path.exists(demo_video):
204
+ print(f"Found demo video: {demo_video}")
205
+
206
+ # Use the same keywords as in the original inference.py
207
+ categorical_keywords = ['human', 'dog', 'frisbee']
208
+ unary_keywords = ['running', 'jumping', 'catching', 'throwing']
209
+ binary_keywords = ['behind', 'bite', 'front', 'jump over', 'right', 'left']
210
+
211
+ print("Example pipeline usage:")
212
+ print("```python")
213
+ print("results = vine_pipeline(")
214
+ print(f" '{demo_video}',")
215
+ print(f" categorical_keywords={categorical_keywords},")
216
+ print(f" unary_keywords={unary_keywords},")
217
+ print(f" binary_keywords={binary_keywords},")
218
+ print(" segmentation_method='grounding_dino_sam2'")
219
+ print(")")
220
+ print("```")
221
+
222
+ # Uncomment to actually run (requires segmentation models)
223
+ # try:
224
+ # results = vine_pipeline(
225
+ # demo_video,
226
+ # categorical_keywords=categorical_keywords,
227
+ # unary_keywords=unary_keywords,
228
+ # binary_keywords=binary_keywords,
229
+ # segmentation_method='grounding_dino_sam2'
230
+ # )
231
+ # print("Results:", results['summary'])
232
+ # except Exception as e:
233
+ # print(f"Pipeline execution failed: {e}")
234
+ # print("This is expected if segmentation models are not set up")
235
+
236
+ return vine_pipeline
237
+
238
+
239
+
240
+ def demonstrate_weight_transfer():
241
+ """Demonstrate how weights are transferred from ensemble to HuggingFace format."""
242
+ print("\n=== Weight Transfer Demonstration ===")
243
+
244
+ print("The ensemble model structure (PredicateModel):")
245
+ print("- clip_cate_model: CLIP model for categorical classification")
246
+ print("- clip_unary_model: CLIP model for unary predicates")
247
+ print("- clip_binary_model: CLIP model for binary relations")
248
+ print("- clip_tokenizer: Tokenizer for text processing")
249
+ print("- clip_processor: Processor for image processing")
250
+
251
+ print("\nWeight transfer process:")
252
+ print("1. Load ensemble model with torch.load()")
253
+ print("2. Initialize base CLIP models in HuggingFace format")
254
+ print("3. Transfer state_dict from ensemble to HuggingFace models:")
255
+ print(" - ensemble.clip_cate_model → hf.clip_cate_model")
256
+ print(" - ensemble.clip_unary_model → hf.clip_unary_model")
257
+ print(" - ensemble.clip_binary_model → hf.clip_binary_model")
258
+ print("4. Transfer tokenizer and processor")
259
+
260
+ print("\nThis preserves all your fine-tuned weights while making them HuggingFace compatible!")
261
+
262
+
263
+ def troubleshooting_guide():
264
+ """Provide troubleshooting guide for common issues."""
265
+ print("\n=== Troubleshooting Guide ===")
266
+
267
+ print("Common Issues:")
268
+ print("1. 'No model file found for epoch X'")
269
+ print(" → Check that .model files exist in the directory")
270
+ print(" → Verify the epoch number is correct")
271
+ print(" → List files: ls /path/to/model/dir/*.model")
272
+
273
+ print("\n2. 'Error loading VINE weights'")
274
+ print(" → Check file permissions")
275
+ print(" → Verify the model file is not corrupted")
276
+ print(" → Try loading with torch.load() directly first")
277
+
278
+ print("\n3. 'CLIP model mismatch'")
279
+ print(" → Ensure config.model_name matches the base model used in training")
280
+
281
+ print("\n4. 'Device mismatch errors'")
282
+ print(" → Models are loaded to CPU first, then moved to device")
283
+ print(" → Check CUDA availability with torch.cuda.is_available()")
284
+
285
+ print("\nDebugging steps:")
286
+ print("1. Test loading ensemble model directly:")
287
+ print(" model = torch.load('path/to/model.0.model', map_location='cpu')")
288
+ print("2. Check model attributes:")
289
+ print(" print(dir(model))")
290
+ print("3. Verify state_dict keys:")
291
+ print(" print(model.clip_cate_model.state_dict().keys())")
292
+
293
+
294
+ if __name__ == "__main__":
295
+ print("VINE Ensemble Weights Loading Examples")
296
+ print("=" * 50)
297
+
298
+ # Test ensemble weight loading
299
+ try:
300
+ model1 = example_load_ensemble_weights()
301
+ except Exception as e:
302
+ print(f"Ensemble loading example failed: {e}")
303
+
304
+ try:
305
+ model2 = example_direct_ensemble_loading()
306
+ except Exception as e:
307
+ print(f"Direct loading example failed: {e}")
308
+
309
+ # Compare approaches
310
+ try:
311
+ example_compare_original_vs_hf()
312
+ except Exception as e:
313
+ print(f"Comparison example failed: {e}")
314
+
315
+ # Test pipeline with ensemble weights
316
+ try:
317
+ pipeline = example_ensemble_with_pipeline()
318
+ except Exception as e:
319
+ print(f"Pipeline example failed: {e}")
320
+
321
+ # Educational content
322
+ demonstrate_weight_transfer()
323
+ troubleshooting_guide()
324
+
325
+ print("\n" + "=" * 50)
326
+ print("Key Points:")
327
+ print("1. AutoModel.from_pretrained() won't work with .pt ensemble weights")
328
+ print("2. Use torch.load() to load the ensemble, then transfer weights")
329
+ print("3. The HuggingFace interface preserves your fine-tuned weights")
330
+ print("4. Specify pretrained_vine_path in VineConfig to auto-load weights")
331
+ print("5. Use VineModel.from_pretrained_vine() for direct loading")
332
+
333
+
vine_hf/example_sam2_masks.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example demonstrating SAM2 mask generation in VINE HuggingFace interface
3
+
4
+ This script shows how to use both SAM2-only and Grounding DINO + SAM2
5
+ segmentation methods with the VINE model.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ import numpy as np
12
+ from transformers.pipelines import PIPELINE_REGISTRY
13
+
14
+ # Add the parent directory to the path to import vine_hf
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+ # Add the parent directory to the path to import vine_hf
17
+
18
+ #Either uncomment the below or set a environemental key, though it isn't needed to run.
19
+ #os.environ['OPENAI_API_KEY'] = 'dummy-key'
20
+
21
+ from vine_hf import VineConfig, VineModel, VinePipeline
22
+ from laser.loading import load_video
23
+
24
+
25
+ def example_sam2_only_segmentation():
26
+ """Example using SAM2 automatic mask generation only."""
27
+ print("=== SAM2-Only Segmentation Example ===")
28
+
29
+ # Create configuration for SAM2-only
30
+ config = VineConfig(
31
+ use_hf_repo=True,
32
+ model_repo="video-fm/vine_v0",
33
+ segmentation_method="sam2", # Use SAM2 only
34
+ target_fps=1,
35
+ debug_visualizations=True,
36
+ )
37
+
38
+ # Register pipeline
39
+ PIPELINE_REGISTRY.register_pipeline(
40
+ "vine-video-understanding",
41
+ pipeline_class=VinePipeline,
42
+ pt_model=VineModel,
43
+ type="multimodal",
44
+ )
45
+
46
+ # Create model and pipeline with SAM2 paths
47
+ vine_model = VineModel(config)
48
+ vine_pipeline = VinePipeline(
49
+ model=vine_model,
50
+ tokenizer=None,
51
+ sam_config_path="path/to/your/sam2/sam_config.yaml",
52
+ sam_checkpoint_path="path/to/your/sam2/sam_checkpoint.pth",
53
+ gd_config_path="path/to/your/groundingdino/config.py",
54
+ gd_checkpoint_path="path/to/your/groundingdino/checkpoint.pth",
55
+ )
56
+
57
+ # Check for demo video
58
+ demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
59
+
60
+ if os.path.exists(demo_video):
61
+ print(f"Processing video: {demo_video}")
62
+
63
+ # Define keywords (SAM2 will find all objects, then classify them)
64
+ categorical_keywords = ['human', 'dog', 'frisbee', 'object', 'person', 'animal']
65
+ unary_keywords = ['running', 'jumping', 'sitting', 'standing', 'moving', 'static']
66
+ binary_keywords = ['behind', 'in front of', 'next to', 'chasing', 'following']
67
+ object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3), (0,4)]
68
+
69
+
70
+ print("Using SAM2 automatic mask generation...")
71
+ print("This will find all objects in the video automatically")
72
+
73
+ try:
74
+ # Process with SAM2 only
75
+ results = vine_pipeline(
76
+ demo_video,
77
+ categorical_keywords=categorical_keywords,
78
+ unary_keywords=unary_keywords,
79
+ binary_keywords=binary_keywords,
80
+ object_pairs=object_pairs,
81
+ segmentation_method="sam2",
82
+ return_top_k=3,
83
+ debug_visualizations=True,
84
+ debug_visualization_path=os.path.join(os.getcwd(), "sam2_debug_masks.png"),
85
+ )
86
+
87
+ print("\n✓ SAM2 segmentation completed!")
88
+ print("Results summary:")
89
+ print(f" Objects detected: {results['summary']['num_objects_detected']}")
90
+ print(f" Top categories: {results['summary']['top_categories']}")
91
+ print(f" Top actions: {results['summary']['top_actions']}")
92
+
93
+ return results
94
+
95
+ except Exception as e:
96
+ print(f"SAM2 segmentation failed: {e}")
97
+ print("Make sure SAM2 models are properly installed")
98
+ return None
99
+ else:
100
+ print(f"Demo video not found: {demo_video}")
101
+ return None
102
+
103
+ def example_grounding_dino_sam2_segmentation():
104
+ """Example using Grounding DINO + SAM2 text-guided segmentation."""
105
+ print("\n=== Grounding DINO + SAM2 Segmentation Example ===")
106
+
107
+ # Create configuration for Grounding DINO + SAM2
108
+ config = VineConfig(
109
+ use_hf_repo=True,
110
+ model_repo="video-fm/vine_v0",
111
+ segmentation_method="grounding_dino_sam2", # Use text-guided segmentation
112
+ box_threshold=0.35,
113
+ text_threshold=0.25,
114
+ target_fps=1,
115
+ debug_visualizations=True,
116
+ )
117
+
118
+ # Create model and pipeline with both SAM2 and GroundingDINO paths
119
+ vine_model = VineModel(config)
120
+ vine_pipeline = VinePipeline(
121
+ model=vine_model,
122
+ tokenizer=None,
123
+ # SAM2 configuration
124
+ sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
125
+ sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
126
+ gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
127
+ gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
128
+ device=0,
129
+ )
130
+
131
+ # Check for demo video
132
+ demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
133
+
134
+ if os.path.exists(demo_video):
135
+ print(f"Processing video: {demo_video}")
136
+
137
+ # Define keywords (Grounding DINO will look specifically for these)
138
+ categorical_keywords = ['human', 'dog', 'frisbee'] # Specific objects to find
139
+ unary_keywords = ['running', 'jumping', 'catching', 'throwing']
140
+ binary_keywords = ['behind', 'chasing', 'next to', 'throwing to']
141
+ object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3), (0,4)]
142
+ print("Using Grounding DINO + SAM2 text-guided segmentation...")
143
+ print(f"Looking specifically for: {categorical_keywords}")
144
+
145
+ try:
146
+ # Process with Grounding DINO + SAM2
147
+ results = vine_pipeline(
148
+ demo_video,
149
+ categorical_keywords=categorical_keywords,
150
+ unary_keywords=unary_keywords,
151
+ binary_keywords=binary_keywords,
152
+ object_pairs=object_pairs,
153
+ segmentation_method="grounding_dino_sam2",
154
+ box_threshold=0.35,
155
+ text_threshold=0.25,
156
+ return_top_k=3,
157
+ debug_visualizations=True,
158
+ )
159
+
160
+ print("\n✓ Grounding DINO + SAM2 segmentation completed!")
161
+ print("Results summary:")
162
+ print(f" Objects detected: {results['summary']['num_objects_detected']}")
163
+ print(f" Top categories: {results['summary']['top_categories']}")
164
+ print(f" Top actions: {results['summary']['top_actions']}")
165
+ print(f" Top relations: {results['summary']['top_relations']}")
166
+
167
+ return results
168
+
169
+ except Exception as e:
170
+ print(f"Grounding DINO + SAM2 segmentation failed: {e}")
171
+ print("Make sure both Grounding DINO and SAM2 models are properly installed")
172
+ return None
173
+ else:
174
+ print(f"Demo video not found: {demo_video}")
175
+ return None
176
+
177
+
178
+ def compare_segmentation_methods():
179
+ """Compare SAM2-only vs Grounding DINO + SAM2 approaches."""
180
+ print("\n=== Comparing Segmentation Methods ===")
181
+
182
+ print("\nSAM2-Only Approach:")
183
+ print("✓ Finds all objects automatically")
184
+ print("✓ No need to specify what to look for")
185
+ print("✓ Good for exploratory analysis")
186
+ print("✗ May find too many irrelevant objects")
187
+ print("✗ Less precise for specific object types")
188
+
189
+ print("\nGrounding DINO + SAM2 Approach:")
190
+ print("✓ Finds specific objects based on text prompts")
191
+ print("✓ More precise and targeted")
192
+ print("✓ Better for known object categories")
193
+ print("✓ Integrates object detection with segmentation")
194
+ print("✗ Limited to specified categories")
195
+ print("✗ Requires knowing what objects to look for")
196
+
197
+
198
+ def demonstrate_mask_processing():
199
+ """Demonstrate how masks are processed internally."""
200
+ print("\n=== Mask Processing Demonstration ===")
201
+
202
+ # Load a video to show the processing pipeline
203
+ demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
204
+
205
+ if os.path.exists(demo_video):
206
+ print("Loading video for mask processing demo...")
207
+
208
+ # Load video tensor
209
+ video_tensor = np.asarray(load_video(demo_video, target_fps=1))
210
+ print(f"Video shape: {video_tensor.shape}")
211
+
212
+ # Create pipeline with segmentation model paths
213
+ config = VineConfig(segmentation_method="sam2")
214
+ vine_model = VineModel(config)
215
+ vine_pipeline = VinePipeline(
216
+ model=vine_model,
217
+ tokenizer=None,
218
+ # SAM2 configuration
219
+ sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
220
+ sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
221
+ gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
222
+ gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
223
+ )
224
+
225
+ try:
226
+ # Process just the first few frames to show the pipeline
227
+ print("\nProcessing first 2 frames with SAM2...")
228
+
229
+ # Manually call the preprocessing to show the steps
230
+ processed_data = vine_pipeline.preprocess(
231
+ video_tensor[:2], # Just first 2 frames
232
+ segmentation_method="sam2",
233
+ categorical_keywords=['object']
234
+ )
235
+
236
+ print("Mask processing results:")
237
+ print(f" Number of frames processed: {processed_data['num_frames']}")
238
+ print(f" Frames with masks: {list(processed_data['masks'].keys())}")
239
+
240
+ # Show mask details
241
+ for frame_id, frame_masks in processed_data['masks'].items():
242
+ print(f" Frame {frame_id}: {len(frame_masks)} objects detected")
243
+ for obj_id, mask in frame_masks.items():
244
+ print(f" Object {obj_id}: mask shape {mask.shape}")
245
+
246
+ print("\nBounding box extraction:")
247
+ for frame_id, frame_bboxes in processed_data['bboxes'].items():
248
+ print(f" Frame {frame_id}: {len(frame_bboxes)} bounding boxes")
249
+ for obj_id, bbox in frame_bboxes.items():
250
+ print(f" Object {obj_id}: bbox {bbox}")
251
+
252
+ except Exception as e:
253
+ print(f"Mask processing failed: {e}")
254
+ print("This is expected if SAM2 models are not properly set up")
255
+ else:
256
+ print(f"Demo video not found: {demo_video}")
257
+
258
+
259
+ def test_mask_formats():
260
+ """Test different mask input formats."""
261
+ print("\n=== Testing Mask Formats ===")
262
+
263
+ # Create dummy data to test mask processing
264
+ height, width = 224, 224
265
+
266
+ # Test different mask formats
267
+ print("Testing mask format conversions...")
268
+
269
+ # Format 1: NumPy boolean array
270
+ mask_np = np.random.rand(height, width) > 0.5
271
+ print(f"NumPy mask: {mask_np.shape}, dtype: {mask_np.dtype}")
272
+
273
+ # Format 2: PyTorch tensor
274
+ mask_torch = torch.from_numpy(mask_np)
275
+ print(f"PyTorch mask: {mask_torch.shape}, dtype: {mask_torch.dtype}")
276
+
277
+ # Format 3: 3D mask with singleton dimension
278
+ mask_3d = mask_torch.unsqueeze(-1)
279
+ print(f"3D mask: {mask_3d.shape}")
280
+
281
+ # Test bounding box extraction
282
+ from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
283
+
284
+ try:
285
+ bbox = mask_to_bbox(mask_torch)
286
+ print(f"Extracted bbox: {bbox}")
287
+ print("✓ Mask format testing successful")
288
+ except Exception as e:
289
+ print(f"Mask format testing failed: {e}")
290
+
291
+
292
+ if __name__ == "__main__":
293
+ print("VINE SAM2 Mask Generation Examples")
294
+ print("=" * 50)
295
+
296
+ # Test SAM2-only approach
297
+ try:
298
+ sam2_results = example_sam2_only_segmentation()
299
+ except Exception as e:
300
+ print(f"SAM2-only example failed: {e}")
301
+
302
+ # Test Grounding DINO + SAM2 approach
303
+ try:
304
+ gd_sam2_results = example_grounding_dino_sam2_segmentation()
305
+ except Exception as e:
306
+ print(f"Grounding DINO + SAM2 example failed: {e}")
307
+
308
+ # Compare approaches
309
+ compare_segmentation_methods()
310
+
311
+ # Demonstrate mask processing
312
+ try:
313
+ demonstrate_mask_processing()
314
+ except Exception as e:
315
+ print(f"Mask processing demo failed: {e}")
316
+
317
+ # Test mask formats
318
+ try:
319
+ test_mask_formats()
320
+ except Exception as e:
321
+ print(f"Mask format testing failed: {e}")
322
+
323
+ print("\n" + "=" * 50)
324
+ print("Examples completed!")
325
+ print("\nKey takeaways:")
326
+ print("1. SAM2-only: Automatic object detection and segmentation")
327
+ print("2. Grounding DINO + SAM2: Text-guided object detection and segmentation")
328
+ print("3. Both methods provide masks and bounding boxes for VINE model")
329
+ print("4. Choose method based on whether you know what objects to look for")
330
+
331
+
vine_hf/example_usage.ipynb ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "44d53281",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/kevinx/miniconda3/envs/laser_env/lib/python3.10/site-packages/pydantic/_internal/_config.py:383: UserWarning: Valid config keys have changed in V2:\n",
14
+ "* 'schema_extra' has been renamed to 'json_schema_extra'\n",
15
+ " warnings.warn(message, UserWarning)\n",
16
+ "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n",
17
+ "Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
18
+ ]
19
+ }
20
+ ],
21
+ "source": [
22
+ "import os\n",
23
+ "import sys\n",
24
+ "import torch\n",
25
+ "from transformers import pipeline, AutoModel\n",
26
+ "from transformers.pipelines import PIPELINE_REGISTRY\n",
27
+ "\n",
28
+ "# Uncomment or set your own\n",
29
+ "#os.environ['OPENAI_API_KEY'] = 'dummy-key'\n",
30
+ "from vine_hf import VineConfig, VineModel, VinePipeline"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "id": "174e479f",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "PIPELINE_REGISTRY.register_pipeline(\n",
41
+ " \"vine-video-understanding\",\n",
42
+ " pipeline_class=VinePipeline,\n",
43
+ " pt_model=VineModel,\n",
44
+ " type=\"multimodal\",\n",
45
+ ")"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "a9af2770",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "vine_config = VineConfig(\n",
56
+ " model_name=\"openai/clip-vit-base-patch32\",\n",
57
+ " # Local file example: set use_hf_repo=False and provide local_dir/local_filename\n",
58
+ " use_hf_repo=False,\n",
59
+ " local_dir=os.path.dirname('/path/to/your/pretrained/model.pt'),\n",
60
+ " local_filename=os.path.basename('/path/to/your/pretrained/model.pt'), # Local file path\n",
61
+ " segmentation_method=\"grounding_dino_sam2\",\n",
62
+ " visualize=True,\n",
63
+ " visualization_dir=\"path/to/visualization/dir\",\n",
64
+ " debug_visualizations=True,\n",
65
+ " device=0, # Change to your desired device\n",
66
+ ")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "274e6515",
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "Loaded state type: <class 'collections.OrderedDict'>\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "vine_pipeline = VinePipeline(\n",
85
+ " model=VineModel(vine_config), \n",
86
+ " tokenizer=None,\n",
87
+ " sam_config_path=\"path/to/sam2/configs/sam2_hiera_base_plus.yaml\",\n",
88
+ " sam_checkpoint_path=\"path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt\",\n",
89
+ " gd_config_path=\"path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py\",\n",
90
+ " gd_checkpoint_path=\"path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth\",\n",
91
+ ")"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 6,
97
+ "id": "123a090d",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "categorical_keywords = ['human', 'dog', 'frisbee']\n",
102
+ "unary_keywords = ['running', 'jumping', 'catching', 'throwing']\n",
103
+ "binary_keywords = ['behind', 'in front of', 'next to', 'chasing']\n",
104
+ "object_pairs = [(0, 1), (0, 2), (1, 2)] # human-dog, dog-frisbee relationships "
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 7,
110
+ "id": "0b42f032",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "demo_video_path = \"/home/kevinx/LASER/LASER/demo/videos/v1.mp4\" # Replace with your video file path"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 8,
120
+ "id": "8202c654",
121
+ "metadata": {},
122
+ "outputs": [
123
+ {
124
+ "name": "stdout",
125
+ "output_type": "stream",
126
+ "text": [
127
+ "Segmentation method: grounding_dino_sam2\n",
128
+ "Generating Grounding DINO + SAM2 masks...\n",
129
+ "<class 'int'>\n",
130
+ "✓ SAM2 models initialized successfully\n",
131
+ "<class 'int'>\n"
132
+ ]
133
+ },
134
+ {
135
+ "name": "stderr",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4314.)\n"
139
+ ]
140
+ },
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "final text_encoder_type: bert-base-uncased\n",
146
+ "✓ GroundingDINO model initialized successfully\n",
147
+ "Start detecting objects at time 05:08:58.178592\n"
148
+ ]
149
+ },
150
+ {
151
+ "name": "stderr",
152
+ "output_type": "stream",
153
+ "text": [
154
+ "Detecting objects: 0%| | 0/3 [00:00<?, ?it/s]FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
155
+ "UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
156
+ "UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
157
+ "FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
158
+ "Detecting objects: 100%|██████████| 3/3 [00:01<00:00, 2.82it/s]\n"
159
+ ]
160
+ },
161
+ {
162
+ "name": "stdout",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "Finished detecting objects at time 05:08:59.250419\n",
166
+ "Loading inference state at time 05:08:59.544425\n",
167
+ "Number of frames: 3\n",
168
+ "None\n"
169
+ ]
170
+ },
171
+ {
172
+ "name": "stderr",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "Processing frames: 100%|██████████| 3/3 [00:00<00:00, 11.77it/s]\n"
176
+ ]
177
+ },
178
+ {
179
+ "name": "stdout",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "Annotated frames: []\n",
183
+ "Find the most dense prompt at time 05:09:01.413703\n",
184
+ "Most dense frame: 0\n",
185
+ "\n",
186
+ "\n",
187
+ "Start propagating objects at time 05:09:01.416367\n",
188
+ "Pass count: 0\n"
189
+ ]
190
+ },
191
+ {
192
+ "name": "stderr",
193
+ "output_type": "stream",
194
+ "text": [
195
+ "propagate in video: 100%|██████████| 3/3 [00:00<00:00, 20.20it/s]\n",
196
+ "propagate in video: 0it [00:00, ?it/s]\n"
197
+ ]
198
+ },
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "Most dense frame: 1\n",
204
+ "\n",
205
+ "\n",
206
+ "Pass count: 1\n"
207
+ ]
208
+ },
209
+ {
210
+ "name": "stderr",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "propagate in video: 100%|██████████| 3/3 [00:00<00:00, 19.25it/s]\n",
214
+ "propagate in video: 0it [00:00, ?it/s]\n"
215
+ ]
216
+ },
217
+ {
218
+ "name": "stdout",
219
+ "output_type": "stream",
220
+ "text": [
221
+ "Most dense frame: 2\n",
222
+ "\n",
223
+ "\n",
224
+ "Pass count: 2\n"
225
+ ]
226
+ },
227
+ {
228
+ "name": "stderr",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "propagate in video: 100%|██████████| 3/3 [00:00<00:00, 25.92it/s]\n",
232
+ "propagate in video: 0it [00:00, ?it/s]\n"
233
+ ]
234
+ },
235
+ {
236
+ "name": "stdout",
237
+ "output_type": "stream",
238
+ "text": [
239
+ "Most dense frame: -1\n",
240
+ "\n",
241
+ "\n",
242
+ "\n",
243
+ "Results:\n",
244
+ "Summary: {'num_objects_detected': 4, 'num_unary_predictions': 10, 'num_binary_predictions': 3, 'top_categories': [('frisbee', 0.9989640712738037), ('dog', 0.957672655582428), ('dog', 0.957672655582428)], 'top_actions': [('running', 0.8483631610870361), ('running', 0.832377016544342), ('running', 0.8178836107254028)], 'top_relations': [('chasing', 0.9616015553474426), ('chasing', 0.9478002786636353), ('chasing', 0.6380977630615234)]}\n"
245
+ ]
246
+ }
247
+ ],
248
+ "source": [
249
+ "try:\n",
250
+ " results = vine_pipeline(\n",
251
+ " demo_video_path,\n",
252
+ " categorical_keywords=categorical_keywords,\n",
253
+ " unary_keywords=unary_keywords,\n",
254
+ " binary_keywords=binary_keywords,\n",
255
+ " object_pairs=object_pairs,\n",
256
+ " segmentation_method='grounding_dino_sam2',\n",
257
+ " return_top_k=3,\n",
258
+ " include_visualizations=False,\n",
259
+ " debug_visualizations=False,\n",
260
+ " )\n",
261
+ " \n",
262
+ " print(\"\\nResults:\")\n",
263
+ " print(f\"Summary: {results['summary']}\")\n",
264
+ " \n",
265
+ "except Exception as e:\n",
266
+ " print(f\"Note: Full execution requires segmentation models to be properly set up.\")\n",
267
+ " print(f\"Error: {e}\")"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 9,
273
+ "id": "414ede9b",
274
+ "metadata": {},
275
+ "outputs": [
276
+ {
277
+ "name": "stdout",
278
+ "output_type": "stream",
279
+ "text": [
280
+ "Summary: {'num_objects_detected': 4, 'num_unary_predictions': 10, 'num_binary_predictions': 3, 'top_categories': [('frisbee', 0.9989640712738037), ('dog', 0.957672655582428), ('dog', 0.957672655582428)], 'top_actions': [('running', 0.8483631610870361), ('running', 0.832377016544342), ('running', 0.8178836107254028)], 'top_relations': [('chasing', 0.9616015553474426), ('chasing', 0.9478002786636353), ('chasing', 0.6380977630615234)]}\n"
281
+ ]
282
+ }
283
+ ],
284
+ "source": [
285
+ "print(f\"Summary: {results['summary']}\")"
286
+ ]
287
+ }
288
+ ],
289
+ "metadata": {
290
+ "kernelspec": {
291
+ "display_name": "laser_env",
292
+ "language": "python",
293
+ "name": "python3"
294
+ },
295
+ "language_info": {
296
+ "codemirror_mode": {
297
+ "name": "ipython",
298
+ "version": 3
299
+ },
300
+ "file_extension": ".py",
301
+ "mimetype": "text/x-python",
302
+ "name": "python",
303
+ "nbconvert_exporter": "python",
304
+ "pygments_lexer": "ipython3",
305
+ "version": "3.10.0"
306
+ }
307
+ },
308
+ "nbformat": 4,
309
+ "nbformat_minor": 5
310
+ }
vine_hf/example_usage.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage of VINE HuggingFace interface
3
+
4
+ This script demonstrates how to use the VINE model through the HuggingFace interface
5
+ for video understanding with categorical, unary, and binary keyword predictions.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ from transformers import pipeline, AutoModel
12
+ from transformers.pipelines import PIPELINE_REGISTRY
13
+
14
+ # Add the parent directory to the path to import vine_hf
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+
17
+ # Uncomment or set your own
18
+ #os.environ['OPENAI_API_KEY'] = 'dummy-key'
19
+ from vine_hf import VineConfig, VineModel, VinePipeline
20
+
21
+ def example_direct_model_usage():
22
+ """Example of using the VINE model directly."""
23
+ print("=== Direct Model Usage ===")
24
+
25
+ # Create configuration
26
+ config = VineConfig(
27
+ model_name="openai/clip-vit-base-patch32",
28
+ segmentation_method="grounding_dino_sam2",
29
+ use_hf_repo=True,
30
+ model_repo="video-fm/vine_v0", # Your HF Hub model
31
+ debug_visualizations=True,
32
+ debug_visualization_path=os.path.join(os.getcwd(), "debug_masks.png"),
33
+ target_fps=30,
34
+ box_threshold=0.35,
35
+ text_threshold=0.25
36
+ )
37
+
38
+ # Initialize model
39
+ model = VineModel(config)
40
+
41
+ print(f"Model initialized with CLIP backbone: {config.model_name}")
42
+ print(f"Segmentation method: {config.segmentation_method}")
43
+ print(f"Device: {model.device}")
44
+
45
+ # Example video data (placeholder - in real usage, load from video file)
46
+ num_frames, height, width = 3, 224, 224
47
+ video_frames = torch.randn(num_frames, height, width, 3) * 255
48
+ video_frames = video_frames.clamp(0, 255).byte()
49
+
50
+ # Example masks and bboxes (placeholder - in real usage, generated by segmentation)
51
+ masks = {
52
+ 0: {1: torch.ones(height, width, 1), 2: torch.ones(height, width, 1)},
53
+ 1: {1: torch.ones(height, width, 1), 2: torch.ones(height, width, 1)},
54
+ 2: {1: torch.ones(height, width, 1), 2: torch.ones(height, width, 1)}
55
+ }
56
+
57
+ bboxes = {
58
+ 0: {1: [50, 50, 150, 150], 2: [100, 100, 200, 200]},
59
+ 1: {1: [52, 52, 152, 152], 2: [102, 102, 202, 202]},
60
+ 2: {1: [54, 54, 154, 154], 2: [104, 104, 204, 204]}
61
+ }
62
+
63
+ # Define keywords
64
+ categorical_keywords = ["human", "dog", "frisbee"]
65
+ unary_keywords = ["running", "jumping", "sitting", "standing"]
66
+ binary_keywords = ["behind", "in front of", "next to", "throwing to", "catching from"]
67
+ object_pairs = [(1, 2)] # Object 1 relates to Object 2
68
+
69
+ # Run prediction
70
+ print("\nRunning prediction...")
71
+ results = model.predict(
72
+ video_frames=video_frames,
73
+ masks=masks,
74
+ bboxes=bboxes,
75
+ categorical_keywords=categorical_keywords,
76
+ unary_keywords=unary_keywords,
77
+ binary_keywords=binary_keywords,
78
+ object_pairs=object_pairs,
79
+ return_top_k=3
80
+ )
81
+
82
+ print("\nResults:")
83
+ print(f"Categorical predictions: {len(results['categorical_predictions'])} objects")
84
+ print(f"Unary predictions: {len(results['unary_predictions'])} actions")
85
+ print(f"Binary predictions: {len(results['binary_predictions'])} relations")
86
+ print(f"Confidence scores: {results['confidence_scores']}")
87
+
88
+
89
+ def example_pipeline_usage():
90
+ """Example of using the VINE pipeline."""
91
+ print("\n=== Pipeline Usage ===")
92
+
93
+ # Register the pipeline
94
+ PIPELINE_REGISTRY.register_pipeline(
95
+ "vine-video-understanding",
96
+ pipeline_class=VinePipeline,
97
+ pt_model=VineModel,
98
+ type="multimodal",
99
+ )
100
+ vine_config = VineConfig(
101
+ model_name="openai/clip-vit-base-patch32",
102
+ use_hf_repo=True,
103
+ model_repo="video-fm/vine_v0", # Your HF Hub model
104
+ segmentation_method="grounding_dino_sam2",
105
+ debug_visualizations=True,
106
+ )
107
+
108
+ vine_pipe = VinePipeline(
109
+ model=VineModel(vine_config),
110
+ tokenizer=None,
111
+ trust_remote_code=True,
112
+ # SAM2 configuration
113
+ sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
114
+ sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
115
+ gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
116
+ gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
117
+ device=0,
118
+ )
119
+
120
+
121
+ print("Pipeline created successfully!")
122
+
123
+ # Example usage with video path
124
+ video_path = "path/to/your/video.mp4" # Replace with actual video path
125
+
126
+ # For demonstration, we'll show the expected usage format
127
+ print(f"\nExample pipeline call (replace with actual video path):")
128
+ print(f"results = vine_pipeline(")
129
+ print(f" '{video_path}',")
130
+ print(f" categorical_keywords=['human', 'dog', 'frisbee'],")
131
+ print(f" unary_keywords=['running', 'jumping', 'sitting'],")
132
+ print(f" binary_keywords=['behind', 'in front of', 'next to'],")
133
+ print(f" object_pairs=[(1, 2)],")
134
+ print(f" segmentation_method='grounding_dino_sam2',")
135
+ print(f" return_top_k=3,")
136
+ print(f" return_flattened_segments=True,")
137
+ print(f" return_valid_pairs=True,")
138
+ print(f" include_visualizations=True,")
139
+ print(f" debug_visualizations=True")
140
+ print(f")")
141
+
142
+ # Note: Actual execution would require proper video file and segmentation models
143
+
144
+
145
+ def example_huggingface_hub_usage():
146
+ """Example of how to push and load from HuggingFace Hub."""
147
+ print("\n=== HuggingFace Hub Usage ===")
148
+
149
+ # Example of preparing model for Hub
150
+ config = VineConfig()
151
+ model = VineModel(config)
152
+
153
+ # Register for auto classes
154
+ config.register_for_auto_class()
155
+ model.register_for_auto_class("AutoModel")
156
+
157
+ print("Model registered for auto classes")
158
+
159
+ # Example push to hub (commented out - requires actual model weights and credentials)
160
+ # config.push_to_hub('your-username/vine-model')
161
+ # model.push_to_hub('your-username/vine-model')
162
+
163
+ # Example load from hub (commented out - requires actual model on hub)
164
+ # model = AutoModel.from_pretrained('your-username/vine-model', trust_remote_code=True)
165
+ # pipeline = pipeline('vine-video-understanding', model='your-username/vine-model', trust_remote_code=True)
166
+
167
+ print("To push to Hub:")
168
+ print("1. config.push_to_hub('your-username/vine-model')")
169
+ print("2. model.push_to_hub('your-username/vine-model')")
170
+ print("\nTo load from Hub:")
171
+ print("model = AutoModel.from_pretrained('your-username/vine-model', trust_remote_code=True)")
172
+ print("pipe = pipeline('vine-video-understanding', model='your-username/vine-model', trust_remote_code=True)")
173
+
174
+
175
+ def example_with_real_video():
176
+ """Example showing how to use with a real video file."""
177
+ print("\n=== Real Video Usage Example ===")
178
+
179
+ # Check if demo video exists
180
+ demo_video_path = os.path.join(os.path.dirname(__file__), "../demo/videos/v1.mp4")
181
+
182
+ if os.path.exists(demo_video_path):
183
+ print(f"Found demo video: {demo_video_path}")
184
+
185
+ # Create pipeline with segmentation model paths
186
+ PIPELINE_REGISTRY.register_pipeline(
187
+ "vine-video-understanding",
188
+ pipeline_class=VinePipeline,
189
+ pt_model=VineModel,
190
+ type="multimodal",
191
+ )
192
+
193
+ vine_config = VineConfig(
194
+ model_name="openai/clip-vit-base-patch32",
195
+ use_hf_repo=True,
196
+ model_repo="video-fm/vine_v0", # Your HF Hub model
197
+ segmentation_method="grounding_dino_sam2",
198
+ debug_visualizations=True,
199
+ debug_visualization_path=os.path.join(os.getcwd(), "real_video_debug_masks.png"),
200
+ )
201
+
202
+ vine_pipeline = VinePipeline(
203
+ model=VineModel(vine_config),
204
+ tokenizer=None,
205
+ trust_remote_code=True,
206
+ # SAM2 configuration
207
+ sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
208
+ sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
209
+ gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
210
+ gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
211
+ )
212
+
213
+ # Define keywords based on the demo
214
+ categorical_keywords = ['human', 'dog', 'frisbee']
215
+ unary_keywords = ['running', 'jumping', 'catching', 'throwing']
216
+ binary_keywords = ['behind', 'in front of', 'next to', 'chasing']
217
+ object_pairs = [(0, 1), (0, 2), (1, 2)] # human-dog, dog-frisbee relationships
218
+
219
+ print("\nProcessing video with VINE...")
220
+ print("Keywords:")
221
+ print(f" Categorical: {categorical_keywords}")
222
+ print(f" Unary: {unary_keywords}")
223
+ print(f" Binary: {binary_keywords}")
224
+ print(f" Object pairs: {object_pairs}")
225
+
226
+ # Note: This would require proper segmentation models to be set up
227
+ try:
228
+ results = vine_pipeline(
229
+ demo_video_path,
230
+ categorical_keywords=categorical_keywords,
231
+ unary_keywords=unary_keywords,
232
+ binary_keywords=binary_keywords,
233
+ object_pairs=object_pairs,
234
+ segmentation_method='grounding_dino_sam2',
235
+ return_top_k=3,
236
+ include_visualizations=False,
237
+ debug_visualizations=True,
238
+ )
239
+
240
+ print("\nResults:")
241
+ print(f"Summary: {results['summary']}")
242
+
243
+ except Exception as e:
244
+ print(f"Note: Full execution requires segmentation models to be properly set up.")
245
+ print(f"Error: {e}")
246
+
247
+ else:
248
+ print(f"Demo video not found at: {demo_video_path}")
249
+ print("To use with a real video, provide the path to your video file.")
250
+
251
+
252
+ if __name__ == "__main__":
253
+ print("VINE HuggingFace Interface Examples")
254
+ print("=" * 50)
255
+
256
+ # Run examples
257
+ try:
258
+ example_direct_model_usage()
259
+ except Exception as e:
260
+ print(f"Direct model usage failed: {e}")
261
+
262
+ try:
263
+ example_pipeline_usage()
264
+ except Exception as e:
265
+ print(f"Pipeline usage failed: {e}")
266
+
267
+ try:
268
+ example_huggingface_hub_usage()
269
+ except Exception as e:
270
+ print(f"Hub usage example failed: {e}")
271
+
272
+ try:
273
+ example_with_real_video()
274
+ except Exception as e:
275
+ print(f"Real video example failed: {e}")
276
+
277
+ print("\n" + "=" * 50)
278
+ print("Examples completed!")
279
+ print("\nNext steps:")
280
+ print("1. Set up Grounding DINO and SAM2 models for segmentation")
281
+ print("2. Load your pretrained VINE model weights")
282
+ print("3. Test with your own videos")
283
+ print("4. Push to HuggingFace Hub for sharing")
vine_hf/example_visualization.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example visualization runner for VINE
2
+ # - Loads a video (path, demo, or random)
3
+ # - Runs the VINE pipeline
4
+ # - Saves annotated frames and an MP4 if available
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import cv2
10
+ import numpy as np
11
+ from collections.abc import Mapping, Sequence
12
+
13
+ from transformers.pipelines import PIPELINE_REGISTRY
14
+ from transformers import pipeline
15
+
16
+ # Set your OpenAI API key here or via environment variable
17
+ os.environ['OPENAI_API_KEY'] = "dummy-key"
18
+
19
+ # Local imports (workspace)
20
+ sys.path.append(os.path.dirname(__file__))
21
+
22
+ from vine_hf.vine_pipeline import VinePipeline # https://github.com link not needed; local path used
23
+ from vine_hf.vine_model import VineModel
24
+ from vine_hf.vine_config import VineConfig
25
+ from laser.loading import load_video
26
+
27
+
28
+ def build_pipeline(args) -> VinePipeline:
29
+ # Register pipeline type
30
+ PIPELINE_REGISTRY.register_pipeline(
31
+ "vine-video-understanding",
32
+ pipeline_class=VinePipeline,
33
+ pt_model=VineModel,
34
+ type="multimodal",
35
+ )
36
+
37
+ config = VineConfig(
38
+ segmentation_method="grounding_dino_sam2",
39
+ model_name="openai/clip-vit-base-patch32",
40
+ # Example: load from HF repo
41
+ use_hf_repo=True,
42
+ model_repo="video-fm/vine_v0",
43
+ # Alternatively use a local path by setting use_hf_repo=False and local_dir/local_filename
44
+ box_threshold=args.box_threshold,
45
+ text_threshold=args.text_threshold,
46
+ target_fps=args.fps,
47
+ topk_cate=args.topk_cate,
48
+ visualization_dir=args.out_dir,
49
+ visualize=True,
50
+ debug_visualizations=True,
51
+ device=args.device,
52
+ )
53
+
54
+ model = VineModel(config)
55
+
56
+ # Create pipeline instance with segmentation model paths (if provided)
57
+ vine_pipe = VinePipeline(
58
+ model=model,
59
+ tokenizer=None,
60
+ sam_config_path="//home/kevinx/LASER/video-sam2/sam2/sam2_hiera_t.yaml",
61
+ sam_checkpoint_path="//home/kevinx/LASER/video-sam2/sam2_hiera_tiny.pt",
62
+ gd_config_path="//home/kevinx/LASER/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
63
+ gd_checkpoint_path="//home/kevinx/LASER/GroundingDINO/weights/groundingdino_swint_ogc.pth",
64
+ device=args.device,
65
+ trust_remote_code=True,
66
+ )
67
+ return vine_pipe
68
+
69
+
70
+ def resolve_video(args) -> np.ndarray | str:
71
+ # Priority: user --video -> demo video -> random frames
72
+ if args.video and os.path.exists(args.video):
73
+ return args.video
74
+
75
+ demo_video = "//home/kevinx/LASER/LASER/demo/videos/v1.mp4"
76
+ demo_alt = "//home/kevinx/LASER/LASER/demo/videos/v2.mp4"
77
+ if os.path.exists(demo_video):
78
+ return demo_video
79
+ if os.path.exists(demo_alt):
80
+ return demo_alt
81
+
82
+ # Fallback to random frames (uint8 HxWx3) shaped as T x H x W x 3
83
+ print("No video found; using random frames.")
84
+ rng = np.random.default_rng(0)
85
+ frames = rng.integers(0, 255, size=(args.rand_frames, args.height, args.width, 3), dtype=np.uint8)
86
+ return frames
87
+
88
+
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser(description="VINE visualization example")
92
+ parser.add_argument("--video", type=str, default=None, help="Path to a video file")
93
+ parser.add_argument("--out_dir", type=str, default="output", help="Output directory")
94
+ parser.add_argument("--method", type=str, default="grounding_dino_sam2", choices=["sam2", "grounding_dino_sam2"], help="Segmentation method")
95
+ parser.add_argument("--fps", type=int, default=5, help="Target FPS for processing")
96
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="GroundingDINO box threshold")
97
+ parser.add_argument("--text_threshold", type=float, default=0.3, help="GroundingDINO text threshold")
98
+ parser.add_argument("--topk_cate", type=int, default=5, help="Top-K categories to display")
99
+ parser.add_argument("--device", type=int, default=0, help="CUDA device index or -1 for CPU")
100
+ parser.add_argument("--debug_visualizations", action="store_true", help="Enable debug visualizations")
101
+
102
+
103
+ args = parser.parse_args()
104
+
105
+ vine_pipe = build_pipeline(args)
106
+ video = resolve_video(args)
107
+
108
+ # Keywords similar to examples/tests
109
+ categorical_keywords = ["dog", "frisbee", "cat"]
110
+ unary_keywords = ["running", "jumping", "sitting", "flying"]
111
+ binary_keywords = ["behind", "next to", "chasing","biting"]
112
+ object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3)]
113
+
114
+ print("Running VINE pipeline...")
115
+ call_kwargs = dict(
116
+ categorical_keywords=categorical_keywords,
117
+ unary_keywords=unary_keywords,
118
+ binary_keywords=binary_keywords,
119
+ object_pairs=object_pairs,
120
+ segmentation_method=args.method,
121
+ return_top_k=args.topk_cate,
122
+ include_visualizations=True,
123
+ debug_visualizations=args.debug_visualizations,
124
+ )
125
+
126
+
127
+ results = vine_pipe(
128
+ video,
129
+ **call_kwargs,
130
+ )
131
+
132
+ # Normalize pipeline output to a dict (can be dict or list[dict])
133
+ if isinstance(results, Mapping):
134
+ result = results
135
+ elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping):
136
+ result = results[0]
137
+ else:
138
+ result = {}
139
+
140
+ # Print brief summary
141
+ summary = result.get("summary", {}) if isinstance(result, dict) else {}
142
+ print("Summary:", summary)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
vine_hf/example_with_pretrained_vine.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage of VINE HuggingFace interface with pretrained VINE weights
3
+
4
+ This script demonstrates how to use the VINE model with your pretrained weights
5
+ from the ensemble format or from video-fm/vine_v0.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ from transformers import pipeline
12
+ from transformers.pipelines import PIPELINE_REGISTRY
13
+
14
+ # Set your OpenAI API key here or via environment variable
15
+ #os.environ['OPENAI_API_KEY'] = "dummy-key"
16
+
17
+ # Add the parent directory to the path to import vine_hf
18
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+
20
+ from vine_hf import VineConfig, VineModel, VinePipeline
21
+
22
+
23
+ def example_with_local_pretrained_weights():
24
+ print("=== Using Local Pretrained VINE Weights ===")
25
+
26
+
27
+ # Download https://huggingface.co/video-fm/vine_v0/tree/main/laser_model_v1.pt
28
+ pretrained_vine_file = "/path/to/your/local/laser_model_v1.pt" # Replace with your local path
29
+
30
+
31
+ # Create configuration with your pretrained path (local file)
32
+ config = VineConfig(
33
+ model_name="openai/clip-vit-base-patch32",
34
+ segmentation_method="grounding_dino_sam2",
35
+ target_fps=1,
36
+ visualize=True,
37
+ visualization_dir="path/to/visualization/dir",
38
+ debug_visualizations=True,
39
+ use_hf_repo=False,
40
+ local_dir=os.path.dirname(pretrained_vine_file),
41
+ local_filename=os.path.basename(pretrained_vine_file),
42
+ )
43
+
44
+ # Method 1: Initialize model directly
45
+ print("Method 1: Direct model initialization")
46
+ vine_model = VineModel(config)
47
+ print(f"✓ Model initialized with pretrained weights from: {pretrained_vine_file}")
48
+
49
+ # Method 2: Use the from_pretrained_vine class method
50
+ print("\nMethod 2: Using from_pretrained_vine class method")
51
+ vine_model_2 = VineModel.from_pretrained_vine(
52
+ model_path=pretrained_vine_file,
53
+ config=config,
54
+ epoch=0 # Specify epoch number
55
+ )
56
+ print("✓ Model loaded using from_pretrained_vine method")
57
+
58
+ return vine_model
59
+
60
+
61
+ def example_with_huggingface_hub():
62
+ """Example using VINE weights from HuggingFace Hub."""
63
+ print("\n=== Using HuggingFace Hub Weights ===")
64
+
65
+ # Create configuration to use HuggingFace Hub weights
66
+ config = VineConfig(
67
+ model_name="openai/clip-vit-base-patch32",
68
+ use_hf_repo=True,
69
+ model_repo="video-fm/vine_v0", # Your HF Hub model
70
+ segmentation_method="grounding_dino_sam2",
71
+ visualize=True,
72
+ visualization_dir="path/to/visualization/dir",
73
+ debug_visualizations=True,
74
+ )
75
+
76
+ try:
77
+ # Initialize model (will try to load from HF Hub)
78
+ vine_model = VineModel(config)
79
+ print("✓ Model loaded from HuggingFace Hub: video-fm/vine_v0")
80
+ return vine_model
81
+ except Exception as e:
82
+ print(f"✗ Could not load from HuggingFace Hub: {e}")
83
+ print("Make sure your model is pushed to video-fm/vine_v0")
84
+ return None
85
+
86
+
87
+ def example_pipeline_with_pretrained():
88
+ """Example using pipeline with pretrained VINE weights."""
89
+ print("\n=== Pipeline with Pretrained VINE ===")
90
+
91
+ # Register the pipeline
92
+ PIPELINE_REGISTRY.register_pipeline(
93
+ "vine-video-understanding",
94
+ pipeline_class=VinePipeline,
95
+ pt_model=VineModel,
96
+ type="multimodal",
97
+ )
98
+
99
+ # Create configuration with your weights
100
+ pretrained_vine_file = "/path/to/your/local/laser_model_v1.pt" # Replace with your local path
101
+ config = VineConfig(
102
+ model_name="openai/clip-vit-base-patch32",
103
+ segmentation_method="grounding_dino_sam2",
104
+ visualize=True,
105
+ visualization_dir="path/to/visualization/dir",
106
+ debug_visualizations=True,
107
+ use_hf_repo=False,
108
+ local_dir=os.path.dirname(pretrained_vine_file),
109
+ local_filename=os.path.basename(pretrained_vine_file),
110
+ )
111
+
112
+ # Create model with pretrained weights
113
+ vine_model = VineModel(config)
114
+
115
+ # Create pipeline with segmentation model paths
116
+ vine_pipeline = VinePipeline(
117
+ model=vine_model,
118
+ tokenizer=None,
119
+ sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
120
+ sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
121
+ gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
122
+ gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
123
+ device=0
124
+ )
125
+
126
+ print("✓ Pipeline created with pretrained VINE weights")
127
+
128
+ # Example usage (would require actual video file)
129
+ demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/v1.mp4")
130
+
131
+ if os.path.exists(demo_video):
132
+ print(f"Found demo video: {demo_video}")
133
+ print("Example pipeline call:")
134
+ print(f"results = vine_pipeline(")
135
+ print(f" '{demo_video}',")
136
+ print(f" categorical_keywords=['human', 'dog', 'frisbee'],")
137
+ print(f" unary_keywords=['running', 'jumping', 'sitting'],")
138
+ print(f" binary_keywords=['behind', 'chasing', 'next to']")
139
+ print(f" debug_visualizations=True")
140
+ print(f")")
141
+
142
+ # Uncomment to actually run (requires segmentation models)
143
+ # results = vine_pipeline(
144
+ # demo_video,
145
+ # categorical_keywords=['human', 'dog', 'frisbee'],
146
+ # unary_keywords=['running', 'jumping', 'sitting'],
147
+ # binary_keywords=['behind', 'chasing', 'next to'],
148
+ # debug_visualizations=True,
149
+ # )
150
+ # print("Results:", results['summary'])
151
+
152
+ return vine_pipeline
153
+
154
+
155
+
156
+ def example_manual_weight_loading():
157
+ """Example of manually loading weights after model creation."""
158
+ print("\n=== Manual Weight Loading ===")
159
+
160
+ # Create model with base CLIP weights
161
+ # No pretrained path: create base config (no HF repo or local file configured)
162
+ config = VineConfig()
163
+ vine_model = VineModel(config)
164
+ print("✓ Model created with base CLIP weights")
165
+ model_dir = "/path/to/your/local/ensemble/model_dir.pt" # Replace with your model directory
166
+
167
+ if os.path.exists(model_dir):
168
+ success = vine_model.load_pretrained_vine_weights(model_dir, epoch=0)
169
+ if success:
170
+ print("✓ Successfully loaded pretrained VINE weights manually")
171
+ else:
172
+ print("✗ Failed to load pretrained weights")
173
+ else:
174
+ print(f"✗ Model directory not found: {model_dir}")
175
+
176
+ return vine_model
177
+
178
+
179
+ def compare_model_outputs():
180
+ """Compare outputs between base CLIP and pretrained VINE."""
181
+ print("\n=== Comparing Model Outputs ===")
182
+
183
+ # Create dummy data for testing
184
+ video_frames = torch.randn(3, 224, 224, 3) * 255 # 3 frames
185
+ video_frames = video_frames.clamp(0, 255).byte()
186
+
187
+ masks = {
188
+ 0: {1: torch.ones(224, 224, 1)},
189
+ 1: {1: torch.ones(224, 224, 1)},
190
+ 2: {1: torch.ones(224, 224, 1)}
191
+ }
192
+
193
+ bboxes = {
194
+ 0: {1: [50, 50, 150, 150]},
195
+ 1: {1: [52, 52, 152, 152]},
196
+ 2: {1: [54, 54, 154, 154]}
197
+ }
198
+
199
+ keywords = ['human', 'dog', 'frisbee']
200
+
201
+ # Model 1: Base CLIP
202
+ print("Creating model with base CLIP weights...")
203
+ config_base = VineConfig()
204
+ model_base = VineModel(config_base)
205
+
206
+ # Model 2: Pretrained VINE (if available)
207
+ data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data"))
208
+ model_dir = os.path.join(data_dir, "LLaVA-Video-178K-v2/models/ensemble-02-10")
209
+
210
+ if os.path.exists(model_dir):
211
+ print("Creating model with pretrained VINE weights...")
212
+ config_vine = VineConfig(
213
+ use_hf_repo=False,
214
+ local_dir=model_dir,
215
+ local_filename=None,
216
+ )
217
+ model_vine = VineModel(config_vine)
218
+
219
+ print("\nComparing predictions...")
220
+
221
+ # Get predictions from both models
222
+ with torch.no_grad():
223
+ results_base = model_base.predict(
224
+ video_frames=video_frames,
225
+ masks=masks,
226
+ bboxes=bboxes,
227
+ categorical_keywords=keywords,
228
+ return_top_k=3
229
+ )
230
+
231
+ results_vine = model_vine.predict(
232
+ video_frames=video_frames,
233
+ masks=masks,
234
+ bboxes=bboxes,
235
+ categorical_keywords=keywords,
236
+ return_top_k=3
237
+ )
238
+
239
+ print("Base CLIP confidence scores:", results_base['confidence_scores'])
240
+ print("Pretrained VINE confidence scores:", results_vine['confidence_scores'])
241
+
242
+ print("✓ Successfully compared both models")
243
+ else:
244
+ print(f"Pretrained model not found at: {model_dir}")
245
+ print("Skipping comparison")
246
+
247
+
248
+ if __name__ == "__main__":
249
+ print("VINE HuggingFace Interface - Pretrained Weights Examples")
250
+ print("=" * 60)
251
+
252
+ try:
253
+ # Test local pretrained weights
254
+ model1 = example_with_local_pretrained_weights()
255
+ except Exception as e:
256
+ print(f"Local weights example failed: {e}")
257
+
258
+ try:
259
+ # Test HuggingFace Hub weights
260
+ model2 = example_with_huggingface_hub()
261
+ except Exception as e:
262
+ print(f"HuggingFace Hub example failed: {e}")
263
+
264
+ try:
265
+ # Test pipeline with pretrained weights
266
+ pipeline = example_pipeline_with_pretrained()
267
+ except Exception as e:
268
+ print(f"Pipeline example failed: {e}")
269
+
270
+ # try:
271
+ # # Test manual weight loading
272
+ # #model3 = example_manual_weight_loading()
273
+ # except Exception as e:
274
+ # print(f"Manual loading example failed: {e}")
275
+
276
+ # try:
277
+ # # Compare model outputs
278
+ # #compare_model_outputs()
279
+ # except Exception as e:
280
+ # print(f"Comparison example failed: {e}")
281
+
282
+ print("\n" + "=" * 60)
283
+ print("Examples completed!")
284
+ print("\nUsage Summary:")
285
+ print("1. Configure VineConfig with `use_hf_repo` + `model_repo` for Hub models, or `use_hf_repo=False` + `local_dir`/`local_filename` for local weights")
286
+ print("2. Use VineModel.from_pretrained_vine() for direct loading")
287
+
vine_hf/flattening.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ MaskType = Union[np.ndarray, torch.Tensor]
11
+
12
+
13
+ def _to_numpy_mask(mask: MaskType) -> np.ndarray:
14
+ """
15
+ Convert assorted mask formats to a 2D numpy boolean array.
16
+ """
17
+ if isinstance(mask, torch.Tensor):
18
+ mask_np = mask.detach().cpu().numpy()
19
+ else:
20
+ mask_np = np.asarray(mask)
21
+
22
+ # Remove singleton dimensions at the front/back
23
+ while mask_np.ndim > 2 and mask_np.shape[0] == 1:
24
+ mask_np = np.squeeze(mask_np, axis=0)
25
+ if mask_np.ndim > 2 and mask_np.shape[-1] == 1:
26
+ mask_np = np.squeeze(mask_np, axis=-1)
27
+
28
+ if mask_np.ndim != 2:
29
+ raise ValueError(f"Expected mask to be 2D after squeezing, got shape {mask_np.shape}")
30
+
31
+ return mask_np.astype(bool)
32
+
33
+
34
+ def _mask_to_bbox(mask: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
35
+ """
36
+ Compute a bounding box for a 2D boolean mask.
37
+ """
38
+ if not mask.any():
39
+ return None
40
+ rows, cols = np.nonzero(mask)
41
+ y_min, y_max = rows.min(), rows.max()
42
+ x_min, x_max = cols.min(), cols.max()
43
+ return x_min, y_min, x_max, y_max
44
+
45
+
46
+ def flatten_segments_for_batch(
47
+ video_id: int,
48
+ segments: Dict[int, Dict[int, MaskType]],
49
+ bbox_min_dim: int = 5,
50
+ ) -> Dict[str, List]:
51
+ """
52
+ Flatten nested segmentation data into batched lists suitable for predicate
53
+ models or downstream visualizations. Mirrors the notebook helper but is
54
+ robust to differing mask dtypes/shapes.
55
+ """
56
+ batched_object_ids: List[Tuple[int, int, int]] = []
57
+ batched_masks: List[np.ndarray] = []
58
+ batched_bboxes: List[Tuple[int, int, int, int]] = []
59
+ frame_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
60
+
61
+ for frame_id, frame_objects in segments.items():
62
+ valid_objects: List[int] = []
63
+ for object_id, raw_mask in frame_objects.items():
64
+ mask = _to_numpy_mask(raw_mask)
65
+ bbox = _mask_to_bbox(mask)
66
+ if bbox is None:
67
+ continue
68
+
69
+ x_min, y_min, x_max, y_max = bbox
70
+ if abs(y_max - y_min) < bbox_min_dim or abs(x_max - x_min) < bbox_min_dim:
71
+ continue
72
+
73
+ valid_objects.append(object_id)
74
+ batched_object_ids.append((video_id, frame_id, object_id))
75
+ batched_masks.append(mask)
76
+ batched_bboxes.append(bbox)
77
+
78
+ for i in valid_objects:
79
+ for j in valid_objects:
80
+ if i == j:
81
+ continue
82
+ frame_pairs.append((video_id, frame_id, (i, j)))
83
+
84
+ return {
85
+ "object_ids": batched_object_ids,
86
+ "masks": batched_masks,
87
+ "bboxes": batched_bboxes,
88
+ "pairs": frame_pairs,
89
+ }
90
+
91
+
92
+ def extract_valid_object_pairs(
93
+ batched_object_ids: Sequence[Tuple[int, int, int]],
94
+ interested_object_pairs: Optional[Iterable[Tuple[int, int]]] = None,
95
+ ) -> List[Tuple[int, int, Tuple[int, int]]]:
96
+ """
97
+ Filter object pairs per frame. If `interested_object_pairs` is provided, only
98
+ emit those combinations when both objects are present; otherwise emit all
99
+ permutations (i, j) with i != j for each frame.
100
+ """
101
+ frame_to_objects: Dict[Tuple[int, int], set] = defaultdict(set)
102
+ for vid, fid, oid in batched_object_ids:
103
+ frame_to_objects[(vid, fid)].add(oid)
104
+
105
+ interested = (
106
+ list(interested_object_pairs)
107
+ if interested_object_pairs is not None
108
+ else None
109
+ )
110
+
111
+ valid_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
112
+ for (vid, fid), object_ids in frame_to_objects.items():
113
+ if interested:
114
+ for src, dst in interested:
115
+ if src in object_ids and dst in object_ids:
116
+ valid_pairs.append((vid, fid, (src, dst)))
117
+ else:
118
+ for src in object_ids:
119
+ for dst in object_ids:
120
+ if src == dst:
121
+ continue
122
+ valid_pairs.append((vid, fid, (src, dst)))
123
+
124
+ return valid_pairs
vine_hf/push_to_hub.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to push VINE model to HuggingFace Hub
3
+
4
+ This script helps you push your trained VINE model to the HuggingFace Hub
5
+ for easy sharing and distribution.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ import argparse
12
+ from huggingface_hub import notebook_login
13
+ from transformers.pipelines import PIPELINE_REGISTRY
14
+
15
+ # Add the parent directory to the path to import vine_hf
16
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ os.environ['OPENAI_API_KEY'] = "dummy-key"
19
+ from vine_hf import VineConfig, VineModel, VinePipeline
20
+
21
+
22
+ def push_vine_to_hub(
23
+ model_weights_path: str,
24
+ repo_name: str,
25
+ model_name: str = "openai/clip-vit-base-patch32",
26
+ segmentation_method: str = "grounding_dino_sam2",
27
+ commit_message: str = "Upload VINE model",
28
+ private: bool = False
29
+ ):
30
+ """
31
+ Push VINE model to HuggingFace Hub.
32
+
33
+ Args:
34
+ model_weights_path: Path to the trained model weights (.pth file)
35
+ repo_name: Name for the repository (e.g., "username/vine-model")
36
+ model_name: CLIP model backbone name
37
+ segmentation_method: Segmentation method used
38
+ commit_message: Commit message for the push
39
+ private: Whether to create a private repository
40
+ """
41
+
42
+ print("=== Pushing VINE Model to HuggingFace Hub ===")
43
+
44
+ # 1. Create configuration
45
+ print(f"Creating configuration with backbone: {model_name}")
46
+ config = VineConfig(
47
+ model_name=model_name,
48
+ segmentation_method=segmentation_method
49
+ )
50
+
51
+ # 2. Initialize model
52
+ print("Initializing model...")
53
+ model = VineModel(config)
54
+
55
+ # 3. Load trained weights
56
+ if os.path.exists(model_weights_path):
57
+ print(f"Loading weights from: {model_weights_path}")
58
+ try:
59
+ # Try loading with weights_only=False for compatibility
60
+ weights = torch.load(model_weights_path, map_location='cpu', weights_only=False)
61
+
62
+ # Handle different weight formats
63
+ if isinstance(weights, dict):
64
+ if 'state_dict' in weights:
65
+ model.load_state_dict(weights['state_dict'])
66
+ elif 'model' in weights:
67
+ model.load_state_dict(weights['model'])
68
+ else:
69
+ model.load_state_dict(weights)
70
+ else:
71
+ # Assume it's the model directly
72
+ model = weights
73
+
74
+ print("✓ Weights loaded successfully")
75
+ except Exception as e:
76
+ print(f"✗ Error loading weights: {e}")
77
+ print("Please check your weights file format")
78
+ return False
79
+ else:
80
+ print(f"✗ Weights file not found: {model_weights_path}")
81
+ return False
82
+
83
+ # 4. Register for auto classes
84
+ print("Registering for auto classes...")
85
+ config.register_for_auto_class()
86
+ model.register_for_auto_class("AutoModel")
87
+
88
+ # 5. Register pipeline
89
+ print("Registering pipeline...")
90
+ PIPELINE_REGISTRY.register_pipeline(
91
+ "vine-video-understanding",
92
+ pipeline_class=VinePipeline,
93
+ pt_model=VineModel,
94
+ type="multimodal",
95
+ )
96
+
97
+ # 6. Create pipeline instance
98
+ print("Creating pipeline...")
99
+ vine_pipeline = VinePipeline(model=model, tokenizer=None)
100
+
101
+ try:
102
+ # 7. Push configuration to hub
103
+ print(f"Pushing configuration to {repo_name}...")
104
+ config.push_to_hub(
105
+ repo_name,
106
+ commit_message=f"{commit_message} - config",
107
+ private=private
108
+ )
109
+ print("✓ Configuration pushed successfully")
110
+
111
+ # 8. Push model to hub
112
+ print(f"Pushing model to {repo_name}...")
113
+ model.push_to_hub(
114
+ repo_name,
115
+ commit_message=f"{commit_message} - model",
116
+ private=private
117
+ )
118
+ print("✓ Model pushed successfully")
119
+
120
+ # 9. Push pipeline to hub
121
+ print(f"Pushing pipeline to {repo_name}...")
122
+ vine_pipeline.push_to_hub(
123
+ repo_name,
124
+ commit_message=f"{commit_message} - pipeline",
125
+ private=private
126
+ )
127
+ print("✓ Pipeline pushed successfully")
128
+
129
+ print(f"\n🎉 Successfully pushed VINE model to: https://huggingface.co/{repo_name}")
130
+ print(f"\nTo use your model:")
131
+ print(f"```python")
132
+ print(f"from transformers import pipeline")
133
+ print(f"")
134
+ print(f"vine_pipeline = pipeline(")
135
+ print(f" 'vine-video-understanding',")
136
+ print(f" model='{repo_name}',")
137
+ print(f" trust_remote_code=True")
138
+ print(f")")
139
+ print(f"")
140
+ print(f"results = vine_pipeline(")
141
+ print(f" 'path/to/video.mp4',")
142
+ print(f" categorical_keywords=['human', 'dog', 'frisbee'],")
143
+ print(f" unary_keywords=['running', 'jumping'],")
144
+ print(f" binary_keywords=['chasing', 'behind']")
145
+ print(f")")
146
+ print(f"```")
147
+
148
+ return True
149
+
150
+ except Exception as e:
151
+ print(f"✗ Error pushing to hub: {e}")
152
+ print("Please check your HuggingFace credentials and repository permissions")
153
+ return False
154
+
155
+
156
+ def main():
157
+ parser = argparse.ArgumentParser(description="Push VINE model to HuggingFace Hub")
158
+
159
+ parser.add_argument(
160
+ "--weights",
161
+ type=str,
162
+ required=True,
163
+ help="Path to the trained model weights (.pth file)"
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--repo",
168
+ type=str,
169
+ required=True,
170
+ help="Repository name (e.g., 'username/vine-model')"
171
+ )
172
+
173
+ parser.add_argument(
174
+ "--model-name",
175
+ type=str,
176
+ default="openai/clip-vit-base-patch32",
177
+ help="CLIP model backbone name"
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--segmentation",
182
+ type=str,
183
+ default="grounding_dino_sam2",
184
+ choices=["sam2", "grounding_dino_sam2"],
185
+ help="Segmentation method"
186
+ )
187
+
188
+ parser.add_argument(
189
+ "--message",
190
+ type=str,
191
+ default="Upload VINE model",
192
+ help="Commit message"
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--private",
197
+ action="store_true",
198
+ help="Create private repository"
199
+ )
200
+
201
+ parser.add_argument(
202
+ "--login",
203
+ action="store_true",
204
+ help="Login to HuggingFace Hub first"
205
+ )
206
+
207
+ args = parser.parse_args()
208
+
209
+ # Login if requested
210
+ if args.login:
211
+ print("Logging in to HuggingFace Hub...")
212
+ notebook_login()
213
+
214
+ # Push model
215
+ success = push_vine_to_hub(
216
+ model_weights_path=args.weights,
217
+ repo_name=args.repo,
218
+ model_name=args.model_name,
219
+ segmentation_method=args.segmentation,
220
+ commit_message=args.message,
221
+ private=args.private
222
+ )
223
+
224
+ if success:
225
+ print("\n✅ Model successfully pushed to HuggingFace Hub!")
226
+ else:
227
+ print("\n❌ Failed to push model to HuggingFace Hub")
228
+ sys.exit(1)
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()
vine_hf/push_to_video_fm.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to push VINE model to video-fm organization on HuggingFace Hub
3
+
4
+ This script pushes the VINE architecture (config, model, pipeline) and model weights
5
+ to the video-fm organization for easy sharing and distribution.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ import argparse
12
+ from pathlib import Path
13
+ from huggingface_hub import HfApi, login
14
+ from transformers.pipelines import PIPELINE_REGISTRY
15
+ from transformers import AutoModel
16
+ from safetensors.torch import save_file
17
+
18
+ # Add the parent directory to path to enable vine_hf imports
19
+ current_dir = Path(__file__).parent
20
+ parent_dir = current_dir.parent
21
+ sys.path.insert(0, str(parent_dir))
22
+
23
+ os.environ['OPENAI_API_KEY'] = "dummy-key"
24
+
25
+ # Import from vine_hf package
26
+ from vine_hf import VineConfig, VineModel, VinePipeline
27
+
28
+
29
+ def push_vine_to_video_fm(
30
+ source_repo_or_path: str = "KevinX-Penn28/testing",
31
+ target_repo: str = "video-fm/vine",
32
+ model_name: str = "openai/clip-vit-base-patch32",
33
+ commit_message: str = "Upload VINE model architecture and weights",
34
+ private: bool = False,
35
+ use_local_weights: bool = False,
36
+ ):
37
+ """
38
+ Push VINE model to video-fm organization on HuggingFace Hub.
39
+
40
+ Args:
41
+ source_repo_or_path: Source HF repo or local path with model weights
42
+ target_repo: Target repository (e.g., "video-fm/vine")
43
+ model_name: CLIP model backbone name
44
+ commit_message: Commit message for the push
45
+ private: Whether to create a private repository
46
+ use_local_weights: If True, source_repo_or_path is a local file path
47
+ """
48
+
49
+ print("=" * 70)
50
+ print("🚀 Pushing VINE Model to HuggingFace Hub - video-fm Organization")
51
+ print("=" * 70)
52
+
53
+ # 1. Create configuration
54
+ print(f"\n📝 Creating configuration with backbone: {model_name}")
55
+ config = VineConfig(
56
+ model_name=model_name,
57
+ segmentation_method="grounding_dino_sam2",
58
+ use_hf_repo=not use_local_weights,
59
+ model_repo=source_repo_or_path if not use_local_weights else None,
60
+ local_dir=str(Path(source_repo_or_path).parent) if use_local_weights else None,
61
+ local_filename=Path(source_repo_or_path).name if use_local_weights else None,
62
+ )
63
+
64
+ # 2. Initialize model (will automatically load weights from source)
65
+ print(f"\n🔧 Initializing model and loading weights from: {source_repo_or_path}")
66
+ model = VineModel(config)
67
+ print("✓ Model initialized with weights loaded")
68
+
69
+ # 3. Register for auto classes
70
+ print("\n📋 Registering for auto classes...")
71
+ config.register_for_auto_class()
72
+ model.register_for_auto_class("AutoModel")
73
+ print("✓ Registered for AutoModel and AutoConfig")
74
+
75
+ # 4. Register pipeline
76
+ print("\n🔌 Registering custom pipeline...")
77
+ try:
78
+ PIPELINE_REGISTRY.register_pipeline(
79
+ "vine-video-understanding",
80
+ pipeline_class=VinePipeline,
81
+ pt_model=VineModel,
82
+ type="multimodal",
83
+ )
84
+ print("✓ Pipeline registered")
85
+ except Exception as e:
86
+ print(f"⚠ Pipeline registration: {e} (may already be registered)")
87
+
88
+ try:
89
+ # 5. Push configuration to hub
90
+ print(f"\n⬆️ Pushing configuration to {target_repo}...")
91
+ config.push_to_hub(
92
+ target_repo,
93
+ commit_message=f"{commit_message} - config",
94
+ private=private
95
+ )
96
+ print("✓ Configuration pushed successfully")
97
+
98
+ # 6. Push model to hub
99
+ print(f"\n⬆️ Pushing model to {target_repo}...")
100
+ model.push_to_hub(
101
+ target_repo,
102
+ commit_message=f"{commit_message} - model and weights",
103
+ private=private
104
+ )
105
+ print("✓ Model and weights pushed successfully")
106
+
107
+ # 7. Copy additional necessary files to the repo
108
+ print(f"\n📦 Uploading additional architecture files...")
109
+ api = HfApi()
110
+
111
+ # Upload flattening.py and vis_utils.py as they're imported by the model
112
+ current_dir = Path(__file__).parent
113
+ additional_files = [
114
+ "flattening.py",
115
+ "vis_utils.py",
116
+ ]
117
+
118
+ for filename in additional_files:
119
+ file_path = current_dir / filename
120
+ if file_path.exists():
121
+ api.upload_file(
122
+ path_or_fileobj=str(file_path),
123
+ path_in_repo=filename,
124
+ repo_id=target_repo,
125
+ commit_message=f"Add {filename}",
126
+ )
127
+ print(f"✓ Uploaded {filename}")
128
+ else:
129
+ print(f"⚠ Warning: {filename} not found at {file_path}")
130
+
131
+ # 8. Upload README if it exists
132
+ readme_path = current_dir / "README.md"
133
+ if readme_path.exists():
134
+ api.upload_file(
135
+ path_or_fileobj=str(readme_path),
136
+ path_in_repo="README.md",
137
+ repo_id=target_repo,
138
+ commit_message="Add README documentation",
139
+ )
140
+ print("✓ Uploaded README.md")
141
+
142
+ print("\n" + "=" * 70)
143
+ print("🎉 Successfully pushed VINE model to HuggingFace Hub!")
144
+ print("=" * 70)
145
+ print(f"\n📍 Model URL: https://huggingface.co/{target_repo}")
146
+ print(f"\n📚 To use your model:")
147
+ print(f"""
148
+ ```python
149
+ from transformers import AutoModel, AutoConfig
150
+ from vine_hf import VineConfig, VineModel, VinePipeline
151
+
152
+ # Option 1: Load with AutoModel
153
+ model = AutoModel.from_pretrained('{target_repo}', trust_remote_code=True)
154
+
155
+ # Option 2: Load with VineModel directly
156
+ config = VineConfig.from_pretrained('{target_repo}')
157
+ model = VineModel.from_pretrained('{target_repo}')
158
+
159
+ # Option 3: Use with pipeline
160
+ from transformers import pipeline
161
+
162
+ vine_pipeline = pipeline(
163
+ 'vine-video-understanding',
164
+ model='{target_repo}',
165
+ trust_remote_code=True
166
+ )
167
+
168
+ results = vine_pipeline(
169
+ 'path/to/video.mp4',
170
+ categorical_keywords=['human', 'dog', 'frisbee'],
171
+ unary_keywords=['running', 'jumping'],
172
+ binary_keywords=['chasing', 'behind']
173
+ )
174
+ ```
175
+ """)
176
+
177
+ return True
178
+
179
+ except Exception as e:
180
+ print(f"\n❌ Error pushing to hub: {e}")
181
+ import traceback
182
+ traceback.print_exc()
183
+ print("\nPlease check:")
184
+ print(" - HuggingFace credentials (run: huggingface-cli login)")
185
+ print(" - Repository permissions for video-fm organization")
186
+ print(" - Network connectivity")
187
+ return False
188
+
189
+
190
+ def main():
191
+ parser = argparse.ArgumentParser(
192
+ description="Push VINE model to video-fm organization on HuggingFace Hub"
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--source",
197
+ type=str,
198
+ default="KevinX-Penn28/testing",
199
+ help="Source HF repo or local path with model weights (default: KevinX-Penn28/testing)"
200
+ )
201
+
202
+ parser.add_argument(
203
+ "--target",
204
+ type=str,
205
+ default="video-fm/vine",
206
+ help="Target repository in video-fm org (default: video-fm/vine)"
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--model-name",
211
+ type=str,
212
+ default="openai/clip-vit-base-patch32",
213
+ help="CLIP model backbone name"
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--message",
218
+ type=str,
219
+ default="Upload VINE model architecture and weights",
220
+ help="Commit message"
221
+ )
222
+
223
+ parser.add_argument(
224
+ "--private",
225
+ action="store_true",
226
+ help="Create private repository"
227
+ )
228
+
229
+ parser.add_argument(
230
+ "--local-weights",
231
+ action="store_true",
232
+ help="Use local weights file instead of HF repo"
233
+ )
234
+
235
+ args = parser.parse_args()
236
+
237
+ # Check login status
238
+ try:
239
+ api = HfApi()
240
+ user_info = api.whoami()
241
+ print(f"✓ Logged in as: {user_info['name']}")
242
+
243
+ # Check if user has access to video-fm org
244
+ orgs = [org['name'] for org in user_info.get('orgs', [])]
245
+ if 'video-fm' in orgs:
246
+ print(f"✓ Confirmed access to video-fm organization")
247
+ else:
248
+ print(f"⚠ Warning: You may not have access to video-fm organization")
249
+ print(f" Your organizations: {orgs}")
250
+ except Exception as e:
251
+ print(f"❌ Not logged in to HuggingFace. Please run: huggingface-cli login")
252
+ print(f" Or use: python -c 'from huggingface_hub import login; login()'")
253
+ sys.exit(1)
254
+
255
+ # Push model
256
+ success = push_vine_to_video_fm(
257
+ source_repo_or_path=args.source,
258
+ target_repo=args.target,
259
+ model_name=args.model_name,
260
+ commit_message=args.message,
261
+ private=args.private,
262
+ use_local_weights=args.local_weights,
263
+ )
264
+
265
+ if success:
266
+ print("\n✅ Successfully completed!")
267
+ sys.exit(0)
268
+ else:
269
+ print("\n❌ Push failed!")
270
+ sys.exit(1)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main()
vine_hf/setup.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Setup script for VINE HuggingFace Interface
3
+ """
4
+
5
+ from setuptools import setup
6
+
7
+ with open("README.md", "r", encoding="utf-8") as fh:
8
+ long_description = fh.read()
9
+
10
+ setup(
11
+ name="vine-hf",
12
+ version="1.0.0",
13
+ author="LASER Team",
14
+ author_email="your-email@example.com",
15
+ description="HuggingFace interface for VINE (Video Understanding with Natural Language)",
16
+ long_description=long_description,
17
+ long_description_content_type="text/markdown",
18
+ url="https://github.com/your-username/vine-hf",
19
+ # Since all modules are in the root directory, we use py_modules instead of packages
20
+ py_modules=[
21
+ "vine_config",
22
+ "vine_model",
23
+ "vine_pipeline",
24
+ "vis_utils",
25
+ "flattening",
26
+ "convert_inference",
27
+ ],
28
+ # Also include __init__.py to make it a package
29
+ packages=["vine_hf"],
30
+ package_dir={"vine_hf": "."},
31
+ classifiers=[
32
+ "Development Status :: 4 - Beta",
33
+ "Intended Audience :: Developers",
34
+ "Intended Audience :: Science/Research",
35
+ "License :: OSI Approved :: MIT License",
36
+ "Operating System :: OS Independent",
37
+ "Programming Language :: Python :: 3",
38
+ "Programming Language :: Python :: 3.7",
39
+ "Programming Language :: Python :: 3.8",
40
+ "Programming Language :: Python :: 3.9",
41
+ "Programming Language :: Python :: 3.10",
42
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
43
+ "Topic :: Multimedia :: Video",
44
+ ],
45
+ python_requires=">=3.7",
46
+ install_requires=[
47
+ "torch>=1.9.0",
48
+ "torchvision>=0.10.0",
49
+ "transformers>=4.20.0",
50
+ "opencv-python>=4.5.0",
51
+ "pillow>=8.0.0",
52
+ "numpy>=1.20.0",
53
+ "huggingface-hub>=0.10.0",
54
+ "tqdm>=4.60.0",
55
+ ],
56
+ extras_require={
57
+ "dev": [
58
+ "pytest>=6.0",
59
+ "black>=22.0",
60
+ "flake8>=4.0",
61
+ "isort>=5.0",
62
+ ],
63
+ "segmentation": [
64
+ # Note: SAM2 and Grounding DINO need to be installed separately
65
+ # as they're not available on PyPI
66
+ ],
67
+ },
68
+ entry_points={
69
+ "console_scripts": [
70
+ "vine-push-to-hub=vine_hf.push_to_hub:main",
71
+ ],
72
+ },
73
+ )
vine_hf/vine_config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig
3
+ from typing import List, Optional, Dict, Any, Tuple, Union
4
+ from pathlib import Path
5
+
6
+
7
+ class VineConfig(PretrainedConfig):
8
+ """
9
+ Configuration class for VINE (Video Understanding with Natural Language) model.
10
+ """
11
+
12
+ model_type = "vine"
13
+
14
+ def __init__(
15
+ self,
16
+ model_name: str = "openai/clip-vit-base-patch32",
17
+ hidden_dim: int = 768,
18
+ use_hf_repo: bool = True,
19
+ model_repo: Optional[str] = "KevinX-Penn28/testing",
20
+ model_file: Optional[str] = None,
21
+ local_dir: Optional[str] = str(Path(__file__).resolve().parent),
22
+ local_filename: Optional[str] = "laser_model_v1.pkl",
23
+ num_top_pairs: int = 18,
24
+ segmentation_method: str = "grounding_dino_sam2",
25
+ box_threshold: float = 0.35,
26
+ text_threshold: float = 0.25,
27
+ target_fps: int = 1,
28
+ alpha: float = 0.5,
29
+ white_alpha: float = 0.8,
30
+ topk_cate: int = 3,
31
+ multi_class: bool = False,
32
+ output_logit: bool = False,
33
+ use_pretrained_cate_weights: bool = False,
34
+ categorical_pool: str = "mean", # "mean" or "max"
35
+ max_video_length: int = 100,
36
+ bbox_min_dim: int = 1,
37
+ visualize: bool = False,
38
+ visualization_dir: Optional[str] = None,
39
+ return_flattened_segments: bool = False,
40
+ return_valid_pairs: bool = False,
41
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
42
+ debug_visualizations: bool = False,
43
+ device: Optional[Union[str, int]] = None,
44
+ **kwargs: Any,
45
+ ):
46
+ self.model_name = model_name
47
+ self.use_hf_repo = use_hf_repo
48
+ if use_hf_repo:
49
+ self.model_repo = model_repo
50
+ self.model_file = model_file
51
+ self.local_dir = None
52
+ self.local_filename = None
53
+ else:
54
+ self.model_repo = None
55
+ self.model_file = None
56
+ self.local_dir = local_dir
57
+ self.local_filename = local_filename
58
+
59
+ self.hidden_dim = hidden_dim
60
+ self.num_top_pairs = num_top_pairs
61
+ self.segmentation_method = segmentation_method
62
+ self.box_threshold = box_threshold
63
+ self.text_threshold = text_threshold
64
+ self.target_fps = target_fps
65
+ self.alpha = alpha
66
+ self.white_alpha = white_alpha
67
+ self.topk_cate = topk_cate
68
+ self.multi_class = multi_class
69
+ self.output_logit = output_logit
70
+ self.use_pretrained_cate_weights = use_pretrained_cate_weights
71
+ self.categorical_pool = categorical_pool
72
+ self.max_video_length = max_video_length
73
+ self.bbox_min_dim = bbox_min_dim
74
+ self.visualize = visualize
75
+ self.visualization_dir = visualization_dir
76
+ self.return_flattened_segments = return_flattened_segments
77
+ self.return_valid_pairs = return_valid_pairs
78
+ self.interested_object_pairs = interested_object_pairs or []
79
+ self.debug_visualizations = debug_visualizations
80
+
81
+ if isinstance(device, int):
82
+ self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
83
+ else:
84
+ self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
85
+
86
+ super().__init__(**kwargs)
vine_hf/vine_model.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Dict, List, Tuple, Optional, Any, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from safetensors.torch import load_file
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as cp
12
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor
13
+ from huggingface_hub import snapshot_download
14
+
15
+ from .vine_config import VineConfig
16
+ from laser.models import llava_clip_model_v3
17
+ sys.modules["llava_clip_model_v3"] = llava_clip_model_v3
18
+ from laser.models.model_utils import (
19
+ extract_single_object,
20
+ extract_object_subject,
21
+ crop_image_contain_bboxes,
22
+ segment_list,
23
+ )
24
+ from .flattening import (
25
+ extract_valid_object_pairs,
26
+ flatten_segments_for_batch,
27
+ )
28
+ from .vis_utils import save_mask_one_image
29
+
30
+
31
+ class VineModel(PreTrainedModel):
32
+ """
33
+ VINE (Video Understanding with Natural Language) Model.
34
+
35
+ Internally, the core CLIP/text/image/pair logic mirrors
36
+ llava_clip_model_v3.PredicateModel as closely as possible for a single video,
37
+ with a small extension to re-normalize categorical probs after pooling.
38
+ """
39
+
40
+ config_class = VineConfig
41
+
42
+ def __init__(self, config: VineConfig):
43
+ super().__init__(config)
44
+ self.config = config
45
+ self.visualize = getattr(config, "visualize", False)
46
+ self.visualization_dir = getattr(config, "visualization_dir", None)
47
+ self.debug_visualizations = getattr(config, "debug_visualizations", False)
48
+ self._device = getattr(config, "_device")
49
+
50
+ # CLIP components
51
+ self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
52
+ if self.clip_tokenizer.pad_token is None:
53
+ self.clip_tokenizer.pad_token = (
54
+ self.clip_tokenizer.unk_token
55
+ if self.clip_tokenizer.unk_token
56
+ else self.clip_tokenizer.eos_token
57
+ )
58
+
59
+ self.clip_processor = AutoProcessor.from_pretrained(config.model_name)
60
+ self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
61
+ self.clip_unary_model = AutoModel.from_pretrained(config.model_name)
62
+ self.clip_binary_model = AutoModel.from_pretrained(config.model_name)
63
+
64
+ # Load fine-tuned weights if available
65
+ if config.use_hf_repo:
66
+ self._load_huggingface_vine_weights(config.model_repo, config.model_file)
67
+ else:
68
+ self._load_local_pretrained_vine_weights(
69
+ config.local_dir, config.local_filename
70
+ )
71
+
72
+ # Optionally reset categorical model to base CLIP (ignore fine-tune)
73
+ if not getattr(config, "use_pretrained_cate_weights", True):
74
+ self.clip_cate_model = AutoModel.from_pretrained(config.model_name)
75
+ self.clip_cate_model.to(self._device)
76
+
77
+ self.to(self._device)
78
+
79
+ # ------------------------------------------------------------------ #
80
+ # Weight loading
81
+ # ------------------------------------------------------------------ #
82
+ def _load_huggingface_vine_weights(
83
+ self, model_repo: str, model_file: Optional[str] = None
84
+ ):
85
+ try:
86
+ print(f"Loading VINE weights from HuggingFace repo: {model_repo}")
87
+ repo_path = snapshot_download(model_repo, revision=model_file or "main")
88
+ weights = load_file(os.path.join(repo_path, "model.safetensors"))
89
+ self.load_state_dict(weights, strict=False)
90
+ print("✓ Successfully loaded VINE weights from HuggingFace Hub")
91
+ return True
92
+ except Exception as e:
93
+ print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}")
94
+ print("Using base CLIP models instead")
95
+ return False
96
+
97
+ def _load_local_pretrained_vine_weights(
98
+ self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0
99
+ ):
100
+ if local_dir is None and local_filename is None:
101
+ return False
102
+
103
+ full_path = (
104
+ os.path.join(local_dir, local_filename) if local_filename else local_dir
105
+ )
106
+
107
+ # .pkl – usually pickled PredicateModel
108
+ if isinstance(full_path, str) and full_path.endswith(".pkl"):
109
+ print(f"Loading VINE weights from: {full_path}")
110
+ loaded_vine_model = torch.load(
111
+ full_path, map_location=self._device, weights_only=False
112
+ )
113
+ print(f"Loaded state type: {type(loaded_vine_model)}")
114
+
115
+ if not isinstance(loaded_vine_model, dict):
116
+ if hasattr(loaded_vine_model, "clip_tokenizer"):
117
+ self.clip_tokenizer = loaded_vine_model.clip_tokenizer
118
+ if hasattr(loaded_vine_model, "clip_processor"):
119
+ self.clip_processor = loaded_vine_model.clip_processor
120
+
121
+ if hasattr(loaded_vine_model, "clip_cate_model"):
122
+ self.clip_cate_model.load_state_dict(
123
+ loaded_vine_model.clip_cate_model.state_dict()
124
+ )
125
+ if hasattr(loaded_vine_model, "clip_unary_model"):
126
+ self.clip_unary_model.load_state_dict(
127
+ loaded_vine_model.clip_unary_model.state_dict()
128
+ )
129
+ if hasattr(loaded_vine_model, "clip_binary_model"):
130
+ self.clip_binary_model.load_state_dict(
131
+ loaded_vine_model.clip_binary_model.state_dict()
132
+ )
133
+ print("✓ Loaded VINE weights from .pkl PredicateModel checkpoint")
134
+ return True
135
+
136
+ # .pt / .pth – plain state_dict
137
+ elif isinstance(full_path, str) and (
138
+ full_path.endswith(".pt") or full_path.endswith(".pth")
139
+ ):
140
+ print(f"Loading VINE weights from: {full_path}")
141
+ state = torch.load(full_path, map_location=self._device, weights_only=True)
142
+ print(f"Loaded state type: {type(state)}")
143
+ self.load_state_dict(state, strict=False)
144
+ print("✓ Loaded VINE weights from state_dict")
145
+ return True
146
+
147
+ # .model – full PredicateModel object
148
+ elif isinstance(full_path, str) and full_path.endswith(".model"):
149
+ print(f"Loading VINE weights from: {full_path}")
150
+ pretrained_model = torch.load(
151
+ full_path, map_location="cpu", weights_only=False
152
+ )
153
+
154
+ if hasattr(pretrained_model, "clip_tokenizer"):
155
+ self.clip_tokenizer = pretrained_model.clip_tokenizer
156
+ if hasattr(pretrained_model, "clip_processor"):
157
+ self.clip_processor = pretrained_model.clip_processor
158
+
159
+ if hasattr(pretrained_model, "clip_cate_model"):
160
+ self.clip_cate_model.load_state_dict(
161
+ pretrained_model.clip_cate_model.state_dict()
162
+ )
163
+ if hasattr(pretrained_model, "clip_unary_model"):
164
+ self.clip_unary_model.load_state_dict(
165
+ pretrained_model.clip_unary_model.state_dict()
166
+ )
167
+ if hasattr(pretrained_model, "clip_binary_model"):
168
+ self.clip_binary_model.load_state_dict(
169
+ pretrained_model.clip_binary_model.state_dict()
170
+ )
171
+ print("✓ Loaded all sub-model weights from .model file")
172
+ return True
173
+
174
+ # directory of .model files
175
+ if isinstance(full_path, str) and os.path.isdir(full_path):
176
+ model_files = [
177
+ f for f in os.listdir(full_path) if f.endswith(f".{epoch}.model")
178
+ ]
179
+ if model_files:
180
+ model_file = os.path.join(full_path, model_files[0])
181
+ print(f"Loading VINE weights from: {model_file}")
182
+ pretrained_model = torch.load(model_file, map_location="cpu")
183
+
184
+ if hasattr(pretrained_model, "clip_tokenizer"):
185
+ self.clip_tokenizer = pretrained_model.clip_tokenizer
186
+ if hasattr(pretrained_model, "clip_processor"):
187
+ self.clip_processor = pretrained_model.clip_processor
188
+
189
+ if hasattr(pretrained_model, "clip_cate_model"):
190
+ self.clip_cate_model.load_state_dict(
191
+ pretrained_model.clip_cate_model.state_dict()
192
+ )
193
+ if hasattr(pretrained_model, "clip_unary_model"):
194
+ self.clip_unary_model.load_state_dict(
195
+ pretrained_model.clip_unary_model.state_dict()
196
+ )
197
+ if hasattr(pretrained_model, "clip_binary_model"):
198
+ self.clip_binary_model.load_state_dict(
199
+ pretrained_model.clip_binary_model.state_dict()
200
+ )
201
+ print("✓ Loaded all sub-model weights from ensemble format")
202
+ return True
203
+ else:
204
+ print(f"No model file found for epoch {epoch} in {full_path}")
205
+ return False
206
+
207
+ print("Unsupported format for pretrained VINE path:", full_path)
208
+ return False
209
+
210
+ @classmethod
211
+ def from_pretrained_vine(
212
+ cls,
213
+ model_path: str,
214
+ config: Optional[VineConfig] = None,
215
+ epoch: int = 0,
216
+ **kwargs: Any,
217
+ ):
218
+ if config is None:
219
+ if model_path and ("/" in model_path and not os.path.exists(model_path)):
220
+ config = VineConfig(use_hf_repo=True, model_repo=model_path)
221
+ else:
222
+ if os.path.isdir(model_path):
223
+ config = VineConfig(use_hf_repo=False, local_dir=model_path)
224
+ else:
225
+ config = VineConfig(
226
+ use_hf_repo=False,
227
+ local_dir=os.path.dirname(model_path) or None,
228
+ local_filename=os.path.basename(model_path) or None,
229
+ )
230
+ else:
231
+ if model_path and ("/" in model_path and not os.path.exists(model_path)):
232
+ config.use_hf_repo = True
233
+ config.model_repo = model_path
234
+ config.model_file = None
235
+ config.local_dir = None
236
+ config.local_filename = None
237
+ else:
238
+ config.use_hf_repo = False
239
+ if os.path.isdir(model_path):
240
+ config.local_dir = model_path
241
+ config.local_filename = None
242
+ else:
243
+ config.local_dir = os.path.dirname(model_path) or None
244
+ config.local_filename = os.path.basename(model_path) or None
245
+
246
+ model = cls(config, **kwargs)
247
+ return model
248
+
249
+ # ------------------------------------------------------------------ #
250
+ # Gradient checkpoint helpers
251
+ # ------------------------------------------------------------------ #
252
+ def _text_features_checkpoint(self, model, token_dict):
253
+ input_ids = token_dict["input_ids"]
254
+ attention_mask = token_dict["attention_mask"]
255
+ token_type_ids = token_dict.get("token_type_ids", None)
256
+
257
+ if token_type_ids is not None:
258
+
259
+ def forward_pass(input_ids, attention_mask, token_type_ids):
260
+ return model.get_text_features(
261
+ input_ids=input_ids,
262
+ attention_mask=attention_mask,
263
+ token_type_ids=token_type_ids,
264
+ )
265
+
266
+ return cp.checkpoint(
267
+ forward_pass,
268
+ input_ids,
269
+ attention_mask,
270
+ token_type_ids,
271
+ use_reentrant=False,
272
+ )
273
+ else:
274
+
275
+ def forward_pass(input_ids, attention_mask):
276
+ return model.get_text_features(
277
+ input_ids=input_ids,
278
+ attention_mask=attention_mask,
279
+ )
280
+
281
+ return cp.checkpoint(
282
+ forward_pass, input_ids, attention_mask, use_reentrant=False
283
+ )
284
+
285
+ def _image_features_checkpoint(self, model, pixel_values):
286
+ def forward_pass(pixel_values):
287
+ return model.get_image_features(pixel_values=pixel_values)
288
+
289
+ return cp.checkpoint(forward_pass, pixel_values, use_reentrant=False)
290
+
291
+ # ------------------------------------------------------------------ #
292
+ # CLIP similarity
293
+ # ------------------------------------------------------------------ #
294
+ def clip_sim(self, model, nl_feat, img_feat):
295
+ img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True)
296
+ nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True)
297
+
298
+ logit_scale = getattr(model, "logit_scale", None)
299
+ logits_per_text = torch.matmul(nl_feat, img_feat.t())
300
+ if logit_scale is not None:
301
+ logits_per_text = logits_per_text * logit_scale.exp()
302
+ return logits_per_text
303
+
304
+ # ------------------------------------------------------------------ #
305
+ # Forward: single-video PredicateModel-style logic
306
+ # ------------------------------------------------------------------ #
307
+ def forward(
308
+ self,
309
+ video_frames: torch.Tensor,
310
+ masks: Dict[int, Dict[int, torch.Tensor]],
311
+ bboxes: Dict[int, Dict[int, List]],
312
+ categorical_keywords: List[str],
313
+ unary_keywords: Optional[List[str]] = None,
314
+ binary_keywords: Optional[List[str]] = None,
315
+ object_pairs: Optional[List[Tuple[int, int]]] = None,
316
+ return_flattened_segments: Optional[bool] = None,
317
+ return_valid_pairs: Optional[bool] = None,
318
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
319
+ debug_visualizations: Optional[bool] = None,
320
+ **kwargs: Any,
321
+ ) -> Dict[str, Any]:
322
+ if unary_keywords is None:
323
+ unary_keywords = []
324
+ if binary_keywords is None:
325
+ binary_keywords = []
326
+ if object_pairs is None:
327
+ object_pairs = []
328
+
329
+ if return_flattened_segments is None:
330
+ return_flattened_segments = getattr(
331
+ self.config, "return_flattened_segments", False
332
+ )
333
+ if return_valid_pairs is None:
334
+ return_valid_pairs = getattr(self.config, "return_valid_pairs", False)
335
+ if interested_object_pairs is None or len(interested_object_pairs) == 0:
336
+ interested_object_pairs = (
337
+ getattr(self.config, "interested_object_pairs", []) or []
338
+ )
339
+ if debug_visualizations is None:
340
+ debug_visualizations = self.debug_visualizations
341
+
342
+ alpha = getattr(self.config, "alpha", 0.5)
343
+ white_alpha = getattr(self.config, "white_alpha", 0.8)
344
+ topk_cate = kwargs.pop("topk_cate", getattr(self.config, "topk_cate", 3))
345
+ dummy_str = kwargs.pop("dummy_str", getattr(self.config, "dummy_str", "$$$"))
346
+ multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
347
+ output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
348
+ output_embeddings = kwargs.pop("output_embeddings", False)
349
+
350
+ batched_video_ids = [0]
351
+
352
+ if torch.is_tensor(video_frames):
353
+ num_frames = video_frames.shape[0]
354
+ batched_videos = [
355
+ self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames)
356
+ ]
357
+ else:
358
+ num_frames = len(video_frames)
359
+ batched_videos = [
360
+ self._frame_to_numpy(video_frames[fid]) for fid in range(num_frames)
361
+ ]
362
+
363
+ batched_masks: List[np.ndarray] = []
364
+ batched_bboxes: List[List[float]] = []
365
+ batched_object_ids: List[Tuple[int, int, int]] = []
366
+
367
+ for frame_id, frame_masks in masks.items():
368
+ if frame_id >= num_frames:
369
+ continue
370
+ frame_boxes = bboxes.get(frame_id, {})
371
+ for obj_id, mask in frame_masks.items():
372
+ if obj_id not in frame_boxes:
373
+ continue
374
+ bbox = frame_boxes[obj_id]
375
+ batched_object_ids.append((0, frame_id, obj_id))
376
+ batched_masks.append(self._mask_to_numpy(mask))
377
+ batched_bboxes.append(bbox)
378
+
379
+ batched_names = [list(categorical_keywords)]
380
+ batched_unary_kws = [list(unary_keywords)]
381
+ batched_binary_kws = [list(binary_keywords)]
382
+
383
+ batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
384
+ if object_pairs:
385
+ for frame_id, frame_masks in masks.items():
386
+ if frame_id >= num_frames:
387
+ continue
388
+ present_ids = set(frame_masks.keys())
389
+ for (from_oid, to_oid) in object_pairs:
390
+ if from_oid in present_ids and to_oid in present_ids:
391
+ batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
392
+
393
+ batched_video_splits = [0]
394
+ batched_binary_predicates = [None]
395
+
396
+ def fill_empty(batched_kw):
397
+ new_batched = []
398
+ for kw_ls in batched_kw:
399
+ if len(kw_ls) == 0:
400
+ new_batched.append([dummy_str])
401
+ else:
402
+ new_batched.append(list(kw_ls))
403
+ return new_batched
404
+
405
+ batched_names = fill_empty(batched_names)
406
+ batched_unary_kws = fill_empty(batched_unary_kws)
407
+ batched_binary_kws = fill_empty(batched_binary_kws)
408
+
409
+ dummy_prob = torch.tensor(0.0, device=self._device)
410
+
411
+ batched_obj_name_features = []
412
+ batched_unary_nl_features = []
413
+ batched_binary_nl_features = []
414
+
415
+ batched_object_ids_lookup: Dict[int, List[Tuple[int, int]]] = {0: []}
416
+ batch_size = len(batched_video_ids)
417
+
418
+ # Step 1: text features
419
+ for object_names, unary_kws, binary_kws in zip(
420
+ batched_names, batched_unary_kws, batched_binary_kws
421
+ ):
422
+ if len(object_names) == 0:
423
+ batched_obj_name_features.append([])
424
+ else:
425
+ obj_tokens = self.clip_tokenizer(
426
+ object_names,
427
+ return_tensors="pt",
428
+ max_length=75,
429
+ truncation=True,
430
+ padding="max_length",
431
+ ).to(self._device)
432
+ obj_feats = self._text_features_checkpoint(
433
+ self.clip_cate_model, obj_tokens
434
+ )
435
+ batched_obj_name_features.append(obj_feats)
436
+
437
+ if len(unary_kws) == 0:
438
+ batched_unary_nl_features.append([])
439
+ else:
440
+ unary_tokens = self.clip_tokenizer(
441
+ list(unary_kws),
442
+ return_tensors="pt",
443
+ max_length=75,
444
+ truncation=True,
445
+ padding="max_length",
446
+ ).to(self._device)
447
+ unary_feats = self._text_features_checkpoint(
448
+ self.clip_unary_model, unary_tokens
449
+ )
450
+ batched_unary_nl_features.append(unary_feats)
451
+
452
+ if len(binary_kws) == 0:
453
+ batched_binary_nl_features.append([])
454
+ else:
455
+ binary_tokens = self.clip_tokenizer(
456
+ list(binary_kws),
457
+ return_tensors="pt",
458
+ max_length=75,
459
+ truncation=True,
460
+ padding="max_length",
461
+ ).to(self._device)
462
+ binary_feats = self._text_features_checkpoint(
463
+ self.clip_binary_model, binary_tokens
464
+ )
465
+ batched_binary_nl_features.append(binary_feats)
466
+
467
+ # Step 2: crop objects
468
+ batched_frame_masks: Dict[Tuple[int, int, int], np.ndarray] = {}
469
+ batched_frame_bboxes: Dict[Tuple[int, int, int], List[float]] = {}
470
+ batched_cropped_objs: Dict[int, List[np.ndarray]] = {
471
+ vid: [] for vid in range(batch_size)
472
+ }
473
+
474
+ assert len(batched_object_ids) > 0, f"No object bbox: {batched_video_ids}"
475
+
476
+ batched_video_splits = [0] + batched_video_splits
477
+
478
+ for (video_id, frame_id, obj_id), mask, bbox in zip(
479
+ batched_object_ids, batched_masks, batched_bboxes
480
+ ):
481
+ overall_frame_id = batched_video_splits[video_id] + frame_id
482
+ object_img = extract_single_object(
483
+ batched_videos[overall_frame_id], mask, white_alpha
484
+ )
485
+ cropped_object_img = crop_image_contain_bboxes(
486
+ object_img, [bbox], batched_video_ids
487
+ )
488
+
489
+ if self.visualization_dir:
490
+ debug_crop_dir = os.path.join(self.visualization_dir, "debug_crops")
491
+ os.makedirs(debug_crop_dir, exist_ok=True)
492
+ cv2.imwrite(
493
+ os.path.join(debug_crop_dir, f"frame_{frame_id}_obj_{obj_id}.jpg"),
494
+ cv2.cvtColor(cropped_object_img, cv2.COLOR_RGB2BGR),
495
+ )
496
+
497
+ batched_frame_masks[(video_id, frame_id, obj_id)] = mask
498
+ batched_frame_bboxes[(video_id, frame_id, obj_id)] = bbox
499
+ batched_object_ids_lookup[video_id].append((frame_id, obj_id))
500
+ batched_cropped_objs[video_id].append(cropped_object_img)
501
+
502
+ # Step 3: categorical + unary
503
+ batched_image_unary_probs: Dict[int, Dict] = {}
504
+ batched_image_cate_probs: Dict[int, Dict] = {}
505
+ batched_obj_cate_features: Dict[int, Any] = {}
506
+ batched_obj_unary_features: Dict[int, Any] = {}
507
+ batched_obj_per_cate: Dict[int, Dict[str, List[Tuple[torch.Tensor, int]]]] = {}
508
+
509
+ for vid in range(batch_size):
510
+ batched_image_unary_probs[vid] = {}
511
+ batched_image_cate_probs[vid] = {}
512
+ batched_obj_cate_features[vid] = {}
513
+ batched_obj_unary_features[vid] = {}
514
+ batched_obj_per_cate[vid] = {}
515
+
516
+ for vid_id, (
517
+ unary_nl_feats,
518
+ object_name_feats,
519
+ cate,
520
+ unary_pred,
521
+ binary_predicates,
522
+ ) in enumerate(
523
+ zip(
524
+ batched_unary_nl_features,
525
+ batched_obj_name_features,
526
+ batched_names,
527
+ batched_unary_kws,
528
+ batched_binary_predicates,
529
+ )
530
+ ):
531
+ cropped_objs = batched_cropped_objs[vid_id]
532
+
533
+ if len(cropped_objs) != 0:
534
+ inputs = self.clip_processor(
535
+ images=cropped_objs, return_tensors="pt"
536
+ ).to(self._device)
537
+ cate_obj_clip_features = self._image_features_checkpoint(
538
+ self.clip_cate_model, inputs["pixel_values"]
539
+ )
540
+ unary_obj_clip_features = self._image_features_checkpoint(
541
+ self.clip_unary_model, inputs["pixel_values"]
542
+ )
543
+ batched_obj_unary_features[vid_id] = unary_obj_clip_features
544
+ batched_obj_cate_features[vid_id] = cate_obj_clip_features
545
+ else:
546
+ batched_obj_cate_features[vid_id] = torch.tensor([])
547
+ batched_obj_unary_features[vid_id] = torch.tensor([])
548
+
549
+ object_ids = batched_object_ids_lookup[vid_id]
550
+
551
+ # Categorical logits
552
+ if (
553
+ len(object_name_feats) == 0
554
+ or len(object_ids) == 0
555
+ or len(cropped_objs) == 0
556
+ ):
557
+ cate_logits_per_text = torch.tensor([])
558
+ else:
559
+ cate_logits_per_text = self.clip_sim(
560
+ self.clip_cate_model, object_name_feats, cate_obj_clip_features
561
+ )
562
+ if not output_logit:
563
+ cate_logits_per_text = cate_logits_per_text.softmax(dim=0)
564
+
565
+ if not (
566
+ len(object_ids) == 0
567
+ or (
568
+ cate_logits_per_text.ndim == 2
569
+ and cate_logits_per_text.shape[1] == len(object_ids)
570
+ )
571
+ or len(object_name_feats) == 0
572
+ ):
573
+ print("Object cate shape mismatch here")
574
+
575
+ assert (
576
+ len(object_name_feats) == 0
577
+ or len(object_ids) == 0
578
+ or (
579
+ cate_logits_per_text.ndim == 2
580
+ and cate_logits_per_text.shape[1] == len(object_ids)
581
+ )
582
+ ), f"Mismatched object id and cate logic: {batched_video_ids}"
583
+
584
+ # Aggregate per object id across frames
585
+ cate_prob_per_obj: Dict[int, Dict[str, List[torch.Tensor]]] = {}
586
+ for cate_name, probs in zip(cate, cate_logits_per_text):
587
+ if cate_name == dummy_str:
588
+ dummy_prob += probs.sum()
589
+ else:
590
+ for prob, (fid, oid) in zip(probs, object_ids):
591
+ cate_prob_per_obj.setdefault(oid, {})
592
+ cate_prob_per_obj[oid].setdefault(cate_name, []).append(prob)
593
+
594
+ new_cate_prob_per_obj: Dict[Tuple[int, str], torch.Tensor] = {}
595
+ obj_per_cate: Dict[str, List[Tuple[torch.Tensor, int]]] = {}
596
+
597
+ for oid, object_cate_info in cate_prob_per_obj.items():
598
+ # Pool across frames per category
599
+ pooled: Dict[str, torch.Tensor] = {}
600
+ for cate_name, prob_list in object_cate_info.items():
601
+ stacked = torch.stack(prob_list)
602
+ if getattr(self.config, "categorical_pool", "mean") == "mean":
603
+ pooled_prob = stacked.mean()
604
+ else:
605
+ pooled_prob = stacked.max()
606
+ pooled[cate_name] = pooled_prob
607
+
608
+ if not pooled:
609
+ continue
610
+
611
+ # Renormalize across categories so they sum to 1 per object
612
+ probs_tensor = torch.stack(list(pooled.values()))
613
+ denom = probs_tensor.sum()
614
+ if denom.item() <= 0:
615
+ norm_tensor = torch.ones_like(probs_tensor) / len(pooled)
616
+ else:
617
+ norm_tensor = probs_tensor / denom
618
+
619
+ for (cate_name, _), norm_prob in zip(pooled.items(), norm_tensor):
620
+ obj_per_cate.setdefault(cate_name, []).append((norm_prob, oid))
621
+ new_cate_prob_per_obj[(oid, cate_name)] = norm_prob
622
+
623
+ for cate_name in obj_per_cate:
624
+ obj_per_cate[cate_name] = sorted(
625
+ obj_per_cate[cate_name], key=lambda x: x[0], reverse=True
626
+ )
627
+
628
+ # Unary
629
+ if len(unary_nl_feats) == 0 or len(cropped_objs) == 0:
630
+ unary_logits_per_text = torch.tensor([])
631
+ else:
632
+ unary_logits_per_text = self.clip_sim(
633
+ self.clip_unary_model, unary_nl_feats, unary_obj_clip_features
634
+ )
635
+ if not output_logit:
636
+ unary_logits_per_text = unary_logits_per_text.softmax(dim=0)
637
+
638
+ unary_prob_per_obj: Dict[Tuple[int, int, str], torch.Tensor] = {}
639
+ for unary_name, probs in zip(unary_pred, unary_logits_per_text):
640
+ if unary_name == dummy_str:
641
+ dummy_prob += probs.sum()
642
+ else:
643
+ for prob, (fid, oid) in zip(probs, object_ids):
644
+ unary_prob_per_obj[(fid, oid, unary_name)] = prob
645
+
646
+ batched_image_cate_probs[vid_id] = new_cate_prob_per_obj
647
+ batched_image_unary_probs[vid_id] = unary_prob_per_obj
648
+ batched_obj_per_cate[vid_id] = obj_per_cate
649
+
650
+ # Step 4: binary pairs
651
+ batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
652
+ frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
653
+ current_info = (0, 0)
654
+ frame_splits[current_info] = {"start": 0}
655
+
656
+ batched_topk_cate_candidates: Dict[int, Dict[str, List[int]]] = {
657
+ video_id: {} for video_id in range(batch_size)
658
+ }
659
+ for video_id, obj_per_cate in batched_obj_per_cate.items():
660
+ topk_cate_candidates: Dict[str, List[int]] = {}
661
+ for cate_name, pred_oid_ls in obj_per_cate.items():
662
+ for _, oid in pred_oid_ls[:topk_cate]:
663
+ topk_cate_candidates.setdefault(cate_name, []).append(oid)
664
+ batched_topk_cate_candidates[video_id] = topk_cate_candidates
665
+
666
+ obj_pair_lookup: Dict[int, Dict[Tuple[int, int], List[int]]] = {
667
+ video_id: {} for video_id in range(len(batched_video_ids))
668
+ }
669
+ for (vid, fid, (from_oid, to_oid)) in batched_obj_pairs:
670
+ if (from_oid, to_oid) not in obj_pair_lookup[vid]:
671
+ obj_pair_lookup[vid][(from_oid, to_oid)] = []
672
+ obj_pair_lookup[vid][(from_oid, to_oid)].append(fid)
673
+
674
+ selected_pairs = set()
675
+ if batched_binary_predicates[0] is None:
676
+ selected_pairs = set(batched_obj_pairs)
677
+ else:
678
+ for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
679
+ topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
680
+ for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
681
+ if (
682
+ from_obj_name in topk_cate_candidates
683
+ and to_obj_name in topk_cate_candidates
684
+ ):
685
+ from_oids = topk_cate_candidates[from_obj_name]
686
+ to_oids = topk_cate_candidates[to_obj_name]
687
+ for from_oid in from_oids:
688
+ for to_oid in to_oids:
689
+ if (
690
+ bp_vid in obj_pair_lookup
691
+ and (from_oid, to_oid) in obj_pair_lookup[bp_vid]
692
+ ):
693
+ for fid in obj_pair_lookup[bp_vid][
694
+ (from_oid, to_oid)
695
+ ]:
696
+ selected_pairs.add(
697
+ (bp_vid, fid, (from_oid, to_oid))
698
+ )
699
+
700
+ selected_pairs = list(selected_pairs)
701
+
702
+ new_select_pairs: Dict[int, List[Tuple[int, int, Tuple[int, int]]]] = {
703
+ video_id: [] for video_id in range(len(batched_video_ids))
704
+ }
705
+ for (vid, fid, (from_oid, to_oid)) in selected_pairs:
706
+ new_select_pairs[vid].append((vid, fid, (from_oid, to_oid)))
707
+
708
+ for vid in range(len(batched_video_ids)):
709
+ batched_cropped_obj_pairs[vid] = []
710
+
711
+ for (vid, fid, (from_id, to_id)) in selected_pairs:
712
+ if (vid, fid, from_id) not in batched_frame_masks or (
713
+ vid,
714
+ fid,
715
+ to_id,
716
+ ) not in batched_frame_masks:
717
+ continue
718
+ if (vid, fid, from_id) not in batched_frame_bboxes or (
719
+ vid,
720
+ fid,
721
+ to_id,
722
+ ) not in batched_frame_bboxes:
723
+ continue
724
+
725
+ overall_frame_id = batched_video_splits[vid] + fid
726
+ mask1 = batched_frame_masks[(vid, fid, from_id)]
727
+ mask2 = batched_frame_masks[(vid, fid, to_id)]
728
+ bbox1 = batched_frame_bboxes[(vid, fid, from_id)]
729
+ bbox2 = batched_frame_bboxes[(vid, fid, to_id)]
730
+ bb_pop_image = extract_object_subject(
731
+ batched_videos[overall_frame_id],
732
+ mask1,
733
+ mask2,
734
+ alpha=alpha,
735
+ white_alpha=white_alpha,
736
+ )
737
+ cropped_bb_pop_image = crop_image_contain_bboxes(
738
+ img=bb_pop_image,
739
+ bbox_ls=[bbox1, bbox2],
740
+ data_id=batched_video_ids,
741
+ )
742
+ batched_cropped_obj_pairs[vid].append(cropped_bb_pop_image)
743
+
744
+ if len(selected_pairs) == 0:
745
+ selected_pairs.append((0, -1, (-1, -1)))
746
+ new_select_pairs[0] = [(0, -1, (-1, -1))]
747
+ dummy_img = batched_videos[0]
748
+ batched_cropped_obj_pairs[0] = [dummy_img]
749
+
750
+ batched_image_binary_probs: List[
751
+ Dict[Tuple[int, Tuple[int, int], str], torch.Tensor]
752
+ ] = []
753
+ batched_obj_pair_features: Dict[int, torch.Tensor] = {
754
+ vid: torch.tensor([]) for vid in range(batch_size)
755
+ }
756
+
757
+ if len(batched_cropped_obj_pairs) == 0:
758
+ batched_image_binary_probs.append({})
759
+ else:
760
+ for vid, binary_nl_features in enumerate(batched_binary_nl_features):
761
+ if len(binary_nl_features) == 0:
762
+ batched_image_binary_probs.append({})
763
+ continue
764
+
765
+ binary_kws = batched_binary_kws[vid]
766
+ cropped_obj_pairs = batched_cropped_obj_pairs[vid]
767
+ if len(cropped_obj_pairs) == 0:
768
+ batched_image_binary_probs.append({})
769
+ continue
770
+
771
+ inputs = self.clip_processor(
772
+ images=cropped_obj_pairs, return_tensors="pt"
773
+ ).to(self._device)
774
+ obj_features = self._image_features_checkpoint(
775
+ self.clip_binary_model, inputs["pixel_values"]
776
+ )
777
+ batched_obj_pair_features[vid] = obj_features
778
+
779
+ obj_clip_features = obj_features / obj_features.norm(
780
+ p=2, dim=-1, keepdim=True
781
+ )
782
+ binary_nl_features = binary_nl_features / binary_nl_features.norm(
783
+ p=2, dim=-1, keepdim=True
784
+ )
785
+
786
+ logit_scale = self.clip_binary_model.logit_scale
787
+ binary_logits_per_text = torch.matmul(
788
+ binary_nl_features, obj_clip_features.t()
789
+ ) * logit_scale.exp()
790
+
791
+ if not output_logit:
792
+ if not multi_class:
793
+ binary_logits_per_text = binary_logits_per_text.softmax(dim=0)
794
+ else:
795
+ binary_logits_per_text = binary_logits_per_text.sigmoid()
796
+
797
+ binary_prob_per_obj: Dict[
798
+ Tuple[int, Tuple[int, int], str], torch.Tensor
799
+ ] = {}
800
+ for binary_name, probs in zip(binary_kws, binary_logits_per_text):
801
+ if binary_name == dummy_str:
802
+ dummy_prob += probs.sum()
803
+ else:
804
+ for prob, (vid_, fid, obj_pair) in zip(
805
+ probs, new_select_pairs[vid]
806
+ ):
807
+ if fid == -1:
808
+ dummy_prob += prob
809
+ else:
810
+ binary_prob_per_obj[(fid, obj_pair, binary_name)] = prob
811
+ batched_image_binary_probs.append(binary_prob_per_obj)
812
+
813
+ result: Dict[str, Any] = {
814
+ "categorical_probs": batched_image_cate_probs,
815
+ "unary_probs": batched_image_unary_probs,
816
+ "binary_probs": batched_image_binary_probs,
817
+ "dummy_prob": dummy_prob,
818
+ }
819
+
820
+ if output_embeddings:
821
+ embeddings_dict = {
822
+ "cate_obj_clip_features": batched_obj_cate_features,
823
+ "cate_object_ids": batched_object_ids_lookup,
824
+ "unary_obj_clip_features": batched_obj_unary_features,
825
+ "unary_object_ids": batched_object_ids_lookup,
826
+ "binary_obj_pair_features": batched_obj_pair_features,
827
+ "binary_object_pairs": new_select_pairs,
828
+ }
829
+ result["embeddings"] = embeddings_dict
830
+
831
+ if return_flattened_segments or return_valid_pairs:
832
+ flattened = flatten_segments_for_batch(
833
+ video_id=0,
834
+ segments=masks,
835
+ bbox_min_dim=self.config.bbox_min_dim,
836
+ )
837
+ if return_flattened_segments:
838
+ result["flattened_segments"] = flattened
839
+ if return_valid_pairs:
840
+ interested_pairs = (
841
+ interested_object_pairs if interested_object_pairs else None
842
+ )
843
+ result["valid_pairs"] = extract_valid_object_pairs(
844
+ flattened["object_ids"],
845
+ interested_pairs,
846
+ )
847
+ if interested_pairs is None:
848
+ result["valid_pairs_metadata"] = {"pair_source": "all_pairs"}
849
+ else:
850
+ result["valid_pairs_metadata"] = {
851
+ "pair_source": "filtered",
852
+ "requested_pairs": interested_object_pairs,
853
+ }
854
+
855
+ return result
856
+
857
+ # ------------------------------------------------------------------ #
858
+ # Helpers
859
+ # ------------------------------------------------------------------ #
860
+ def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
861
+ if torch.is_tensor(frame):
862
+ frame_np = frame.detach().cpu().numpy()
863
+ else:
864
+ frame_np = np.asarray(frame)
865
+ return np.ascontiguousarray(frame_np)
866
+
867
+ def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
868
+ if torch.is_tensor(mask):
869
+ mask_np = mask.detach().cpu().numpy()
870
+ else:
871
+ mask_np = np.asarray(mask)
872
+
873
+ if mask_np.ndim == 3:
874
+ if mask_np.shape[0] == 1:
875
+ mask_np = mask_np.squeeze(0)
876
+ elif mask_np.shape[2] == 1:
877
+ mask_np = mask_np.squeeze(2)
878
+
879
+ if mask_np.ndim != 2:
880
+ raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}")
881
+
882
+ return mask_np.astype(bool, copy=False)
883
+
884
+ def _extract_text_features(self, model, keywords: List[str]):
885
+ tokens = self.clip_tokenizer(
886
+ keywords,
887
+ return_tensors="pt",
888
+ max_length=75,
889
+ truncation=True,
890
+ padding="max_length",
891
+ ).to(self._device)
892
+ return self._text_features_checkpoint(model, tokens)
893
+
894
+ def _extract_image_features(self, model, image):
895
+ if torch.is_tensor(image):
896
+ image = image.detach().cpu().numpy()
897
+ elif isinstance(image, np.ndarray):
898
+ pass
899
+
900
+ inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
901
+ return self._image_features_checkpoint(model, inputs["pixel_values"])
902
+
903
+ # ------------------------------------------------------------------ #
904
+ # High-level predict API
905
+ # ------------------------------------------------------------------ #
906
+ def predict(
907
+ self,
908
+ video_frames: torch.Tensor,
909
+ masks: Dict[int, Dict[int, torch.Tensor]],
910
+ bboxes: Dict[int, Dict[int, List]],
911
+ categorical_keywords: List[str],
912
+ unary_keywords: Optional[List[str]] = None,
913
+ binary_keywords: Optional[List[str]] = None,
914
+ object_pairs: Optional[List[Tuple[int, int]]] = None,
915
+ return_top_k: int = 3,
916
+ return_flattened_segments: Optional[bool] = None,
917
+ return_valid_pairs: Optional[bool] = None,
918
+ interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
919
+ debug_visualizations: Optional[bool] = None,
920
+ ) -> Dict[str, Any]:
921
+ with torch.no_grad():
922
+ outputs = self.forward(
923
+ video_frames=video_frames,
924
+ masks=masks,
925
+ bboxes=bboxes,
926
+ categorical_keywords=categorical_keywords,
927
+ unary_keywords=unary_keywords,
928
+ binary_keywords=binary_keywords,
929
+ object_pairs=object_pairs,
930
+ return_flattened_segments=return_flattened_segments,
931
+ return_valid_pairs=return_valid_pairs,
932
+ interested_object_pairs=interested_object_pairs,
933
+ debug_visualizations=debug_visualizations,
934
+ )
935
+
936
+ formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
937
+ for (obj_id, category), prob in outputs["categorical_probs"][0].items():
938
+ if obj_id not in formatted_categorical:
939
+ formatted_categorical[obj_id] = []
940
+ prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
941
+ formatted_categorical[obj_id].append((prob_val, category))
942
+
943
+ for obj_id in formatted_categorical:
944
+ formatted_categorical[obj_id] = sorted(
945
+ formatted_categorical[obj_id], reverse=True
946
+ )[:return_top_k]
947
+
948
+ formatted_unary: Dict[Tuple[int, int], List[Tuple[float, str]]] = {}
949
+ for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items():
950
+ key = (frame_id, obj_id)
951
+ if key not in formatted_unary:
952
+ formatted_unary[key] = []
953
+ prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
954
+ formatted_unary[key].append((prob_val, predicate))
955
+
956
+ for key in formatted_unary:
957
+ formatted_unary[key] = sorted(
958
+ formatted_unary[key], reverse=True
959
+ )[:return_top_k]
960
+
961
+ formatted_binary: Dict[Tuple[int, Tuple[int, int]], List[Tuple[float, str]]] = {}
962
+ if len(outputs["binary_probs"]) > 0:
963
+ for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items():
964
+ key = (frame_id, obj_pair)
965
+ if key not in formatted_binary:
966
+ formatted_binary[key] = []
967
+ prob_val = float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
968
+ formatted_binary[key].append((prob_val, predicate))
969
+
970
+ for key in formatted_binary:
971
+ formatted_binary[key] = sorted(
972
+ formatted_binary[key], reverse=True
973
+ )[:return_top_k]
974
+
975
+ def max_conf(d: Dict[Any, List[Tuple[float, str]]]) -> float:
976
+ if not d:
977
+ return 0.0
978
+ return max(
979
+ (max((p for p, _ in preds), default=0.0) for preds in d.values()),
980
+ default=0.0,
981
+ )
982
+
983
+ result: Dict[str, Any] = {
984
+ "categorical_predictions": formatted_categorical,
985
+ "unary_predictions": formatted_unary,
986
+ "binary_predictions": formatted_binary,
987
+ "confidence_scores": {
988
+ "categorical": max_conf(formatted_categorical),
989
+ "unary": max_conf(formatted_unary),
990
+ "binary": max_conf(formatted_binary),
991
+ },
992
+ }
993
+
994
+ if "flattened_segments" in outputs:
995
+ result["flattened_segments"] = outputs["flattened_segments"]
996
+ if "valid_pairs" in outputs:
997
+ result["valid_pairs"] = outputs["valid_pairs"]
998
+ if "valid_pairs_metadata" in outputs:
999
+ result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"]
1000
+
1001
+ return result
vine_hf/vine_pipeline.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import hashlib
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Dict, List, Tuple, Optional, Any, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from transformers import Pipeline
12
+
13
+ from .vine_config import VineConfig
14
+ from .vine_model import VineModel
15
+ from .vis_utils import render_dino_frames, render_sam_frames, render_vine_frame_sets
16
+ from laser.loading import load_video
17
+ from laser.preprocess.mask_generation_grounding_dino import generate_masks_grounding_dino
18
+
19
+
20
+ class VinePipeline(Pipeline):
21
+ """
22
+ Pipeline for VINE model that handles end-to-end video understanding.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ sam_config_path: Optional[str] = None,
28
+ sam_checkpoint_path: Optional[str] = None,
29
+ gd_config_path: Optional[str] = None,
30
+ gd_checkpoint_path: Optional[str] = None,
31
+ **kwargs: Any,
32
+ ):
33
+ self.grounding_model = None
34
+ self.sam_predictor = None
35
+ self.mask_generator = None
36
+
37
+ self.sam_config_path = sam_config_path
38
+ self.sam_checkpoint_path = sam_checkpoint_path
39
+ self.gd_config_path = gd_config_path
40
+ self.gd_checkpoint_path = gd_checkpoint_path
41
+
42
+ super().__init__(**kwargs)
43
+
44
+ self.segmentation_method = getattr(
45
+ self.model.config, "segmentation_method", "grounding_dino_sam2"
46
+ )
47
+ self.box_threshold = getattr(self.model.config, "box_threshold", 0.35)
48
+ self.text_threshold = getattr(self.model.config, "text_threshold", 0.25)
49
+ self.target_fps = getattr(self.model.config, "target_fps", 1)
50
+ self.visualize = getattr(self.model.config, "visualize", False)
51
+ self.visualization_dir = getattr(self.model.config, "visualization_dir", None)
52
+ self.debug_visualizations = getattr(
53
+ self.model.config, "debug_visualizations", False
54
+ )
55
+ self._device = getattr(self.model.config, "_device")
56
+ if kwargs.get("device") is not None:
57
+ self._device = kwargs.get("device")
58
+
59
+ # ------------------------------------------------------------------ #
60
+ # Segmentation model injection
61
+ # ------------------------------------------------------------------ #
62
+ def set_segmentation_models(
63
+ self,
64
+ *,
65
+ sam_predictor=None,
66
+ mask_generator=None,
67
+ grounding_model=None,
68
+ ):
69
+ if sam_predictor is not None:
70
+ self.sam_predictor = sam_predictor
71
+ if mask_generator is not None:
72
+ self.mask_generator = mask_generator
73
+ if grounding_model is not None:
74
+ self.grounding_model = grounding_model
75
+
76
+ # ------------------------------------------------------------------ #
77
+ # Pipeline protocol
78
+ # ------------------------------------------------------------------ #
79
+ def _sanitize_parameters(self, **kwargs: Any):
80
+ preprocess_kwargs: Dict[str, Any] = {}
81
+ forward_kwargs: Dict[str, Any] = {}
82
+ postprocess_kwargs: Dict[str, Any] = {}
83
+
84
+ if "segmentation_method" in kwargs:
85
+ preprocess_kwargs["segmentation_method"] = kwargs["segmentation_method"]
86
+ if "target_fps" in kwargs:
87
+ preprocess_kwargs["target_fps"] = kwargs["target_fps"]
88
+ if "box_threshold" in kwargs:
89
+ preprocess_kwargs["box_threshold"] = kwargs["box_threshold"]
90
+ if "text_threshold" in kwargs:
91
+ preprocess_kwargs["text_threshold"] = kwargs["text_threshold"]
92
+ if "categorical_keywords" in kwargs:
93
+ preprocess_kwargs["categorical_keywords"] = kwargs["categorical_keywords"]
94
+
95
+ if "categorical_keywords" in kwargs:
96
+ forward_kwargs["categorical_keywords"] = kwargs["categorical_keywords"]
97
+ if "unary_keywords" in kwargs:
98
+ forward_kwargs["unary_keywords"] = kwargs["unary_keywords"]
99
+ if "binary_keywords" in kwargs:
100
+ forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
101
+ if "object_pairs" in kwargs:
102
+ forward_kwargs["object_pairs"] = kwargs["object_pairs"]
103
+ if "return_flattened_segments" in kwargs:
104
+ forward_kwargs["return_flattened_segments"] = kwargs[
105
+ "return_flattened_segments"
106
+ ]
107
+ if "return_valid_pairs" in kwargs:
108
+ forward_kwargs["return_valid_pairs"] = kwargs["return_valid_pairs"]
109
+ if "interested_object_pairs" in kwargs:
110
+ forward_kwargs["interested_object_pairs"] = kwargs[
111
+ "interested_object_pairs"
112
+ ]
113
+ if "debug_visualizations" in kwargs:
114
+ forward_kwargs["debug_visualizations"] = kwargs["debug_visualizations"]
115
+ postprocess_kwargs["debug_visualizations"] = kwargs["debug_visualizations"]
116
+
117
+ if "return_top_k" in kwargs:
118
+ postprocess_kwargs["return_top_k"] = kwargs["return_top_k"]
119
+ if "self.visualize" in kwargs:
120
+ postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
121
+
122
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
123
+
124
+ # ------------------------------------------------------------------ #
125
+ # Preprocess: video + segmentation
126
+ # ------------------------------------------------------------------ #
127
+ def preprocess(
128
+ self,
129
+ video_input: Union[str, np.ndarray, torch.Tensor],
130
+ segmentation_method: Optional[str] = None,
131
+ target_fps: Optional[int] = None,
132
+ box_threshold: Optional[float] = None,
133
+ text_threshold: Optional[float] = None,
134
+ categorical_keywords: Optional[List[str]] = None,
135
+ **kwargs: Any,
136
+ ) -> Dict[str, Any]:
137
+ if segmentation_method is None:
138
+ segmentation_method = self.segmentation_method
139
+ if target_fps is None:
140
+ target_fps = self.target_fps
141
+ else:
142
+ self.target_fps = target_fps
143
+ if box_threshold is None:
144
+ box_threshold = self.box_threshold
145
+ else:
146
+ self.box_threshold = box_threshold
147
+ if text_threshold is None:
148
+ text_threshold = self.text_threshold
149
+ else:
150
+ self.text_threshold = text_threshold
151
+ if categorical_keywords is None:
152
+ categorical_keywords = ["object"]
153
+
154
+ if isinstance(video_input, str):
155
+ video_tensor = load_video(video_input, target_fps=target_fps)
156
+ if isinstance(video_tensor, list):
157
+ video_tensor = np.array(video_tensor)
158
+ elif isinstance(video_tensor, torch.Tensor):
159
+ video_tensor = video_tensor.cpu().numpy()
160
+ elif isinstance(video_input, (np.ndarray, torch.Tensor)):
161
+ if isinstance(video_input, torch.Tensor):
162
+ video_tensor = video_input.numpy()
163
+ else:
164
+ video_tensor = video_input
165
+ else:
166
+ raise ValueError(f"Unsupported video input type: {type(video_input)}")
167
+
168
+ if not isinstance(video_tensor, np.ndarray):
169
+ video_tensor = np.array(video_tensor)
170
+
171
+ if len(video_tensor.shape) != 4:
172
+ raise ValueError(
173
+ f"Expected video tensor shape (frames, height, width, channels), got {video_tensor.shape}"
174
+ )
175
+
176
+ visualization_data: Dict[str, Any] = {}
177
+ print(f"Segmentation method: {segmentation_method}")
178
+ if segmentation_method == "sam2":
179
+ masks, bboxes, vis_data = self._generate_sam2_masks(video_tensor)
180
+ elif segmentation_method == "grounding_dino_sam2":
181
+ masks, bboxes, vis_data = self._generate_grounding_dino_sam2_masks(
182
+ video_tensor,
183
+ categorical_keywords,
184
+ box_threshold,
185
+ text_threshold,
186
+ video_input,
187
+ )
188
+ else:
189
+ raise ValueError(f"Unsupported segmentation method: {segmentation_method}")
190
+ if vis_data:
191
+ visualization_data.update(vis_data)
192
+ visualization_data.setdefault("sam_masks", masks)
193
+
194
+ return {
195
+ "video_frames": torch.tensor(video_tensor),
196
+ "masks": masks,
197
+ "bboxes": bboxes,
198
+ "num_frames": len(video_tensor),
199
+ "visualization_data": visualization_data,
200
+ }
201
+
202
+ # ------------------------------------------------------------------ #
203
+ # Segmentation helpers
204
+ # ------------------------------------------------------------------ #
205
+ def _generate_sam2_masks(
206
+ self, video_tensor: np.ndarray
207
+ ) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]:
208
+ print("Generating SAM2 masks...")
209
+ if self.mask_generator is None:
210
+ self._initialize_segmentation_models()
211
+ if self.mask_generator is None:
212
+ raise ValueError("SAM2 mask generator not available")
213
+
214
+ masks: Dict[int, Dict[int, torch.Tensor]] = {}
215
+ bboxes: Dict[int, Dict[int, List[int]]] = {}
216
+
217
+ for frame_id, frame in enumerate(video_tensor):
218
+ if isinstance(frame, np.ndarray) and frame.dtype != np.uint8:
219
+ frame = (
220
+ (frame * 255).astype(np.uint8)
221
+ if frame.max() <= 1
222
+ else frame.astype(np.uint8)
223
+ )
224
+
225
+ frame_masks = self.mask_generator.generate(frame)
226
+
227
+ masks[frame_id] = {}
228
+ bboxes[frame_id] = {}
229
+
230
+ for obj_id, mask_data in enumerate(frame_masks):
231
+ mask = mask_data["segmentation"]
232
+ if isinstance(mask, np.ndarray):
233
+ mask = torch.from_numpy(mask)
234
+
235
+ if len(mask.shape) == 2:
236
+ mask = mask.unsqueeze(-1)
237
+ elif len(mask.shape) == 3 and mask.shape[0] == 1:
238
+ mask = mask.permute(1, 2, 0)
239
+
240
+ wrapped_id = obj_id + 1
241
+ masks[frame_id][wrapped_id] = mask
242
+
243
+ mask_np = (
244
+ mask.squeeze().numpy()
245
+ if isinstance(mask, torch.Tensor)
246
+ else mask.squeeze()
247
+ )
248
+
249
+ coords = np.where(mask_np > 0)
250
+ if len(coords[0]) > 0:
251
+ y1, y2 = coords[0].min(), coords[0].max()
252
+ x1, x2 = coords[1].min(), coords[1].max()
253
+ bboxes[frame_id][wrapped_id] = [x1, y1, x2, y2]
254
+
255
+ tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes)
256
+ return tracked_masks, tracked_bboxes, {"sam_masks": tracked_masks}
257
+
258
+ def _generate_grounding_dino_sam2_masks(
259
+ self,
260
+ video_tensor: np.ndarray,
261
+ categorical_keywords: List[str],
262
+ box_threshold: float,
263
+ text_threshold: float,
264
+ video_path: Union[str, None],
265
+ ) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]], Dict[str, Any]]:
266
+ print("Generating Grounding DINO + SAM2 masks...")
267
+ if self.grounding_model is None or self.sam_predictor is None:
268
+ self._initialize_segmentation_models()
269
+ if self.grounding_model is None or self.sam_predictor is None:
270
+ raise ValueError("GroundingDINO or SAM2 models not available")
271
+
272
+ temp_video_path = None
273
+ if video_path is None or not isinstance(video_path, str):
274
+ temp_video_path = self._create_temp_video(video_tensor)
275
+ video_path = temp_video_path
276
+
277
+ CHUNK = 5
278
+ classes_ls = [
279
+ categorical_keywords[i : i + CHUNK]
280
+ for i in range(0, len(categorical_keywords), CHUNK)
281
+ ]
282
+
283
+ base_name = Path(video_path).stem
284
+ fps_tag = f"fps{int(self.target_fps)}"
285
+ path_hash = hashlib.md5(video_path.encode("utf-8")).hexdigest()[:8]
286
+ video_cache_name = f"{base_name}_{fps_tag}_{path_hash}"
287
+
288
+ video_segments, oid_class_pred, _ = generate_masks_grounding_dino(
289
+ self.grounding_model,
290
+ box_threshold,
291
+ text_threshold,
292
+ self.sam_predictor,
293
+ self.mask_generator,
294
+ video_tensor,
295
+ video_path,
296
+ video_cache_name,
297
+ out_dir=tempfile.gettempdir(),
298
+ classes_ls=classes_ls,
299
+ target_fps=self.target_fps,
300
+ visualize=self.debug_visualizations,
301
+ frames=None,
302
+ max_prop_time=2,
303
+ )
304
+
305
+ masks: Dict[int, Dict[int, torch.Tensor]] = {}
306
+ bboxes: Dict[int, Dict[int, List[int]]] = {}
307
+
308
+ for frame_id, frame_masks in video_segments.items():
309
+ masks[frame_id] = {}
310
+ bboxes[frame_id] = {}
311
+
312
+ for obj_id, mask in frame_masks.items():
313
+ if not isinstance(mask, torch.Tensor):
314
+ mask = torch.tensor(mask)
315
+ masks[frame_id][obj_id] = mask
316
+ mask_np = mask.numpy()
317
+ if mask_np.ndim == 3 and mask_np.shape[0] == 1:
318
+ mask_np = np.squeeze(mask_np, axis=0)
319
+
320
+ coords = np.where(mask_np > 0)
321
+ if len(coords[0]) > 0:
322
+ y1, y2 = coords[0].min(), coords[0].max()
323
+ x1, x2 = coords[1].min(), coords[1].max()
324
+ bboxes[frame_id][obj_id] = [x1, y1, x2, y2]
325
+
326
+ if temp_video_path and os.path.exists(temp_video_path):
327
+ os.remove(temp_video_path)
328
+
329
+ tracked_masks, tracked_bboxes = self._track_ids_across_frames(masks, bboxes)
330
+
331
+ vis_data: Dict[str, Any] = {
332
+ "sam_masks": tracked_masks,
333
+ "dino_labels": oid_class_pred,
334
+ }
335
+ return tracked_masks, tracked_bboxes, vis_data
336
+
337
+ # ------------------------------------------------------------------ #
338
+ # ID tracking across frames
339
+ # ------------------------------------------------------------------ #
340
+ def _bbox_iou(self, box1: List[int], box2: List[int]) -> float:
341
+ x1, y1, x2, y2 = box1
342
+ x1b, y1b, x2b, y2b = box2
343
+ ix1 = max(x1, x1b)
344
+ iy1 = max(y1, y1b)
345
+ ix2 = min(x2, x2b)
346
+ iy2 = min(y2, y2b)
347
+ iw = max(0, ix2 - ix1)
348
+ ih = max(0, iy2 - iy1)
349
+ inter = iw * ih
350
+ if inter <= 0:
351
+ return 0.0
352
+ area1 = max(0, x2 - x1) * max(0, y2 - y1)
353
+ area2 = max(0, x2b - x1b) * max(0, y2b - y1b)
354
+ union = area1 + area2 - inter
355
+ if union <= 0:
356
+ return 0.0
357
+ return inter / union
358
+
359
+ def _track_ids_across_frames(
360
+ self,
361
+ masks: Dict[int, Dict[int, torch.Tensor]],
362
+ bboxes: Dict[int, Dict[int, List[int]]],
363
+ iou_threshold: float = 0.3,
364
+ ) -> Tuple[Dict[int, Dict[int, torch.Tensor]], Dict[int, Dict[int, List[int]]]]:
365
+ frame_ids = sorted(masks.keys())
366
+ tracked_masks: Dict[int, Dict[int, torch.Tensor]] = {}
367
+ tracked_bboxes: Dict[int, Dict[int, List[int]]] = {}
368
+ next_track_id = 0
369
+ prev_tracks: Dict[int, List[int]] = {}
370
+
371
+ for frame_id in frame_ids:
372
+ frame_masks = masks.get(frame_id, {})
373
+ frame_boxes = bboxes.get(frame_id, {})
374
+ tracked_masks[frame_id] = {}
375
+ tracked_bboxes[frame_id] = {}
376
+
377
+ if not frame_boxes:
378
+ prev_tracks = {}
379
+ continue
380
+
381
+ det_ids = list(frame_boxes.keys())
382
+ prev_ids = list(prev_tracks.keys())
383
+
384
+ candidates: List[Tuple[float, int, int]] = []
385
+ for tid in prev_ids:
386
+ prev_box = prev_tracks[tid]
387
+ for det_id in det_ids:
388
+ iou = self._bbox_iou(prev_box, frame_boxes[det_id])
389
+ if iou > iou_threshold:
390
+ candidates.append((iou, tid, det_id))
391
+ candidates.sort(reverse=True)
392
+
393
+ matched_prev = set()
394
+ matched_det = set()
395
+
396
+ for iou, tid, det_id in candidates:
397
+ if tid in matched_prev or det_id in matched_det:
398
+ continue
399
+ matched_prev.add(tid)
400
+ matched_det.add(det_id)
401
+ tracked_masks[frame_id][tid] = frame_masks[det_id]
402
+ tracked_bboxes[frame_id][tid] = frame_boxes[det_id]
403
+
404
+ for det_id in det_ids:
405
+ if det_id in matched_det:
406
+ continue
407
+ tid = next_track_id
408
+ next_track_id += 1
409
+ tracked_masks[frame_id][tid] = frame_masks[det_id]
410
+ tracked_bboxes[frame_id][tid] = frame_boxes[det_id]
411
+
412
+ prev_tracks = {
413
+ tid: tracked_bboxes[frame_id][tid]
414
+ for tid in tracked_bboxes[frame_id].keys()
415
+ }
416
+
417
+ return tracked_masks, tracked_bboxes
418
+
419
+ # ------------------------------------------------------------------ #
420
+ # Segmentation model initialization
421
+ # ------------------------------------------------------------------ #
422
+ def _initialize_segmentation_models(self):
423
+ if self.sam_predictor is None or self.mask_generator is None:
424
+ self._initialize_sam2_models()
425
+ if self.grounding_model is None:
426
+ self._initialize_grounding_dino_model()
427
+
428
+ def _initialize_sam2_models(self):
429
+ try:
430
+ from sam2.build_sam import build_sam2_video_predictor, build_sam2
431
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
432
+ except ImportError as e:
433
+ print(f"Warning: Could not import SAM2: {e}")
434
+ return
435
+
436
+ config_path, checkpoint_path = self._resolve_sam2_paths()
437
+
438
+ if self.sam_config_path is not None and not os.path.exists(config_path):
439
+ raise ValueError(f"SAM2 config path not found: {config_path}")
440
+ if self.sam_checkpoint_path is not None and not os.path.exists(checkpoint_path):
441
+ raise ValueError(f"SAM2 checkpoint path not found: {checkpoint_path}")
442
+
443
+ if not os.path.exists(checkpoint_path):
444
+ print(f"Warning: SAM2 checkpoint not found at {checkpoint_path}")
445
+ print("SAM2 functionality will be unavailable")
446
+ return
447
+
448
+ try:
449
+ device = self._device
450
+ self.sam_predictor = build_sam2_video_predictor(
451
+ config_path, checkpoint_path, device=device
452
+ )
453
+
454
+ sam2_model = build_sam2(
455
+ config_path,
456
+ checkpoint_path,
457
+ device=device,
458
+ apply_postprocessing=False,
459
+ )
460
+ self.mask_generator = SAM2AutomaticMaskGenerator(
461
+ model=sam2_model,
462
+ points_per_side=32,
463
+ points_per_batch=32,
464
+ pred_iou_thresh=0.7,
465
+ stability_score_thresh=0.8,
466
+ crop_n_layers=2,
467
+ box_nms_thresh=0.6,
468
+ crop_n_points_downscale_factor=2,
469
+ min_mask_region_area=100,
470
+ use_m2m=True,
471
+ )
472
+ print("✓ SAM2 models initialized successfully")
473
+
474
+ except Exception as e:
475
+ raise ValueError(f"Failed to initialize SAM2 with custom paths: {e}")
476
+
477
+ def _initialize_grounding_dino_model(self):
478
+ try:
479
+ from groundingdino.util.inference import Model as gd_Model
480
+ except ImportError as e:
481
+ print(f"Warning: Could not import GroundingDINO: {e}")
482
+ return
483
+
484
+ config_path, checkpoint_path = self._resolve_grounding_dino_paths()
485
+
486
+ if self.gd_config_path is not None and not os.path.exists(config_path):
487
+ raise ValueError(f"GroundingDINO config path not found: {config_path}")
488
+ if self.gd_checkpoint_path is not None and not os.path.exists(checkpoint_path):
489
+ raise ValueError(
490
+ f"GroundingDINO checkpoint path not found: {checkpoint_path}"
491
+ )
492
+
493
+ if not (os.path.exists(config_path) and os.path.exists(checkpoint_path)):
494
+ print(
495
+ f"Warning: GroundingDINO models not found at {config_path} / {checkpoint_path}"
496
+ )
497
+ print("GroundingDINO functionality will be unavailable")
498
+ return
499
+
500
+ try:
501
+ device = self._device
502
+ self.grounding_model = gd_Model(
503
+ model_config_path=config_path,
504
+ model_checkpoint_path=checkpoint_path,
505
+ device=device,
506
+ )
507
+ print("✓ GroundingDINO model initialized successfully")
508
+
509
+ except Exception as e:
510
+ raise ValueError(f"Failed to initialize GroundingDINO with custom paths: {e}")
511
+
512
+ def _resolve_sam2_paths(self):
513
+ if self.sam_config_path and self.sam_checkpoint_path:
514
+ return self.sam_config_path, self.sam_checkpoint_path
515
+
516
+ def _resolve_grounding_dino_paths(self):
517
+ if self.gd_config_path and self.gd_checkpoint_path:
518
+ return self.gd_config_path, self.gd_checkpoint_path
519
+
520
+ # ------------------------------------------------------------------ #
521
+ # Video writing helpers
522
+ # ------------------------------------------------------------------ #
523
+ def _prepare_visualization_dir(self, name: str, enabled: bool) -> Optional[str]:
524
+ if not enabled:
525
+ return None
526
+
527
+ if self.visualization_dir:
528
+ target_dir = (
529
+ os.path.join(self.visualization_dir, name)
530
+ if name
531
+ else self.visualization_dir
532
+ )
533
+ os.makedirs(target_dir, exist_ok=True)
534
+ return target_dir
535
+
536
+ return tempfile.mkdtemp(prefix=f"vine_{name}_")
537
+
538
+ def _create_temp_video(
539
+ self,
540
+ video_tensor: np.ndarray,
541
+ base_dir: Optional[str] = None,
542
+ prefix: str = "temp_video",
543
+ ) -> str:
544
+ import subprocess
545
+
546
+ if base_dir is None:
547
+ base_dir = tempfile.mkdtemp(prefix=f"vine_{prefix}_")
548
+ else:
549
+ os.makedirs(base_dir, exist_ok=True)
550
+ file_name = f"{prefix}_{uuid.uuid4().hex}.mp4"
551
+ temp_path = os.path.join(base_dir, file_name)
552
+
553
+ height, width = video_tensor.shape[1:3]
554
+ processing_fps = max(1, self.target_fps)
555
+ output_fps = processing_fps
556
+ video_tensor_for_output = video_tensor
557
+
558
+ ffmpeg_success = False
559
+ try:
560
+ ffmpeg_success = self._create_video_with_ffmpeg(
561
+ video_tensor_for_output, temp_path, output_fps, width, height
562
+ )
563
+ except Exception as e:
564
+ print(f"FFmpeg method failed: {e}")
565
+
566
+ if not ffmpeg_success:
567
+ print("Using OpenCV fallback")
568
+ self._create_temp_video_opencv(
569
+ video_tensor_for_output, temp_path, output_fps, width, height
570
+ )
571
+
572
+ return temp_path
573
+
574
+ def _create_video_with_ffmpeg(
575
+ self, video_tensor: np.ndarray, output_path: str, fps: int, width: int, height: int
576
+ ) -> bool:
577
+ import subprocess
578
+
579
+ try:
580
+ ffmpeg_cmd = [
581
+ "ffmpeg",
582
+ "-y",
583
+ "-f",
584
+ "rawvideo",
585
+ "-vcodec",
586
+ "rawvideo",
587
+ "-s",
588
+ f"{width}x{height}",
589
+ "-pix_fmt",
590
+ "rgb24",
591
+ "-r",
592
+ str(fps),
593
+ "-i",
594
+ "pipe:0",
595
+ "-c:v",
596
+ "libx264",
597
+ "-preset",
598
+ "fast",
599
+ "-crf",
600
+ "23",
601
+ "-pix_fmt",
602
+ "yuv420p",
603
+ "-movflags",
604
+ "+faststart",
605
+ "-loglevel",
606
+ "error",
607
+ output_path,
608
+ ]
609
+
610
+ process = subprocess.Popen(
611
+ ffmpeg_cmd,
612
+ stdin=subprocess.PIPE,
613
+ stdout=subprocess.PIPE,
614
+ stderr=subprocess.PIPE,
615
+ )
616
+
617
+ frame_data = b""
618
+ for frame in video_tensor:
619
+ if frame.dtype != np.uint8:
620
+ frame = (
621
+ (frame * 255).astype(np.uint8)
622
+ if frame.max() <= 1
623
+ else frame.astype(np.uint8)
624
+ )
625
+ frame_data += frame.tobytes()
626
+
627
+ stdout, stderr = process.communicate(input=frame_data, timeout=60)
628
+
629
+ if process.returncode == 0:
630
+ print(f"Video created with FFmpeg (H.264) at {fps} FPS")
631
+ return True
632
+ else:
633
+ error_msg = stderr.decode() if stderr else "Unknown error"
634
+ print(f"FFmpeg error: {error_msg}")
635
+ return False
636
+
637
+ except FileNotFoundError:
638
+ print("FFmpeg not found in PATH")
639
+ return False
640
+ except Exception as e:
641
+ print(f"FFmpeg exception: {e}")
642
+ return False
643
+
644
+ def _create_temp_video_opencv(
645
+ self, video_tensor: np.ndarray, temp_path: str, fps: int, width: int, height: int
646
+ ) -> str:
647
+ codecs_to_try = ["avc1", "X264", "mp4v"]
648
+ out = None
649
+ used_codec = None
650
+
651
+ for codec in codecs_to_try:
652
+ try:
653
+ fourcc = cv2.VideoWriter_fourcc(*codec)
654
+ temp_out = cv2.VideoWriter(temp_path, fourcc, fps, (width, height))
655
+
656
+ if temp_out.isOpened():
657
+ out = temp_out
658
+ used_codec = codec
659
+ break
660
+ else:
661
+ temp_out.release()
662
+ except Exception as e:
663
+ print(f"Warning: Codec {codec} not available: {e}")
664
+ continue
665
+
666
+ if out is None or not out.isOpened():
667
+ raise RuntimeError(
668
+ f"Failed to initialize VideoWriter with any codec. Tried: {codecs_to_try}"
669
+ )
670
+
671
+ print(f"Using OpenCV with codec: {used_codec}")
672
+
673
+ for frame in video_tensor:
674
+ if len(frame.shape) == 3 and frame.shape[2] == 3:
675
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
676
+ else:
677
+ frame_bgr = frame
678
+ if frame_bgr.dtype != np.uint8:
679
+ frame_bgr = (
680
+ (frame_bgr * 255).astype(np.uint8)
681
+ if frame_bgr.max() <= 1
682
+ else frame_bgr.astype(np.uint8)
683
+ )
684
+ out.write(frame_bgr)
685
+
686
+ out.release()
687
+ return temp_path
688
+
689
+ # ------------------------------------------------------------------ #
690
+ # Forward + postprocess
691
+ # ------------------------------------------------------------------ #
692
+ def _forward(self, model_inputs: Dict[str, Any], **forward_kwargs: Any) -> Dict[str, Any]:
693
+ outputs = self.model.predict(
694
+ video_frames=model_inputs["video_frames"],
695
+ masks=model_inputs["masks"],
696
+ bboxes=model_inputs["bboxes"],
697
+ **forward_kwargs,
698
+ )
699
+ outputs.setdefault("video_frames", model_inputs.get("video_frames"))
700
+ outputs.setdefault("bboxes", model_inputs.get("bboxes"))
701
+ outputs.setdefault("masks", model_inputs.get("masks"))
702
+ outputs.setdefault("visualization_data", model_inputs.get("visualization_data"))
703
+ return outputs
704
+
705
+ def postprocess(
706
+ self,
707
+ model_outputs: Dict[str, Any],
708
+ return_top_k: int = 3,
709
+ visualize: Optional[bool] = None,
710
+ **kwargs: Any,
711
+ ) -> Dict[str, Any]:
712
+ results: Dict[str, Any] = {
713
+ "categorical_predictions": model_outputs.get("categorical_predictions", {}),
714
+ "unary_predictions": model_outputs.get("unary_predictions", {}),
715
+ "binary_predictions": model_outputs.get("binary_predictions", {}),
716
+ "confidence_scores": model_outputs.get("confidence_scores", {}),
717
+ "summary": self._generate_summary(model_outputs),
718
+ }
719
+
720
+ print("\n" + "=" * 50)
721
+ print("DEBUG: Raw Model Outputs - Categorical Predictions")
722
+ cat_preds = model_outputs.get("categorical_predictions", {})
723
+ for obj_id, preds in cat_preds.items():
724
+ print(f"Object {obj_id}: {preds}")
725
+ print("=" * 50 + "\n")
726
+
727
+ if "flattened_segments" in model_outputs:
728
+ results["flattened_segments"] = model_outputs["flattened_segments"]
729
+ if "valid_pairs" in model_outputs:
730
+ results["valid_pairs"] = model_outputs["valid_pairs"]
731
+ if "valid_pairs_metadata" in model_outputs:
732
+ results["valid_pairs_metadata"] = model_outputs["valid_pairs_metadata"]
733
+ if "visualization_data" in model_outputs:
734
+ results["visualization_data"] = model_outputs["visualization_data"]
735
+
736
+ if self.visualize and "video_frames" in model_outputs and "bboxes" in model_outputs:
737
+ frames_tensor = model_outputs["video_frames"]
738
+ if isinstance(frames_tensor, torch.Tensor):
739
+ frames_np = frames_tensor.detach().cpu().numpy()
740
+ else:
741
+ frames_np = np.asarray(frames_tensor)
742
+ if frames_np.dtype != np.uint8:
743
+ if np.issubdtype(frames_np.dtype, np.floating):
744
+ max_val = frames_np.max() if frames_np.size else 0.0
745
+ scale = 255.0 if max_val <= 1.0 else 1.0
746
+ frames_np = (frames_np * scale).clip(0, 255).astype(np.uint8)
747
+ else:
748
+ frames_np = frames_np.clip(0, 255).astype(np.uint8)
749
+
750
+ cat_label_lookup: Dict[int, Tuple[str, float]] = {}
751
+ for obj_id, preds in model_outputs.get("categorical_predictions", {}).items():
752
+ if preds:
753
+ prob, label = preds[0]
754
+ cat_label_lookup[obj_id] = (label, prob)
755
+
756
+ unary_preds = model_outputs.get("unary_predictions", {})
757
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]] = {}
758
+ for (frame_id, obj_id), preds in unary_preds.items():
759
+ if preds:
760
+ unary_lookup.setdefault(frame_id, {})[obj_id] = preds[:1]
761
+
762
+ binary_preds = model_outputs.get("binary_predictions", {})
763
+ binary_lookup: Dict[
764
+ int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]
765
+ ] = {}
766
+ for (frame_id, obj_pair), preds in binary_preds.items():
767
+ if preds:
768
+ binary_lookup.setdefault(frame_id, []).append((obj_pair, preds[:1]))
769
+
770
+ bboxes = model_outputs["bboxes"]
771
+ visualization_data = model_outputs.get("visualization_data", {})
772
+ visualizations: Dict[str, Dict[str, Any]] = {}
773
+ debug_visualizations = kwargs.get("debug_visualizations")
774
+ if debug_visualizations is None:
775
+ debug_visualizations = self.debug_visualizations
776
+
777
+ vine_frame_sets = render_vine_frame_sets(
778
+ frames_np,
779
+ bboxes,
780
+ cat_label_lookup,
781
+ unary_lookup,
782
+ binary_lookup,
783
+ visualization_data.get("sam_masks"),
784
+ )
785
+
786
+ vine_visuals: Dict[str, Dict[str, Any]] = {}
787
+ final_frames = vine_frame_sets.get("all", [])
788
+ if final_frames:
789
+ final_entry: Dict[str, Any] = {"frames": final_frames, "video_path": None}
790
+ final_dir = self._prepare_visualization_dir(
791
+ "all", enabled=self.visualize
792
+ )
793
+ final_entry["video_path"] = self._create_temp_video(
794
+ np.stack(final_frames, axis=0),
795
+ base_dir=final_dir,
796
+ prefix="all_visualization",
797
+ )
798
+ vine_visuals["all"] = final_entry
799
+
800
+ if debug_visualizations:
801
+ sam_masks = visualization_data.get("sam_masks")
802
+ if sam_masks:
803
+ sam_frames = render_sam_frames(
804
+ frames_np, sam_masks, visualization_data.get("dino_labels")
805
+ )
806
+ sam_entry = {"frames": sam_frames, "video_path": None}
807
+ if sam_frames:
808
+ sam_dir = self._prepare_visualization_dir(
809
+ "sam", enabled=self.visualize
810
+ )
811
+ sam_entry["video_path"] = self._create_temp_video(
812
+ np.stack(sam_frames, axis=0),
813
+ base_dir=sam_dir,
814
+ prefix="sam_visualization",
815
+ )
816
+ visualizations["sam"] = sam_entry
817
+
818
+ dino_labels = visualization_data.get("dino_labels")
819
+ if dino_labels:
820
+ dino_frames = render_dino_frames(frames_np, bboxes, dino_labels)
821
+ dino_entry = {"frames": dino_frames, "video_path": None}
822
+ if dino_frames:
823
+ dino_dir = self._prepare_visualization_dir(
824
+ "dino", enabled=self.visualize
825
+ )
826
+ dino_entry["video_path"] = self._create_temp_video(
827
+ np.stack(dino_frames, axis=0),
828
+ base_dir=dino_dir,
829
+ prefix="dino_visualization",
830
+ )
831
+ visualizations["dino"] = dino_entry
832
+
833
+ for name in ("object", "unary", "binary"):
834
+ frames_list = vine_frame_sets.get(name, [])
835
+ entry: Dict[str, Any] = {"frames": frames_list, "video_path": None}
836
+ if frames_list:
837
+ vine_dir = self._prepare_visualization_dir(
838
+ name, enabled=self.visualize
839
+ )
840
+ entry["video_path"] = self._create_temp_video(
841
+ np.stack(frames_list, axis=0),
842
+ base_dir=vine_dir,
843
+ prefix=f"{name}_visualization",
844
+ )
845
+ vine_visuals[name] = entry
846
+
847
+ if vine_visuals:
848
+ visualizations["vine"] = vine_visuals
849
+
850
+ if visualizations:
851
+ results["visualizations"] = visualizations
852
+
853
+ return results
854
+
855
+ # ------------------------------------------------------------------ #
856
+ # Summary JSON
857
+ # ------------------------------------------------------------------ #
858
+ def _generate_summary(self, model_outputs: Dict[str, Any]) -> Dict[str, Any]:
859
+ """
860
+ Per-object summary:
861
+ {
862
+ "num_objects_detected": N,
863
+ "objects": {
864
+ "<obj_id>": {
865
+ "top_categories": [{"label": str, "probability": float}, ...],
866
+ "top_unary": [{"frame_id": int, "predicate": str, "probability": float}, ...],
867
+ }
868
+ }
869
+ }
870
+ """
871
+ categorical_preds = model_outputs.get("categorical_predictions", {})
872
+ unary_preds = model_outputs.get("unary_predictions", {})
873
+
874
+ unary_by_obj: Dict[int, List[Tuple[float, str, int]]] = {}
875
+ for (frame_id, obj_id), preds in unary_preds.items():
876
+ for prob, predicate in preds:
877
+ prob_val = (
878
+ float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
879
+ )
880
+ unary_by_obj.setdefault(obj_id, []).append((prob_val, predicate, frame_id))
881
+
882
+ objects_summary: Dict[str, Dict[str, Any]] = {}
883
+ all_obj_ids = set(categorical_preds.keys()) | set(unary_by_obj.keys())
884
+
885
+ for obj_id in sorted(all_obj_ids):
886
+ cat_list = categorical_preds.get(obj_id, [])
887
+ cat_sorted = sorted(
888
+ [
889
+ (
890
+ float(p.detach().cpu()) if torch.is_tensor(p) else float(p),
891
+ label,
892
+ )
893
+ for p, label in cat_list
894
+ ],
895
+ key=lambda x: x[0],
896
+ reverse=True,
897
+ )[:3]
898
+
899
+ top_categories = [
900
+ {"label": label, "probability": prob} for prob, label in cat_sorted
901
+ ]
902
+
903
+ unary_list = unary_by_obj.get(obj_id, [])
904
+ unary_sorted = sorted(unary_list, key=lambda x: x[0], reverse=True)[:3]
905
+ top_unary = [
906
+ {
907
+ "frame_id": int(frame_id),
908
+ "predicate": predicate,
909
+ "probability": prob,
910
+ }
911
+ for (prob, predicate, frame_id) in unary_sorted
912
+ ]
913
+
914
+ objects_summary[str(obj_id)] = {
915
+ "top_categories": top_categories,
916
+ "top_unary": top_unary,
917
+ }
918
+
919
+ summary = {
920
+ "num_objects_detected": len(objects_summary),
921
+ "objects": objects_summary,
922
+ }
923
+ return summary
vine_hf/vis_utils.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import random
7
+ import math
8
+ from matplotlib.patches import Rectangle
9
+ import itertools
10
+ from typing import Any, Dict, List, Tuple, Optional, Union
11
+
12
+ from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
13
+
14
+ ########################################################################################
15
+ ########## Visualization Library ########
16
+ ########################################################################################
17
+ # This module renders SAM masks, GroundingDINO boxes, and VINE predictions.
18
+ #
19
+ # Conventions (RGB frames, pixel coords):
20
+ # - Frames: list[np.ndarray] with shape (H, W, 3) in RGB, or np.ndarray with shape (T, H, W, 3).
21
+ # - Masks: 2D boolean arrays (H, W) or tensors convertible to that; (H, W, 1) is also accepted.
22
+ # - BBoxes: (x1, y1, x2, y2) integer pixel coordinates with x2 > x1 and y2 > y1.
23
+ #
24
+ # Per-frame stores use one of:
25
+ # - Dict[int(frame_id) -> Dict[int(obj_id) -> value]]
26
+ # - List indexed by frame_id (each item may be a dict of obj_id->value or a list in order)
27
+ #
28
+ # Renderer inputs/outputs:
29
+ # 1) render_sam_frames(frames, sam_masks, dino_labels=None) -> List[np.ndarray]
30
+ # - sam_masks: Dict[frame_id, Dict[obj_id, Mask]] or a list; Mask can be np.ndarray or torch.Tensor.
31
+ # - dino_labels: Optional Dict[obj_id, str] to annotate boxes derived from masks.
32
+ #
33
+ # 2) render_dino_frames(frames, bboxes, dino_labels=None) -> List[np.ndarray]
34
+ # - bboxes: Dict[frame_id, Dict[obj_id, Sequence[float]]] or a list; each bbox as [x1, y1, x2, y2].
35
+ #
36
+ # 3) render_vine_frames(frames, bboxes, cat_label_lookup, unary_lookup, binary_lookup, masks=None)
37
+ # -> List[np.ndarray] (the "all" view)
38
+ # - cat_label_lookup: Dict[obj_id, (label: str, prob: float)]
39
+ # - unary_lookup: Dict[frame_id, Dict[obj_id, List[(prob: float, label: str)]]]
40
+ # - binary_lookup: Dict[frame_id, List[((sub_id: int, obj_id: int), List[(prob: float, relation: str)])]]
41
+ # - masks: Optional; same structure as sam_masks, used for translucent overlays when unary labels exist.
42
+ #
43
+ # Ground-truth helpers used by plotting utilities:
44
+ # - For a single frame, gt_relations is represented as List[(subject_label, object_label, relation_label)].
45
+ #
46
+ # All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
47
+ ########################################################################################
48
+
49
+ def clean_label(label):
50
+ """Replace underscores and slashes with spaces for uniformity."""
51
+ return label.replace("_", " ").replace("/", " ")
52
+
53
+ # Should be performed somewhere else I believe
54
+ def format_cate_preds(cate_preds):
55
+ # Group object predictions from the model output.
56
+ obj_pred_dict = {}
57
+ for (oid, label), prob in cate_preds.items():
58
+ # Clean the predicted label as well.
59
+ clean_pred = clean_label(label)
60
+ if oid not in obj_pred_dict:
61
+ obj_pred_dict[oid] = []
62
+ obj_pred_dict[oid].append((clean_pred, prob))
63
+ for oid in obj_pred_dict:
64
+ obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
65
+ return obj_pred_dict
66
+
67
+ def format_binary_cate_preds(binary_preds):
68
+ frame_binary_preds = []
69
+ for key, score in binary_preds.items():
70
+ # Expect key format: (frame_id, (subject, object), predicted_relation)
71
+ try:
72
+ f_id, (subj, obj), pred_rel = key
73
+ frame_binary_preds.append((f_id, subj, obj, pred_rel, score))
74
+ except Exception as e:
75
+ print("Skipping key with unexpected format:", key)
76
+ continue
77
+ frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
78
+ return frame_binary_preds
79
+
80
+ _FONT = cv2.FONT_HERSHEY_SIMPLEX
81
+
82
+
83
+ def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.ndarray]:
84
+ if mask is None:
85
+ return None
86
+ if isinstance(mask, torch.Tensor):
87
+ mask_np = mask.detach().cpu().numpy()
88
+ else:
89
+ mask_np = np.asarray(mask)
90
+ if mask_np.ndim == 0:
91
+ return None
92
+ if mask_np.ndim == 3:
93
+ mask_np = np.squeeze(mask_np)
94
+ if mask_np.ndim != 2:
95
+ return None
96
+ if mask_np.dtype == bool:
97
+ return mask_np
98
+ return mask_np > 0
99
+
100
+
101
+ def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]:
102
+ if bbox is None:
103
+ return None
104
+ if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
105
+ x1, y1, x2, y2 = [float(b) for b in bbox[:4]]
106
+ elif isinstance(bbox, np.ndarray) and bbox.size >= 4:
107
+ x1, y1, x2, y2 = [float(b) for b in bbox.flat[:4]]
108
+ else:
109
+ return None
110
+ x1 = int(np.clip(round(x1), 0, width - 1))
111
+ y1 = int(np.clip(round(y1), 0, height - 1))
112
+ x2 = int(np.clip(round(x2), 0, width - 1))
113
+ y2 = int(np.clip(round(y2), 0, height - 1))
114
+ if x2 <= x1 or y2 <= y1:
115
+ return None
116
+ return (x1, y1, x2, y2)
117
+
118
+
119
+ def _object_color_bgr(obj_id: int) -> Tuple[int, int, int]:
120
+ color = get_color(obj_id)
121
+ rgb = [int(np.clip(c, 0.0, 1.0) * 255) for c in color[:3]]
122
+ return (rgb[2], rgb[1], rgb[0])
123
+
124
+
125
+ def _background_color(color: Tuple[int, int, int]) -> Tuple[int, int, int]:
126
+ return tuple(int(0.25 * 255 + 0.75 * channel) for channel in color)
127
+
128
+
129
+ def _draw_label_block(
130
+ image: np.ndarray,
131
+ lines: List[str],
132
+ anchor: Tuple[int, int],
133
+ color: Tuple[int, int, int],
134
+ font_scale: float = 0.5,
135
+ thickness: int = 1,
136
+ direction: str = "up",
137
+ ) -> None:
138
+ if not lines:
139
+ return
140
+ img_h, img_w = image.shape[:2]
141
+ x, y = anchor
142
+ x = int(np.clip(x, 0, img_w - 1))
143
+ y_cursor = int(np.clip(y, 0, img_h - 1))
144
+ bg_color = _background_color(color)
145
+
146
+ if direction == "down":
147
+ for text in lines:
148
+ text = str(text)
149
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
150
+ left_x = x
151
+ right_x = min(left_x + tw + 8, img_w - 1)
152
+ top_y = int(np.clip(y_cursor + 6, 0, img_h - 1))
153
+ bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
154
+ if bottom_y <= top_y:
155
+ break
156
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
157
+ text_x = left_x + 4
158
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
159
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
160
+ y_cursor = bottom_y
161
+ else:
162
+ for text in lines:
163
+ text = str(text)
164
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
165
+ top_y = max(y_cursor - th - baseline - 6, 0)
166
+ left_x = x
167
+ right_x = min(left_x + tw + 8, img_w - 1)
168
+ bottom_y = min(top_y + th + baseline + 6, img_h - 1)
169
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
170
+ text_x = left_x + 4
171
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
172
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
173
+ y_cursor = top_y
174
+
175
+
176
+ def _draw_centered_label(
177
+ image: np.ndarray,
178
+ text: str,
179
+ center: Tuple[int, int],
180
+ color: Tuple[int, int, int],
181
+ font_scale: float = 0.5,
182
+ thickness: int = 1,
183
+ ) -> None:
184
+ text = str(text)
185
+ img_h, img_w = image.shape[:2]
186
+ (tw, th), baseline = cv2.getTextSize(text, _FONT, font_scale, thickness)
187
+ cx = int(np.clip(center[0], 0, img_w - 1))
188
+ cy = int(np.clip(center[1], 0, img_h - 1))
189
+ left_x = int(np.clip(cx - tw // 2 - 4, 0, img_w - 1))
190
+ top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
191
+ right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
192
+ bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
193
+ cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1)
194
+ text_x = left_x + 4
195
+ text_y = min(bottom_y - baseline - 2, img_h - 1)
196
+ cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
197
+
198
+
199
+ def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]:
200
+ if isinstance(store, dict):
201
+ frame_entry = store.get(frame_idx, {})
202
+ elif isinstance(store, list) and 0 <= frame_idx < len(store):
203
+ frame_entry = store[frame_idx]
204
+ else:
205
+ frame_entry = {}
206
+ if isinstance(frame_entry, dict):
207
+ return frame_entry
208
+ if isinstance(frame_entry, list):
209
+ return {i: value for i, value in enumerate(frame_entry)}
210
+ return {}
211
+
212
+
213
+ def _label_anchor_and_direction(
214
+ bbox: Tuple[int, int, int, int],
215
+ position: str,
216
+ ) -> Tuple[Tuple[int, int], str]:
217
+ x1, y1, x2, y2 = bbox
218
+ if position == "bottom":
219
+ return (x1, y2), "down"
220
+ return (x1, y1), "up"
221
+
222
+
223
+ def _draw_bbox_with_label(
224
+ image: np.ndarray,
225
+ bbox: Tuple[int, int, int, int],
226
+ obj_id: int,
227
+ title: Optional[str] = None,
228
+ sub_lines: Optional[List[str]] = None,
229
+ label_position: str = "top",
230
+ ) -> None:
231
+ color = _object_color_bgr(obj_id)
232
+ cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
233
+ head = title if title else f"#{obj_id}"
234
+ if not head.startswith("#"):
235
+ head = f"#{obj_id} {head}"
236
+ lines = [head]
237
+ if sub_lines:
238
+ lines.extend(sub_lines)
239
+ anchor, direction = _label_anchor_and_direction(bbox, label_position)
240
+ _draw_label_block(image, lines, anchor, color, direction=direction)
241
+
242
+
243
+ def render_sam_frames(
244
+ frames: Union[np.ndarray, List[np.ndarray]],
245
+ sam_masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None],
246
+ dino_labels: Optional[Dict[int, str]] = None,
247
+ ) -> List[np.ndarray]:
248
+ results: List[np.ndarray] = []
249
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
250
+ dino_labels = dino_labels or {}
251
+
252
+ for frame_idx, frame in enumerate(frames_iterable):
253
+ if frame is None:
254
+ continue
255
+ frame_rgb = np.asarray(frame)
256
+ frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
257
+ overlay = frame_bgr.astype(np.float32)
258
+ masks_for_frame = _extract_frame_entities(sam_masks, frame_idx)
259
+
260
+ for obj_id, mask in masks_for_frame.items():
261
+ mask_np = _to_numpy_mask(mask)
262
+ if mask_np is None or not np.any(mask_np):
263
+ continue
264
+ color = _object_color_bgr(obj_id)
265
+ alpha = 0.45
266
+ overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32)
267
+
268
+ annotated = np.clip(overlay, 0, 255).astype(np.uint8)
269
+ frame_h, frame_w = annotated.shape[:2]
270
+
271
+ for obj_id, mask in masks_for_frame.items():
272
+ mask_np = _to_numpy_mask(mask)
273
+ if mask_np is None or not np.any(mask_np):
274
+ continue
275
+ bbox = mask_to_bbox(mask_np)
276
+ bbox = _sanitize_bbox(bbox, frame_w, frame_h)
277
+ if not bbox:
278
+ continue
279
+ label = dino_labels.get(obj_id)
280
+ title = f"{label}" if label else None
281
+ _draw_bbox_with_label(annotated, bbox, obj_id, title=title)
282
+
283
+ results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
284
+
285
+ return results
286
+
287
+
288
+ def render_dino_frames(
289
+ frames: Union[np.ndarray, List[np.ndarray]],
290
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
291
+ dino_labels: Optional[Dict[int, str]] = None,
292
+ ) -> List[np.ndarray]:
293
+ results: List[np.ndarray] = []
294
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
295
+ dino_labels = dino_labels or {}
296
+
297
+ for frame_idx, frame in enumerate(frames_iterable):
298
+ if frame is None:
299
+ continue
300
+ frame_rgb = np.asarray(frame)
301
+ annotated = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
302
+ frame_h, frame_w = annotated.shape[:2]
303
+ frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
304
+
305
+ for obj_id, bbox_values in frame_bboxes.items():
306
+ bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
307
+ if not bbox:
308
+ continue
309
+ label = dino_labels.get(obj_id)
310
+ title = f"{label}" if label else None
311
+ _draw_bbox_with_label(annotated, bbox, obj_id, title=title)
312
+
313
+ results.append(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
314
+
315
+ return results
316
+
317
+
318
+ def render_vine_frame_sets(
319
+ frames: Union[np.ndarray, List[np.ndarray]],
320
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
321
+ cat_label_lookup: Dict[int, Tuple[str, float]],
322
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
323
+ binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
324
+ masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
325
+ ) -> Dict[str, List[np.ndarray]]:
326
+ frame_groups: Dict[str, List[np.ndarray]] = {
327
+ "object": [],
328
+ "unary": [],
329
+ "binary": [],
330
+ "all": [],
331
+ }
332
+ frames_iterable = frames if isinstance(frames, list) else list(frames)
333
+
334
+ for frame_idx, frame in enumerate(frames_iterable):
335
+ if frame is None:
336
+ continue
337
+ frame_rgb = np.asarray(frame)
338
+ base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
339
+ frame_h, frame_w = base_bgr.shape[:2]
340
+ frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
341
+ frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {}
342
+
343
+ objects_bgr = base_bgr.copy()
344
+ unary_bgr = base_bgr.copy()
345
+ binary_bgr = base_bgr.copy()
346
+ all_bgr = base_bgr.copy()
347
+
348
+ bbox_lookup: Dict[int, Tuple[int, int, int, int]] = {}
349
+ unary_lines_lookup: Dict[int, List[str]] = {}
350
+ titles_lookup: Dict[int, Optional[str]] = {}
351
+
352
+ for obj_id, bbox_values in frame_bboxes.items():
353
+ bbox = _sanitize_bbox(bbox_values, frame_w, frame_h)
354
+ if not bbox:
355
+ continue
356
+ bbox_lookup[obj_id] = bbox
357
+ cat_label, cat_prob = cat_label_lookup.get(obj_id, (None, None))
358
+ title_parts = []
359
+ if cat_label:
360
+ if cat_prob is not None:
361
+ title_parts.append(f"{cat_label} {cat_prob:.2f}")
362
+ else:
363
+ title_parts.append(cat_label)
364
+ titles_lookup[obj_id] = " ".join(title_parts) if title_parts else None
365
+ unary_preds = unary_lookup.get(frame_idx, {}).get(obj_id, [])
366
+ unary_lines = [f"{label} {prob:.2f}" for prob, label in unary_preds]
367
+ unary_lines_lookup[obj_id] = unary_lines
368
+
369
+ for obj_id, bbox in bbox_lookup.items():
370
+ unary_lines = unary_lines_lookup.get(obj_id, [])
371
+ if not unary_lines:
372
+ continue
373
+ mask_raw = frame_masks.get(obj_id)
374
+ mask_np = _to_numpy_mask(mask_raw)
375
+ if mask_np is None or not np.any(mask_np):
376
+ continue
377
+ color = np.array(_object_color_bgr(obj_id), dtype=np.float32)
378
+ alpha = 0.45
379
+ for target in (unary_bgr, all_bgr):
380
+ target_vals = target[mask_np].astype(np.float32)
381
+ blended = (1.0 - alpha) * target_vals + alpha * color
382
+ target[mask_np] = np.clip(blended, 0, 255).astype(np.uint8)
383
+
384
+ for obj_id, bbox in bbox_lookup.items():
385
+ title = titles_lookup.get(obj_id)
386
+ unary_lines = unary_lines_lookup.get(obj_id, [])
387
+ _draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top")
388
+ _draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top")
389
+ if unary_lines:
390
+ anchor, direction = _label_anchor_and_direction(bbox, "bottom")
391
+ _draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
392
+ _draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top")
393
+ _draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top")
394
+ if unary_lines:
395
+ anchor, direction = _label_anchor_and_direction(bbox, "bottom")
396
+ _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
397
+
398
+ for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
399
+ if len(obj_pair) != 2 or not relation_preds:
400
+ continue
401
+ subj_id, obj_id = obj_pair
402
+ subj_bbox = bbox_lookup.get(subj_id)
403
+ obj_bbox = bbox_lookup.get(obj_id)
404
+ if not subj_bbox or not obj_bbox:
405
+ continue
406
+ start, end = relation_line(subj_bbox, obj_bbox)
407
+ color = tuple(int(c) for c in np.clip(
408
+ (np.array(_object_color_bgr(subj_id), dtype=np.float32) +
409
+ np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
410
+ 0, 255
411
+ ))
412
+ prob, relation = relation_preds[0]
413
+ label_text = f"{relation} {prob:.2f}"
414
+ mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
415
+ cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA)
416
+ cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA)
417
+ _draw_centered_label(binary_bgr, label_text, mid_point, color)
418
+ _draw_centered_label(all_bgr, label_text, mid_point, color)
419
+
420
+ frame_groups["object"].append(cv2.cvtColor(objects_bgr, cv2.COLOR_BGR2RGB))
421
+ frame_groups["unary"].append(cv2.cvtColor(unary_bgr, cv2.COLOR_BGR2RGB))
422
+ frame_groups["binary"].append(cv2.cvtColor(binary_bgr, cv2.COLOR_BGR2RGB))
423
+ frame_groups["all"].append(cv2.cvtColor(all_bgr, cv2.COLOR_BGR2RGB))
424
+
425
+ return frame_groups
426
+
427
+
428
+ def render_vine_frames(
429
+ frames: Union[np.ndarray, List[np.ndarray]],
430
+ bboxes: Union[Dict[int, Dict[int, Union[List[float], np.ndarray]]], List, None],
431
+ cat_label_lookup: Dict[int, Tuple[str, float]],
432
+ unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
433
+ binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
434
+ masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
435
+ ) -> List[np.ndarray]:
436
+ return render_vine_frame_sets(
437
+ frames,
438
+ bboxes,
439
+ cat_label_lookup,
440
+ unary_lookup,
441
+ binary_lookup,
442
+ masks,
443
+ ).get("all", [])
444
+
445
+ def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
446
+ all_colors = []
447
+ all_texts = []
448
+ for (obj_id, bbox, gt_label) in gt_labels:
449
+ preds = obj_pred_dict.get(obj_id, [])
450
+ if len(preds) == 0:
451
+ top1 = "N/A"
452
+ box_color = (0, 0, 255) # bright red if no prediction
453
+ else:
454
+ top1, prob1 = preds[0]
455
+ topk_labels = [p[0] for p in preds[:topk_object]]
456
+ # Compare cleaned labels.
457
+ if top1.lower() == gt_label.lower():
458
+ box_color = (0, 255, 0) # bright green for correct
459
+ elif gt_label.lower() in [p.lower() for p in topk_labels]:
460
+ box_color = (0, 165, 255) # bright orange for partial match
461
+ else:
462
+ box_color = (0, 0, 255) # bright red for incorrect
463
+
464
+ label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
465
+ all_colors.append(box_color)
466
+ all_texts.append(label_text)
467
+ return all_colors, all_texts
468
+
469
+ def plot_unary(frame_img, gt_labels, all_colors, all_texts):
470
+
471
+ for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts):
472
+ x1, y1, x2, y2 = map(int, bbox)
473
+ cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
474
+ (tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
475
+ cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1)
476
+ cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX,
477
+ 0.5, (0, 0, 0), 1, cv2.LINE_AA)
478
+
479
+ return frame_img
480
+
481
+ def get_white_pane(pane_height,
482
+ pane_width=600,
483
+ header_height = 50,
484
+ header_font = cv2.FONT_HERSHEY_SIMPLEX,
485
+ header_font_scale = 0.7,
486
+ header_thickness = 2,
487
+ header_color = (0, 0, 0)):
488
+ # Create an expanded white pane to display text info.
489
+ white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
490
+
491
+ # --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
492
+ left_width = int(pane_width * 0.6)
493
+ right_width = pane_width - left_width
494
+ left_pane = white_pane[:, :left_width, :].copy()
495
+ right_pane = white_pane[:, left_width:, :].copy()
496
+
497
+ cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30),
498
+ header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
499
+ cv2.putText(right_pane, "Ground Truth", (10, header_height - 30),
500
+ header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
501
+
502
+ return white_pane
503
+
504
+ # This is for ploting binary prediction results with frame-based scene graphs
505
+ def plot_binary_sg(frame_img,
506
+ white_pane,
507
+ bin_preds,
508
+ gt_relations,
509
+ topk_binary,
510
+ header_height=50,
511
+ indicator_size=20,
512
+ pane_width=600):
513
+ # Leave vertical space for the headers.
514
+ line_height = 30 # vertical spacing per line
515
+ x_text = 10 # left margin for text
516
+ y_text_left = header_height + 10 # starting y for left pane text
517
+ y_text_right = header_height + 10 # starting y for right pane text
518
+
519
+ # Left section: top-k binary predictions.
520
+ left_width = int(pane_width * 0.6)
521
+ right_width = pane_width - left_width
522
+ left_pane = white_pane[:, :left_width, :].copy()
523
+ right_pane = white_pane[:, left_width:, :].copy()
524
+
525
+ for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]:
526
+ correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
527
+ for gt in gt_relations)
528
+ indicator_color = (0, 255, 0) if correct else (0, 0, 255)
529
+ cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5),
530
+ (x_text + indicator_size, y_text_left + 5), indicator_color, -1)
531
+ text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
532
+ cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5),
533
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
534
+ y_text_left += line_height
535
+
536
+ # Right section: ground truth binary relations.
537
+ for gt in gt_relations:
538
+ if len(gt) != 3:
539
+ continue
540
+ text = f"{gt[0]} - {gt[2]} - {gt[1]}"
541
+ cv2.putText(right_pane, text, (x_text, y_text_right + 5),
542
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
543
+ y_text_right += line_height
544
+
545
+ # Combine the two text panes and then with the frame image.
546
+ combined_pane = np.hstack((left_pane, right_pane))
547
+ combined_image = np.hstack((frame_img, combined_pane))
548
+ return combined_image
549
+
550
+ def visualized_frame(frame_img,
551
+ bboxes,
552
+ object_ids,
553
+ gt_labels,
554
+ cate_preds,
555
+ binary_preds,
556
+ gt_relations,
557
+ topk_object,
558
+ topk_binary,
559
+ phase="unary"):
560
+
561
+ """Return the combined annotated frame for frame index i as an image (in BGR)."""
562
+ # Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
563
+
564
+ # --- Process Object Predictions (for overlaying bboxes) ---
565
+ if phase == "unary":
566
+ objs = []
567
+ for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels):
568
+ gt_label = clean_label(gt_label)
569
+ objs.append((obj_id, bbox, gt_label))
570
+
571
+ formatted_cate_preds = format_cate_preds(cate_preds)
572
+ all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object)
573
+ updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
574
+ return updated_frame_img
575
+
576
+ else:
577
+ # --- Process Binary Predictions & Ground Truth for the Text Pane ---
578
+ formatted_binary_preds = format_binary_cate_preds(binary_preds)
579
+
580
+ # Ground truth binary relations for the frame.
581
+ # Clean ground truth relations.
582
+ gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations]
583
+
584
+ pane_width = 600 # increased pane width for more horizontal space
585
+ pane_height = frame_img.shape[0]
586
+
587
+ # --- Add header labels to each text pane with extra space ---
588
+ header_height = 50 # increased header space
589
+ white_pane = get_white_pane(pane_height, pane_width, header_height=header_height)
590
+
591
+ combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary)
592
+
593
+ return combined_image
594
+
595
+ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
596
+ # Ensure mask is a numpy array
597
+ mask = np.array(mask)
598
+ # Handle different mask shapes
599
+ if mask.ndim == 3:
600
+ # (1, H, W) -> (H, W)
601
+ if mask.shape[0] == 1:
602
+ mask = mask.squeeze(0)
603
+ # (H, W, 1) -> (H, W)
604
+ elif mask.shape[2] == 1:
605
+ mask = mask.squeeze(2)
606
+ # Now mask should be (H, W)
607
+ assert mask.ndim == 2, f"Mask must be 2D after squeezing, got shape {mask.shape}"
608
+
609
+ if random_color:
610
+ color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
611
+ else:
612
+ cmap = plt.get_cmap("gist_rainbow")
613
+ cmap_idx = 0 if obj_id is None else obj_id
614
+ color = list(cmap((cmap_idx * 47) % 256))
615
+ color[3] = 0.5
616
+ color = np.array(color)
617
+
618
+ # Expand mask to (H, W, 1) for broadcasting
619
+ mask_expanded = mask[..., None]
620
+ mask_image = mask_expanded * color.reshape(1, 1, -1)
621
+
622
+ # draw a box around the mask with the det_class as the label
623
+ if not det_class is None:
624
+ # Find the bounding box coordinates
625
+ y_indices, x_indices = np.where(mask > 0)
626
+ if y_indices.size > 0 and x_indices.size > 0:
627
+ x_min, x_max = x_indices.min(), x_indices.max()
628
+ y_min, y_max = y_indices.min(), y_indices.max()
629
+ rect = Rectangle(
630
+ (x_min, y_min),
631
+ x_max - x_min,
632
+ y_max - y_min,
633
+ linewidth=1.5,
634
+ edgecolor=color[:3],
635
+ facecolor="none",
636
+ alpha=color[3]
637
+ )
638
+ ax.add_patch(rect)
639
+ ax.text(
640
+ x_min,
641
+ y_min - 5,
642
+ f"{det_class}",
643
+ color="white",
644
+ fontsize=6,
645
+ backgroundcolor=np.array(color),
646
+ alpha=1
647
+ )
648
+ ax.imshow(mask_image)
649
+
650
+ def save_mask_one_image(frame_image, masks, save_path):
651
+ """Render masks on top of a frame and store the visualization on disk."""
652
+ fig, ax = plt.subplots(1, figsize=(6, 6))
653
+
654
+ frame_np = (
655
+ frame_image.detach().cpu().numpy()
656
+ if torch.is_tensor(frame_image)
657
+ else np.asarray(frame_image)
658
+ )
659
+ frame_np = np.ascontiguousarray(frame_np)
660
+
661
+ if isinstance(masks, dict):
662
+ mask_iter = masks.items()
663
+ else:
664
+ mask_iter = enumerate(masks)
665
+
666
+ prepared_masks = {
667
+ obj_id: (
668
+ mask.detach().cpu().numpy()
669
+ if torch.is_tensor(mask)
670
+ else np.asarray(mask)
671
+ )
672
+ for obj_id, mask in mask_iter
673
+ }
674
+
675
+ ax.imshow(frame_np)
676
+ ax.axis("off")
677
+
678
+ for obj_id, mask_np in prepared_masks.items():
679
+ show_mask(mask_np, ax, obj_id=obj_id, det_class=None, random_color=False)
680
+
681
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
682
+ plt.close(fig)
683
+ return save_path
684
+
685
+ def get_video_masks_visualization(video_tensor,
686
+ video_masks,
687
+ video_id,
688
+ video_save_base_dir,
689
+ oid_class_pred=None,
690
+ sample_rate = 1):
691
+
692
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
693
+ if not os.path.exists(video_save_dir):
694
+ os.makedirs(video_save_dir, exist_ok=True)
695
+
696
+ for frame_id, image in enumerate(video_tensor):
697
+ if frame_id not in video_masks:
698
+ print("No mask for Frame", frame_id)
699
+ continue
700
+
701
+ masks = video_masks[frame_id]
702
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
703
+ get_mask_one_image(image, masks, oid_class_pred)
704
+
705
+ def get_mask_one_image(frame_image, masks, oid_class_pred=None):
706
+ # Create a figure and axis
707
+ fig, ax = plt.subplots(1, figsize=(6, 6))
708
+
709
+ # Display the frame image
710
+ ax.imshow(frame_image)
711
+ ax.axis('off')
712
+
713
+ if type(masks) == list:
714
+ masks = {i: m for i, m in enumerate(masks)}
715
+
716
+ # Add the masks
717
+ for obj_id, mask in masks.items():
718
+ det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None
719
+ show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
720
+
721
+ # Show the plot
722
+ return fig, ax
723
+
724
+ def save_video(frames, output_filename, output_fps):
725
+
726
+ # --- Create a video from all frames ---
727
+ num_frames = len(frames)
728
+ frame_h, frame_w = frames.shape[:2]
729
+
730
+ # Use a codec supported by VS Code (H.264 via 'avc1').
731
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
732
+ out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
733
+
734
+ print(f"Processing {num_frames} frames...")
735
+ for i in range(num_frames):
736
+ vis_frame = get_visualized_frame(i)
737
+ out.write(vis_frame)
738
+ if i % 10 == 0:
739
+ print(f"Processed frame {i+1}/{num_frames}")
740
+
741
+ out.release()
742
+ print(f"Video saved as {output_filename}")
743
+
744
+
745
+ def list_depth(lst):
746
+ """Calculates the depth of a nested list."""
747
+ if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
748
+ return 0
749
+ elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0):
750
+ return 1
751
+ else:
752
+ return 1 + max(list_depth(item) for item in lst)
753
+
754
+ def normalize_prompt(points, labels):
755
+ if list_depth(points) == 3:
756
+ points = torch.stack([p.unsqueeze(0) for p in points])
757
+ labels = torch.stack([l.unsqueeze(0) for l in labels])
758
+ return points, labels
759
+
760
+
761
+ def show_box(box, ax, object_id):
762
+ if len(box) == 0:
763
+ return
764
+
765
+ cmap = plt.get_cmap("gist_rainbow")
766
+ cmap_idx = 0 if object_id is None else object_id
767
+ color = list(cmap((cmap_idx * 47) % 256))
768
+
769
+ x0, y0 = box[0], box[1]
770
+ w, h = box[2] - box[0], box[3] - box[1]
771
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2))
772
+
773
+ def show_points(coords, labels, ax, object_id=None, marker_size=375):
774
+ if len(labels) == 0:
775
+ return
776
+
777
+ pos_points = coords[labels==1]
778
+ neg_points = coords[labels==0]
779
+
780
+ cmap = plt.get_cmap("gist_rainbow")
781
+ cmap_idx = 0 if object_id is None else object_id
782
+ color = list(cmap((cmap_idx * 47) % 256))
783
+
784
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25)
785
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25)
786
+
787
+ def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
788
+ # Create a figure and axis
789
+ fig, ax = plt.subplots(1, figsize=(6, 6))
790
+
791
+ # Display the frame image
792
+ ax.imshow(frame_image)
793
+ ax.axis('off')
794
+
795
+ points, labels = normalize_prompt(points, labels)
796
+ if type(boxes) == torch.Tensor:
797
+ for object_id, box in enumerate(boxes):
798
+ # Add the bounding boxes
799
+ if not box is None:
800
+ show_box(box.cpu(), ax, object_id=object_id)
801
+ elif type(boxes) == dict:
802
+ for object_id, box in boxes.items():
803
+ # Add the bounding boxes
804
+ if not box is None:
805
+ show_box(box.cpu(), ax, object_id=object_id)
806
+ elif type(boxes) == list and len(boxes) == 0:
807
+ pass
808
+ else:
809
+ raise Exception()
810
+
811
+ for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
812
+ if not len(point_ls) == 0:
813
+ show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
814
+
815
+ # Show the plot
816
+ plt.savefig(save_path)
817
+ plt.close()
818
+
819
+ def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir):
820
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
821
+ if not os.path.exists(video_save_dir):
822
+ os.makedirs(video_save_dir, exist_ok=True)
823
+
824
+ for frame_id, image in enumerate(video_tensor):
825
+ boxes, points, labels = [], [], []
826
+
827
+ if frame_id in video_boxes:
828
+ boxes = video_boxes[frame_id]
829
+
830
+ if frame_id in video_points:
831
+ points = video_points[frame_id]
832
+ if frame_id in video_labels:
833
+ labels = video_labels[frame_id]
834
+
835
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
836
+ save_prompts_one_image(image, boxes, points, labels, save_path)
837
+
838
+
839
+ def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1):
840
+ video_save_dir = os.path.join(video_save_base_dir, video_id)
841
+ if not os.path.exists(video_save_dir):
842
+ os.makedirs(video_save_dir, exist_ok=True)
843
+
844
+ for frame_id, image in enumerate(video_tensor):
845
+ if random.random() > sample_rate:
846
+ continue
847
+ if frame_id not in video_masks:
848
+ print("No mask for Frame", frame_id)
849
+ continue
850
+ masks = video_masks[frame_id]
851
+ save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
852
+ save_mask_one_image(image, masks, save_path)
853
+
854
+
855
+
856
+ def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
857
+ cmap = plt.get_cmap(cmap_name)
858
+ cmap_idx = 0 if obj_id is None else obj_id
859
+ color = list(cmap((cmap_idx * 47) % 256))
860
+ color[3] = 0.5
861
+ color = np.array(color)
862
+ return color
863
+
864
+
865
+ def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
866
+ return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
867
+
868
+
869
+ def relation_line(
870
+ bbox1: Tuple[int, int, int, int],
871
+ bbox2: Tuple[int, int, int, int],
872
+ ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
873
+ """
874
+ Returns integer pixel centers suitable for drawing a relation line. For
875
+ coincident boxes, nudges the target center to ensure the segment has span.
876
+ """
877
+ center1 = _bbox_center(bbox1)
878
+ center2 = _bbox_center(bbox2)
879
+ if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3):
880
+ offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
881
+ center2 = (center2[0] + offset, center2[1])
882
+ start = (int(round(center1[0])), int(round(center1[1])))
883
+ end = (int(round(center2[0])), int(round(center2[1])))
884
+ if start == end:
885
+ end = (end[0] + 1, end[1])
886
+ return start, end
887
+
888
+ def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
889
+ # Create a figure and axis
890
+ fig, ax = plt.subplots(1, figsize=(6, 6))
891
+
892
+ # Display the frame image
893
+ ax.imshow(frame_image)
894
+ ax.axis('off')
895
+
896
+ all_objs_to_show = set()
897
+ all_lines_to_show = []
898
+
899
+ # print(rel_pred_ls[0])
900
+ for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
901
+ all_objs_to_show.add(from_obj_id)
902
+ all_objs_to_show.add(to_obj_id)
903
+
904
+ from_mask = masks[from_obj_id]
905
+ bbox1 = mask_to_bbox(from_mask)
906
+ to_mask = masks[to_obj_id]
907
+ bbox2 = mask_to_bbox(to_mask)
908
+
909
+ c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
910
+
911
+ line_color = get_color(from_obj_id)
912
+ face_color = get_color(to_obj_id)
913
+ line = c1, c2, face_color, line_color, rel_text
914
+ all_lines_to_show.append(line)
915
+
916
+ masks_to_show = {}
917
+ for oid in all_objs_to_show:
918
+ masks_to_show[oid] = masks[oid]
919
+
920
+ # Add the masks
921
+ for obj_id, mask in masks_to_show.items():
922
+ show_mask(mask, ax, obj_id=obj_id, random_color=False)
923
+
924
+ for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show:
925
+
926
+ plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3)
927
+ mid_pt_x = (from_pt_x + to_pt_x) / 2
928
+ mid_pt_y = (from_pt_y + to_pt_y) / 2
929
+ ax.text(
930
+ mid_pt_x - 5,
931
+ mid_pt_y,
932
+ rel_text,
933
+ color="white",
934
+ fontsize=6,
935
+ backgroundcolor=np.array(line_color),
936
+ bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'),
937
+ alpha=1
938
+ )
939
+
940
+ # Show the plot
941
+ return fig, ax