jiani-huang commited on
Commit
d3c563b
·
1 Parent(s): 21f4849

copying in saved code

Browse files
app.py CHANGED
@@ -60,12 +60,155 @@ print(
60
  )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  @lru_cache(maxsize=1)
66
  def _load_vine_pipeline():
67
  """
68
- Lazy-load and cache the Vine pipeline so we don't re-download/rebuild it on every request.
69
  """
70
  from vine_hf import VineConfig, VineModel, VinePipeline
71
 
@@ -84,6 +227,7 @@ def _load_vine_pipeline():
84
  debug_visualizations=False,
85
  device="cuda",
86
  categorical_pool="max",
 
87
  )
88
  model = VineModel(config)
89
  return VinePipeline(
@@ -104,6 +248,7 @@ def process_video(
104
  categorical_keywords,
105
  unary_keywords,
106
  binary_keywords,
 
107
  output_fps,
108
  box_threshold,
109
  text_threshold,
@@ -121,34 +266,86 @@ def process_video(
121
  if not isinstance(video_file, (str, Path)):
122
  raise ValueError(f"Unsupported video input type: {type(video_file)}")
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  categorical_keywords = (
125
- [kw.strip() for kw in categorical_keywords.split(",")]
126
- if categorical_keywords
127
  else []
128
  )
129
  unary_keywords = (
130
- [kw.strip() for kw in unary_keywords.split(",")] if unary_keywords else []
131
- )
132
- binary_keywords = (
133
- [kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
134
  )
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # Debug: Print what we're sending to the pipeline
137
  print("\n" + "=" * 80)
138
- print("INPUT TO VINE PIPELINE:")
139
  print(f" categorical_keywords: {categorical_keywords}")
140
  print(f" unary_keywords: {unary_keywords}")
141
- print(f" binary_keywords: {binary_keywords}")
 
 
 
142
  print("=" * 80 + "\n")
143
 
144
  # Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py
145
- object_pairs = []
 
 
 
 
 
 
 
 
 
 
146
 
147
  results = vine_pipe(
148
  inputs=video_file,
149
  categorical_keywords=categorical_keywords,
150
  unary_keywords=unary_keywords,
151
- binary_keywords=binary_keywords,
152
  object_pairs=object_pairs,
153
  segmentation_method="grounding_dino_sam2",
154
  return_top_k=5,
@@ -159,6 +356,7 @@ def process_video(
159
  text_threshold=text_threshold,
160
  target_fps=output_fps,
161
  binary_confidence_threshold=binary_confidence_threshold,
 
162
  )
163
 
164
  # Debug: Print what the pipeline returned
@@ -193,6 +391,13 @@ def process_video(
193
  result_video_path = str(candidates[0]) if candidates else None
194
  summary = results_dict.get("summary") or {}
195
 
 
 
 
 
 
 
 
196
  if result_video_path and os.path.exists(result_video_path):
197
  gradio_tmp = (
198
  Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
@@ -246,7 +451,7 @@ def _create_blocks():
246
  """
247
  Build a Blocks context that works across Gradio versions.
248
  """
249
- blocks_kwargs = {"title": "VINE Demo"}
250
  soft_theme = None
251
 
252
  if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
@@ -265,21 +470,23 @@ def _create_blocks():
265
  with _create_blocks() as demo:
266
  gr.Markdown(
267
  """
268
- # 🎬 VINE: Video-based Interaction and Event Detection
269
 
270
- Upload an MP4 video and specify keywords to detect objects, actions, and interactions in your video.
 
 
271
  """
272
  )
273
 
274
  with gr.Row():
275
  # Left column: Inputs
276
  with gr.Column(scale=1):
277
- gr.Markdown("### Input Configuration")
278
 
279
  video_input = _video_component("Upload Video (MP4 only)", is_output=False)
280
  gr.Markdown("*Note: Only MP4 format is currently supported*")
281
 
282
- gr.Markdown("#### Detection Keywords")
283
  categorical_input = gr.Textbox(
284
  label="Categorical Keywords",
285
  placeholder="e.g., person, car, dog",
@@ -294,8 +501,20 @@ with _create_blocks() as demo:
294
  )
295
  binary_input = gr.Textbox(
296
  label="Binary Keywords",
297
- placeholder="e.g., chasing, carrying",
298
- info="Object-to-object interactions to detect (comma-separated)",
 
 
 
 
 
 
 
 
 
 
 
 
299
  )
300
 
301
  gr.Markdown("#### Processing Settings")
@@ -326,7 +545,7 @@ with _create_blocks() as demo:
326
  label="Binary Relation Confidence Threshold",
327
  minimum=0.0,
328
  maximum=1.0,
329
- value=0.8,
330
  step=0.05,
331
  info="Minimum confidence to show binary relations and object pairs",
332
  )
@@ -335,24 +554,31 @@ with _create_blocks() as demo:
335
 
336
  # Right column: Outputs
337
  with gr.Column(scale=1):
338
- gr.Markdown("### Results")
339
 
340
  video_output = _video_component("Annotated Video Output", is_output=True)
341
 
342
- gr.Markdown("### Detection Summary")
343
- summary_output = gr.JSON(label="Summary of Detected Events")
344
 
345
  gr.Markdown(
346
  """
347
  ---
348
- ### How to Use
349
- 1. Upload an MP4 video file
350
- 2. Specify the objects, actions, and interactions you want to detect
351
- 3. Adjust processing settings if needed (including binary relation confidence threshold)
352
- 4. Click "Process Video" to analyze
353
-
354
- The system will automatically detect all binary relations between detected objects
355
- and show only those with confidence above the threshold (default: 0.8).
 
 
 
 
 
 
 
356
  """
357
  )
358
 
@@ -363,6 +589,7 @@ with _create_blocks() as demo:
363
  categorical_input,
364
  unary_input,
365
  binary_input,
 
366
  fps_input,
367
  box_threshold_input,
368
  text_threshold_input,
 
60
  )
61
 
62
 
63
+ def _split_top_level_commas(s: str):
64
+ """
65
+ Split a string on commas that are NOT inside parentheses.
66
+
67
+ Example:
68
+ "behind(person, dog), bite(dog, frisbee)"
69
+ -> ["behind(person, dog)", "bite(dog, frisbee)"]
70
+ """
71
+ parts = []
72
+ buf = []
73
+ depth = 0
74
+ for ch in s:
75
+ if ch == "(":
76
+ depth += 1
77
+ buf.append(ch)
78
+ elif ch == ")":
79
+ if depth > 0:
80
+ depth -= 1
81
+ buf.append(ch)
82
+ elif ch == "," and depth == 0:
83
+ part = "".join(buf).strip()
84
+ if part:
85
+ parts.append(part)
86
+ buf = []
87
+ else:
88
+ buf.append(ch)
89
+ if buf:
90
+ part = "".join(buf).strip()
91
+ if part:
92
+ parts.append(part)
93
+ return parts
94
+
95
+
96
+ def _extract_categories_from_binary(binary_keywords_str: str) -> list[str]:
97
+ """
98
+ Pull candidate category tokens from binary keyword strings, e.g. relation(a, b).
99
+ Only returns tokens when parentheses and two comma-separated entries exist.
100
+ """
101
+ categories: list[str] = []
102
+ for kw in _split_top_level_commas(binary_keywords_str or ""):
103
+ lpar = kw.find("(")
104
+ rpar = kw.rfind(")")
105
+ if lpar == -1 or rpar <= lpar:
106
+ continue
107
+ inside = kw[lpar + 1 : rpar]
108
+ parts = [p.strip() for p in inside.split(",") if p.strip()]
109
+ if len(parts) == 2:
110
+ categories.extend(parts)
111
+ return categories
112
+
113
+
114
+ def _parse_binary_keywords(binary_keywords_str: str, categorical_keywords: list[str]):
115
+ """
116
+ Parse binary keyword string like:
117
+ "behind(person, dog), bite(dog, frisbee)"
118
+ into:
119
+ - binary_keywords_list: list of raw strings (used as CLIP text)
120
+ - batched_binary_predicates: {0: [(rel_text, from_cat, to_cat), ...]} or None
121
+ - warnings: list of warning strings about invalid/mismatched categories
122
+ """
123
+ if not binary_keywords_str:
124
+ return [], None, []
125
+
126
+ cat_map = {
127
+ kw.strip().lower(): kw.strip()
128
+ for kw in categorical_keywords
129
+ if isinstance(kw, str) and kw.strip()
130
+ }
131
+
132
+ entries = _split_top_level_commas(binary_keywords_str)
133
+ binary_keywords_list: list[str] = []
134
+ predicates: list[tuple[str, str, str]] = []
135
+ warnings: list[str] = []
136
+
137
+ for raw in entries:
138
+ kw = raw.strip()
139
+ if not kw:
140
+ continue
141
+ # Always use the full raw keyword as the CLIP text string
142
+ binary_keywords_list.append(kw)
143
+
144
+ lpar = kw.find("(")
145
+ rpar = kw.rfind(")")
146
+ if (lpar == -1 and rpar != -1) or (lpar != -1 and rpar == -1) or rpar < lpar:
147
+ msg = (
148
+ f"Binary keyword '{kw}' has mismatched parentheses; expected "
149
+ "relation(from_category, to_category)."
150
+ )
151
+ print(msg)
152
+ warnings.append(msg)
153
+ continue
154
+
155
+ if lpar == -1 or rpar <= lpar:
156
+ # No explicit (from,to) part; treat as plain relation (no category filter)
157
+ continue
158
+
159
+ inside = kw[lpar + 1 : rpar]
160
+ parts = inside.split(",")
161
+ if len(parts) != 2:
162
+ msg = (
163
+ f"Ignoring '(from,to)' part in binary keyword '{kw}': "
164
+ f"expected exactly two comma-separated items."
165
+ )
166
+ print(msg)
167
+ warnings.append(msg)
168
+ continue
169
+
170
+ from_raw = parts[0].strip()
171
+ to_raw = parts[1].strip()
172
+ if not from_raw or not to_raw:
173
+ msg = f"Ignoring binary keyword '{kw}': empty from/to category."
174
+ print(msg)
175
+ warnings.append(msg)
176
+ continue
177
+
178
+ canonical_from = cat_map.get(from_raw.lower())
179
+ canonical_to = cat_map.get(to_raw.lower())
180
+
181
+ if canonical_from is None:
182
+ msg = (
183
+ f"Binary keyword '{kw}': from-category '{from_raw}' does not "
184
+ f"match any categorical keyword {categorical_keywords}."
185
+ )
186
+ print(msg)
187
+ warnings.append(msg)
188
+ if canonical_to is None:
189
+ msg = (
190
+ f"Binary keyword '{kw}': to-category '{to_raw}' does not "
191
+ f"match any categorical keyword {categorical_keywords}."
192
+ )
193
+ print(msg)
194
+ warnings.append(msg)
195
+
196
+ if canonical_from is None or canonical_to is None:
197
+ continue
198
+
199
+ # Store (relation_text, from_category, to_category)
200
+ predicates.append((kw, canonical_from, canonical_to))
201
+
202
+ if not predicates:
203
+ return binary_keywords_list, None, warnings
204
+
205
+ return binary_keywords_list, {0: predicates}, warnings
206
 
207
 
208
  @lru_cache(maxsize=1)
209
  def _load_vine_pipeline():
210
  """
211
+ Lazy-load and cache the LASER (VINE HF) pipeline so we don't re-download/rebuild it on every request.
212
  """
213
  from vine_hf import VineConfig, VineModel, VinePipeline
214
 
 
227
  debug_visualizations=False,
228
  device="cuda",
229
  categorical_pool="max",
230
+ auto_add_not_unary=False, # UI will control this per-call
231
  )
232
  model = VineModel(config)
233
  return VinePipeline(
 
248
  categorical_keywords,
249
  unary_keywords,
250
  binary_keywords,
251
+ auto_add_not_unary,
252
  output_fps,
253
  box_threshold,
254
  text_threshold,
 
266
  if not isinstance(video_file, (str, Path)):
267
  raise ValueError(f"Unsupported video input type: {type(video_file)}")
268
 
269
+ video_path = Path(video_file)
270
+ if video_path.suffix.lower() != ".mp4":
271
+ msg = (
272
+ "Please upload an MP4 file. LASER currently supports MP4 inputs for "
273
+ "scene-graph generation."
274
+ )
275
+ print(msg)
276
+ return None, {"error": msg}
277
+ video_file = str(video_path)
278
+
279
+ # Keep original strings for parsing
280
+ categorical_keywords_str = categorical_keywords
281
+ unary_keywords_str = unary_keywords
282
+ binary_keywords_str = binary_keywords
283
+
284
  categorical_keywords = (
285
+ [kw.strip() for kw in categorical_keywords_str.split(",")]
286
+ if categorical_keywords_str
287
  else []
288
  )
289
  unary_keywords = (
290
+ [kw.strip() for kw in unary_keywords_str.split(",")]
291
+ if unary_keywords_str
292
+ else []
 
293
  )
294
 
295
+ # Preprocess: pull categories referenced in binary keywords and add any missing ones
296
+ added_categories: list[str] = []
297
+ extra_cats = _extract_categories_from_binary(binary_keywords_str or "")
298
+ if extra_cats:
299
+ existing_lower = {kw.lower() for kw in categorical_keywords}
300
+ for cat in extra_cats:
301
+ if cat and cat.lower() not in existing_lower:
302
+ categorical_keywords.append(cat)
303
+ existing_lower.add(cat.lower())
304
+ added_categories.append(cat)
305
+
306
+ # Parse binary keywords with category info (if provided)
307
+ (
308
+ binary_keywords_list,
309
+ batched_binary_predicates,
310
+ binary_input_warnings,
311
+ ) = _parse_binary_keywords(binary_keywords_str or "", categorical_keywords)
312
+ if added_categories:
313
+ binary_input_warnings.append(
314
+ "Auto-added categorical keywords from binary relations: "
315
+ + ", ".join(added_categories)
316
+ )
317
+
318
+ skip_binary = len(binary_keywords_list) == 0
319
+
320
  # Debug: Print what we're sending to the pipeline
321
  print("\n" + "=" * 80)
322
+ print("INPUT TO LASER PIPELINE:")
323
  print(f" categorical_keywords: {categorical_keywords}")
324
  print(f" unary_keywords: {unary_keywords}")
325
+ print(f" binary_keywords (raw parsed): {binary_keywords_list}")
326
+ print(f" batched_binary_predicates: {batched_binary_predicates}")
327
+ print(f" auto_add_not_unary: {auto_add_not_unary}")
328
+ print(f" skip_binary: {skip_binary}")
329
  print("=" * 80 + "\n")
330
 
331
  # Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py
332
+ object_pairs: list[tuple[int, int]] = []
333
+
334
+ extra_forward_kwargs = {}
335
+ if batched_binary_predicates is not None and not skip_binary:
336
+ # Use category-based filtering of binary pairs
337
+ extra_forward_kwargs["batched_binary_predicates"] = batched_binary_predicates
338
+ extra_forward_kwargs["topk_cate"] = 1 # as requested
339
+
340
+ extra_forward_kwargs["auto_add_not_unary"] = bool(auto_add_not_unary)
341
+ if skip_binary:
342
+ extra_forward_kwargs["disable_binary"] = True
343
 
344
  results = vine_pipe(
345
  inputs=video_file,
346
  categorical_keywords=categorical_keywords,
347
  unary_keywords=unary_keywords,
348
+ binary_keywords=binary_keywords_list,
349
  object_pairs=object_pairs,
350
  segmentation_method="grounding_dino_sam2",
351
  return_top_k=5,
 
356
  text_threshold=text_threshold,
357
  target_fps=output_fps,
358
  binary_confidence_threshold=binary_confidence_threshold,
359
+ **extra_forward_kwargs,
360
  )
361
 
362
  # Debug: Print what the pipeline returned
 
391
  result_video_path = str(candidates[0]) if candidates else None
392
  summary = results_dict.get("summary") or {}
393
 
394
+ # Attach any binary category parsing warnings into the summary JSON
395
+ if binary_input_warnings:
396
+ if "binary_input_warnings" in summary:
397
+ summary["binary_input_warnings"].extend(binary_input_warnings)
398
+ else:
399
+ summary["binary_input_warnings"] = binary_input_warnings
400
+
401
  if result_video_path and os.path.exists(result_video_path):
402
  gradio_tmp = (
403
  Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
 
451
  """
452
  Build a Blocks context that works across Gradio versions.
453
  """
454
+ blocks_kwargs = {"title": "LASER Scene Graph Demo"}
455
  soft_theme = None
456
 
457
  if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
 
470
  with _create_blocks() as demo:
471
  gr.Markdown(
472
  """
473
+ # 🎬 LASER: Spatio-temporal Scene Graphs for Video
474
 
475
+ Turn any MP4 into a spatio-temporal scene graph with LASER - our 100-million parameter foundation model for scene-graph generation. LASER trains on 87K+ open-domain videos using a neurosymbolic caption-to-scene alignment pipeline, so it learns fine-grained video semantics without human labels.
476
+
477
+ Upload an MP4 and sketch the scene graph you care about: specify the objects, actions, and interactions you want, and LASER will assemble a spatio-temporal scene graph plus an annotated video.
478
  """
479
  )
480
 
481
  with gr.Row():
482
  # Left column: Inputs
483
  with gr.Column(scale=1):
484
+ gr.Markdown("### Scene Graph Inputs")
485
 
486
  video_input = _video_component("Upload Video (MP4 only)", is_output=False)
487
  gr.Markdown("*Note: Only MP4 format is currently supported*")
488
 
489
+ gr.Markdown("#### Scene Graph Queries")
490
  categorical_input = gr.Textbox(
491
  label="Categorical Keywords",
492
  placeholder="e.g., person, car, dog",
 
501
  )
502
  binary_input = gr.Textbox(
503
  label="Binary Keywords",
504
+ placeholder="e.g., behind(person, dog), bite(dog, frisbee)",
505
+ info=(
506
+ "Object-to-object interactions to detect. "
507
+ "Use format: relation(from_category, to_category). "
508
+ "Example: 'behind(person, dog), bite(dog, frisbee)'. "
509
+ "If you omit '(from,to)', the relation will be applied to all object pairs (default behavior). "
510
+ "Leave blank to skip binary relation search entirely."
511
+ ),
512
+ )
513
+
514
+ add_not_unary_checkbox = gr.Checkbox(
515
+ label="Also query 'not <unary>' predicates",
516
+ value=False,
517
+ info="If enabled, for each unary keyword X, also query 'not X'.",
518
  )
519
 
520
  gr.Markdown("#### Processing Settings")
 
545
  label="Binary Relation Confidence Threshold",
546
  minimum=0.0,
547
  maximum=1.0,
548
+ value=.5,
549
  step=0.05,
550
  info="Minimum confidence to show binary relations and object pairs",
551
  )
 
554
 
555
  # Right column: Outputs
556
  with gr.Column(scale=1):
557
+ gr.Markdown("### Scene Graph Results")
558
 
559
  video_output = _video_component("Annotated Video Output", is_output=True)
560
 
561
+ gr.Markdown("### Scene Graph Summary")
562
+ summary_output = gr.JSON(label="Scene Graph / Detected Events")
563
 
564
  gr.Markdown(
565
  """
566
  ---
567
+ ### How to Use LASER
568
+ 1. Upload an MP4 (we validate the format for you).
569
+ 2. Describe the **nodes** of your spatio-temporal scene graph with categorical keywords (objects) and unary keywords (single-object actions).
570
+ 3. Wire up **binary** relations:
571
+ - Use the structured form `relation(from_category, to_category)` (e.g., `behind(person, dog), bite(dog, frisbee)`) to limit relations to those category pairs.
572
+ - Or list relation names (`chasing, carrying`) to evaluate all object pairs.
573
+ - Leave the field blank to skip binary relations entirely (no pair search or binary predicates).
574
+ - Categories referenced inside binary relations are auto-added to the categorical list for you.
575
+ 4. Optionally enable automatic `'not <unary>'` predicates.
576
+ 5. Adjust processing settings if needed and click **Process Video** to receive an annotated video plus the serialized scene graph.
577
+
578
+ More to explore:
579
+ - LASER paper (ICLR'25): https://arxiv.org/abs/2304.07647 | Demo: https://huggingface.co/spaces/jiani-huang/LASER | Code: https://github.com/video-fm/LASER
580
+ - ESCA paper: https://arxiv.org/abs/2510.15963 | Code: https://github.com/video-fm/ESCA | Model: https://huggingface.co/video-fm/vine_v0 | Dataset: https://huggingface.co/datasets/video-fm/ESCA-video-87K
581
+ - Meet us at **NeurIPS 2025** (San Diego, Exhibit Hall C/D/E, Booth #4908 - Wed, Dec 3 - 11:00 a.m.-2:00 p.m. PST) for the foundation model demo, code, and full paper.
582
  """
583
  )
584
 
 
589
  categorical_input,
590
  unary_input,
591
  binary_input,
592
+ add_not_unary_checkbox,
593
  fps_input,
594
  box_threshold_input,
595
  text_threshold_input,
vine_hf/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/__init__.cpython-310.pyc and b/vine_hf/__pycache__/__init__.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_config.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_config.cpython-310.pyc and b/vine_hf/__pycache__/vine_config.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_model.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_model.cpython-310.pyc and b/vine_hf/__pycache__/vine_model.cpython-310.pyc differ
 
vine_hf/vine_config.py CHANGED
@@ -41,6 +41,7 @@ class VineConfig(PretrainedConfig):
41
  interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
42
  debug_visualizations: bool = False,
43
  device: Optional[Union[str, int]] = None,
 
44
  **kwargs: Any,
45
  ):
46
  self.model_name = model_name
@@ -77,6 +78,7 @@ class VineConfig(PretrainedConfig):
77
  self.return_valid_pairs = return_valid_pairs
78
  self.interested_object_pairs = interested_object_pairs or []
79
  self.debug_visualizations = debug_visualizations
 
80
 
81
  if isinstance(device, int):
82
  self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
 
41
  interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
42
  debug_visualizations: bool = False,
43
  device: Optional[Union[str, int]] = None,
44
+ auto_add_not_unary: bool = False,
45
  **kwargs: Any,
46
  ):
47
  self.model_name = model_name
 
78
  self.return_valid_pairs = return_valid_pairs
79
  self.interested_object_pairs = interested_object_pairs or []
80
  self.debug_visualizations = debug_visualizations
81
+ self.auto_add_not_unary = auto_add_not_unary
82
 
83
  if isinstance(device, int):
84
  self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
vine_hf/vine_model.py CHANGED
@@ -326,6 +326,7 @@ class VineModel(PreTrainedModel):
326
  debug_visualizations: Optional[bool] = None,
327
  **kwargs: Any,
328
  ) -> Dict[str, Any]:
 
329
  if unary_keywords is None:
330
  unary_keywords = []
331
  if binary_keywords is None:
@@ -353,6 +354,8 @@ class VineModel(PreTrainedModel):
353
  multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
354
  output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
355
  output_embeddings = kwargs.pop("output_embeddings", False)
 
 
356
 
357
  batched_video_ids = [0]
358
 
@@ -385,12 +388,12 @@ class VineModel(PreTrainedModel):
385
 
386
  batched_names = [list(categorical_keywords)]
387
  batched_unary_kws = [list(unary_keywords)]
388
- batched_binary_kws = [list(binary_keywords)]
389
 
390
  batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
391
 
392
  # Auto-generate all object pairs if binary_keywords provided but object_pairs is empty
393
- if not object_pairs and binary_keywords:
394
  # Get all unique object IDs across all frames
395
  all_object_ids = set()
396
  for frame_masks in masks.values():
@@ -404,7 +407,10 @@ class VineModel(PreTrainedModel):
404
  if from_oid != to_oid:
405
  object_pairs.append((from_oid, to_oid))
406
 
407
- print(f"Auto-generated {len(object_pairs)} bidirectional object pairs for binary relation detection: {object_pairs}")
 
 
 
408
 
409
  if object_pairs:
410
  for frame_id, frame_masks in masks.items():
@@ -416,12 +422,34 @@ class VineModel(PreTrainedModel):
416
  batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
417
 
418
  batched_video_splits = [0]
419
- batched_binary_predicates = [None]
420
 
421
- def fill_empty(batched_kw):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  new_batched = []
423
  for kw_ls in batched_kw:
424
  if len(kw_ls) == 0:
 
 
 
425
  new_batched.append([dummy_str])
426
  else:
427
  new_batched.append(list(kw_ls))
@@ -429,7 +457,7 @@ class VineModel(PreTrainedModel):
429
 
430
  batched_names = fill_empty(batched_names)
431
  batched_unary_kws = fill_empty(batched_unary_kws)
432
- batched_binary_kws = fill_empty(batched_binary_kws)
433
 
434
  dummy_prob = torch.tensor(0.0, device=self._device)
435
 
@@ -673,6 +701,31 @@ class VineModel(PreTrainedModel):
673
  batched_obj_per_cate[vid_id] = obj_per_cate
674
 
675
  # Step 4: binary pairs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
677
  frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
678
  current_info = (0, 0)
@@ -701,6 +754,8 @@ class VineModel(PreTrainedModel):
701
  selected_pairs = set(batched_obj_pairs)
702
  else:
703
  for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
 
 
704
  topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
705
  for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
706
  if (
@@ -925,6 +980,21 @@ class VineModel(PreTrainedModel):
925
  inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
926
  return self._image_features_checkpoint(model, inputs["pixel_values"])
927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
  # ------------------------------------------------------------------ #
929
  # High-level predict API
930
  # ------------------------------------------------------------------ #
@@ -942,7 +1012,35 @@ class VineModel(PreTrainedModel):
942
  return_valid_pairs: Optional[bool] = None,
943
  interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
944
  debug_visualizations: Optional[bool] = None,
 
 
 
945
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946
  with torch.no_grad():
947
  outputs = self.forward(
948
  video_frames=video_frames,
@@ -956,6 +1054,7 @@ class VineModel(PreTrainedModel):
956
  return_valid_pairs=return_valid_pairs,
957
  interested_object_pairs=interested_object_pairs,
958
  debug_visualizations=debug_visualizations,
 
959
  )
960
 
961
  formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
 
326
  debug_visualizations: Optional[bool] = None,
327
  **kwargs: Any,
328
  ) -> Dict[str, Any]:
329
+ disable_binary = kwargs.pop("disable_binary", False)
330
  if unary_keywords is None:
331
  unary_keywords = []
332
  if binary_keywords is None:
 
354
  multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
355
  output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
356
  output_embeddings = kwargs.pop("output_embeddings", False)
357
+ batched_binary_predicates_arg = kwargs.pop("batched_binary_predicates", None)
358
+ skip_binary = disable_binary or len(binary_keywords) == 0
359
 
360
  batched_video_ids = [0]
361
 
 
388
 
389
  batched_names = [list(categorical_keywords)]
390
  batched_unary_kws = [list(unary_keywords)]
391
+ batched_binary_kws = [list(binary_keywords)] if not skip_binary else [[]]
392
 
393
  batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
394
 
395
  # Auto-generate all object pairs if binary_keywords provided but object_pairs is empty
396
+ if not object_pairs and binary_keywords and not skip_binary:
397
  # Get all unique object IDs across all frames
398
  all_object_ids = set()
399
  for frame_masks in masks.values():
 
407
  if from_oid != to_oid:
408
  object_pairs.append((from_oid, to_oid))
409
 
410
+ print(
411
+ f"Auto-generated {len(object_pairs)} bidirectional object "
412
+ f"pairs for binary relation detection: {object_pairs}"
413
+ )
414
 
415
  if object_pairs:
416
  for frame_id, frame_masks in masks.items():
 
422
  batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
423
 
424
  batched_video_splits = [0]
 
425
 
426
+ # Prepare binary predicates per video (single-video setup)
427
+ if batched_binary_predicates_arg is None:
428
+ batched_binary_predicates = [None]
429
+ elif skip_binary:
430
+ batched_binary_predicates = [None]
431
+ else:
432
+ if isinstance(batched_binary_predicates_arg, dict):
433
+ preds_for_vid0 = batched_binary_predicates_arg.get(0, [])
434
+ if preds_for_vid0:
435
+ batched_binary_predicates = [preds_for_vid0]
436
+ else:
437
+ batched_binary_predicates = [None]
438
+ else:
439
+ if isinstance(batched_binary_predicates_arg, (list, tuple)) and len(
440
+ batched_binary_predicates_arg
441
+ ) > 0:
442
+ batched_binary_predicates = [list(batched_binary_predicates_arg)]
443
+ else:
444
+ batched_binary_predicates = [None]
445
+
446
+ def fill_empty(batched_kw, *, allow_dummy: bool = True):
447
  new_batched = []
448
  for kw_ls in batched_kw:
449
  if len(kw_ls) == 0:
450
+ if not allow_dummy:
451
+ new_batched.append([])
452
+ continue
453
  new_batched.append([dummy_str])
454
  else:
455
  new_batched.append(list(kw_ls))
 
457
 
458
  batched_names = fill_empty(batched_names)
459
  batched_unary_kws = fill_empty(batched_unary_kws)
460
+ batched_binary_kws = fill_empty(batched_binary_kws, allow_dummy=not skip_binary)
461
 
462
  dummy_prob = torch.tensor(0.0, device=self._device)
463
 
 
701
  batched_obj_per_cate[vid_id] = obj_per_cate
702
 
703
  # Step 4: binary pairs
704
+ if skip_binary:
705
+ batched_image_binary_probs = [{} for _ in range(batch_size)]
706
+ batched_obj_pair_features: Dict[int, torch.Tensor] = {
707
+ vid: torch.tensor([]) for vid in range(batch_size)
708
+ }
709
+ result: Dict[str, Any] = {
710
+ "categorical_probs": batched_image_cate_probs,
711
+ "unary_probs": batched_image_unary_probs,
712
+ "binary_probs": batched_image_binary_probs,
713
+ "dummy_prob": dummy_prob,
714
+ }
715
+
716
+ if output_embeddings:
717
+ embeddings_dict = {
718
+ "cate_obj_clip_features": batched_obj_cate_features,
719
+ "cate_obj_name_features": batched_obj_name_features,
720
+ "unary_obj_features": batched_obj_unary_features,
721
+ "unary_nl_features": batched_unary_nl_features,
722
+ "binary_obj_pair_features": batched_obj_pair_features,
723
+ "binary_nl_features": batched_binary_nl_features,
724
+ }
725
+ result.update(embeddings_dict)
726
+
727
+ return result
728
+
729
  batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
730
  frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
731
  current_info = (0, 0)
 
754
  selected_pairs = set(batched_obj_pairs)
755
  else:
756
  for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
757
+ if binary_predicates is None:
758
+ continue
759
  topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
760
  for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
761
  if (
 
980
  inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
981
  return self._image_features_checkpoint(model, inputs["pixel_values"])
982
 
983
+ def _augment_unary_with_negation(self, unary_keywords: List[str]) -> List[str]:
984
+ """
985
+ Given unary predicates like ["running", "walking"], add "not running",
986
+ "not walking" if they are not already present (case-insensitive).
987
+ """
988
+ base = [kw for kw in unary_keywords if isinstance(kw, str) and kw.strip()]
989
+ seen_lower = {kw.lower() for kw in base}
990
+ augmented = list(base)
991
+ for kw in base:
992
+ neg = f"not {kw}"
993
+ if neg.lower() not in seen_lower:
994
+ augmented.append(neg)
995
+ seen_lower.add(neg.lower())
996
+ return augmented
997
+
998
  # ------------------------------------------------------------------ #
999
  # High-level predict API
1000
  # ------------------------------------------------------------------ #
 
1012
  return_valid_pairs: Optional[bool] = None,
1013
  interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
1014
  debug_visualizations: Optional[bool] = None,
1015
+ auto_add_not_unary: Optional[bool] = None,
1016
+ batched_binary_predicates: Optional[Dict[int, List[Tuple[str, str, str]]]] = None,
1017
+ topk_cate: Optional[int] = None,
1018
  ) -> Dict[str, Any]:
1019
+ if unary_keywords is None:
1020
+ unary_keywords = []
1021
+ else:
1022
+ unary_keywords = list(unary_keywords)
1023
+
1024
+ if binary_keywords is None:
1025
+ binary_keywords = []
1026
+ else:
1027
+ binary_keywords = list(binary_keywords)
1028
+
1029
+ if object_pairs is None:
1030
+ object_pairs = []
1031
+
1032
+ if auto_add_not_unary is None:
1033
+ auto_add_not_unary = getattr(self.config, "auto_add_not_unary", False)
1034
+
1035
+ if auto_add_not_unary:
1036
+ unary_keywords = self._augment_unary_with_negation(unary_keywords)
1037
+
1038
+ forward_extra_kwargs: Dict[str, Any] = {}
1039
+ if batched_binary_predicates is not None:
1040
+ forward_extra_kwargs["batched_binary_predicates"] = batched_binary_predicates
1041
+ if topk_cate is not None:
1042
+ forward_extra_kwargs["topk_cate"] = topk_cate
1043
+
1044
  with torch.no_grad():
1045
  outputs = self.forward(
1046
  video_frames=video_frames,
 
1054
  return_valid_pairs=return_valid_pairs,
1055
  interested_object_pairs=interested_object_pairs,
1056
  debug_visualizations=debug_visualizations,
1057
+ **forward_extra_kwargs,
1058
  )
1059
 
1060
  formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
vine_hf/vine_pipeline.py CHANGED
@@ -107,6 +107,14 @@ class VinePipeline(Pipeline):
107
  forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
108
  if "object_pairs" in kwargs:
109
  forward_kwargs["object_pairs"] = kwargs["object_pairs"]
 
 
 
 
 
 
 
 
110
  if "return_flattened_segments" in kwargs:
111
  forward_kwargs["return_flattened_segments"] = kwargs[
112
  "return_flattened_segments"
@@ -126,7 +134,9 @@ class VinePipeline(Pipeline):
126
  if "self.visualize" in kwargs:
127
  postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
128
  if "binary_confidence_threshold" in kwargs:
129
- postprocess_kwargs["binary_confidence_threshold"] = kwargs["binary_confidence_threshold"]
 
 
130
 
131
  return preprocess_kwargs, forward_kwargs, postprocess_kwargs
132
 
 
107
  forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
108
  if "object_pairs" in kwargs:
109
  forward_kwargs["object_pairs"] = kwargs["object_pairs"]
110
+ if "batched_binary_predicates" in kwargs:
111
+ # New: per-video (rel, from_cat, to_cat) triples for binary filtering
112
+ forward_kwargs["batched_binary_predicates"] = kwargs["batched_binary_predicates"]
113
+ if "topk_cate" in kwargs:
114
+ # New: override topk_cate when binary filtering is requested
115
+ forward_kwargs["topk_cate"] = kwargs["topk_cate"]
116
+ if "auto_add_not_unary" in kwargs:
117
+ forward_kwargs["auto_add_not_unary"] = kwargs["auto_add_not_unary"]
118
  if "return_flattened_segments" in kwargs:
119
  forward_kwargs["return_flattened_segments"] = kwargs[
120
  "return_flattened_segments"
 
134
  if "self.visualize" in kwargs:
135
  postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
136
  if "binary_confidence_threshold" in kwargs:
137
+ postprocess_kwargs["binary_confidence_threshold"] = kwargs[
138
+ "binary_confidence_threshold"
139
+ ]
140
 
141
  return preprocess_kwargs, forward_kwargs, postprocess_kwargs
142