Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
d3c563b
1
Parent(s):
21f4849
copying in saved code
Browse files- app.py +257 -30
- vine_hf/__pycache__/__init__.cpython-310.pyc +0 -0
- vine_hf/__pycache__/vine_config.cpython-310.pyc +0 -0
- vine_hf/__pycache__/vine_model.cpython-310.pyc +0 -0
- vine_hf/vine_config.py +2 -0
- vine_hf/vine_model.py +105 -6
- vine_hf/vine_pipeline.py +11 -1
app.py
CHANGED
|
@@ -60,12 +60,155 @@ print(
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
@lru_cache(maxsize=1)
|
| 66 |
def _load_vine_pipeline():
|
| 67 |
"""
|
| 68 |
-
Lazy-load and cache the
|
| 69 |
"""
|
| 70 |
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 71 |
|
|
@@ -84,6 +227,7 @@ def _load_vine_pipeline():
|
|
| 84 |
debug_visualizations=False,
|
| 85 |
device="cuda",
|
| 86 |
categorical_pool="max",
|
|
|
|
| 87 |
)
|
| 88 |
model = VineModel(config)
|
| 89 |
return VinePipeline(
|
|
@@ -104,6 +248,7 @@ def process_video(
|
|
| 104 |
categorical_keywords,
|
| 105 |
unary_keywords,
|
| 106 |
binary_keywords,
|
|
|
|
| 107 |
output_fps,
|
| 108 |
box_threshold,
|
| 109 |
text_threshold,
|
|
@@ -121,34 +266,86 @@ def process_video(
|
|
| 121 |
if not isinstance(video_file, (str, Path)):
|
| 122 |
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
categorical_keywords = (
|
| 125 |
-
[kw.strip() for kw in
|
| 126 |
-
if
|
| 127 |
else []
|
| 128 |
)
|
| 129 |
unary_keywords = (
|
| 130 |
-
[kw.strip() for kw in
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
[kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
|
| 134 |
)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
# Debug: Print what we're sending to the pipeline
|
| 137 |
print("\n" + "=" * 80)
|
| 138 |
-
print("INPUT TO
|
| 139 |
print(f" categorical_keywords: {categorical_keywords}")
|
| 140 |
print(f" unary_keywords: {unary_keywords}")
|
| 141 |
-
print(f" binary_keywords: {
|
|
|
|
|
|
|
|
|
|
| 142 |
print("=" * 80 + "\n")
|
| 143 |
|
| 144 |
# Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py
|
| 145 |
-
object_pairs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
results = vine_pipe(
|
| 148 |
inputs=video_file,
|
| 149 |
categorical_keywords=categorical_keywords,
|
| 150 |
unary_keywords=unary_keywords,
|
| 151 |
-
binary_keywords=
|
| 152 |
object_pairs=object_pairs,
|
| 153 |
segmentation_method="grounding_dino_sam2",
|
| 154 |
return_top_k=5,
|
|
@@ -159,6 +356,7 @@ def process_video(
|
|
| 159 |
text_threshold=text_threshold,
|
| 160 |
target_fps=output_fps,
|
| 161 |
binary_confidence_threshold=binary_confidence_threshold,
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
# Debug: Print what the pipeline returned
|
|
@@ -193,6 +391,13 @@ def process_video(
|
|
| 193 |
result_video_path = str(candidates[0]) if candidates else None
|
| 194 |
summary = results_dict.get("summary") or {}
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if result_video_path and os.path.exists(result_video_path):
|
| 197 |
gradio_tmp = (
|
| 198 |
Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
|
|
@@ -246,7 +451,7 @@ def _create_blocks():
|
|
| 246 |
"""
|
| 247 |
Build a Blocks context that works across Gradio versions.
|
| 248 |
"""
|
| 249 |
-
blocks_kwargs = {"title": "
|
| 250 |
soft_theme = None
|
| 251 |
|
| 252 |
if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
|
|
@@ -265,21 +470,23 @@ def _create_blocks():
|
|
| 265 |
with _create_blocks() as demo:
|
| 266 |
gr.Markdown(
|
| 267 |
"""
|
| 268 |
-
# 🎬
|
| 269 |
|
| 270 |
-
|
|
|
|
|
|
|
| 271 |
"""
|
| 272 |
)
|
| 273 |
|
| 274 |
with gr.Row():
|
| 275 |
# Left column: Inputs
|
| 276 |
with gr.Column(scale=1):
|
| 277 |
-
gr.Markdown("###
|
| 278 |
|
| 279 |
video_input = _video_component("Upload Video (MP4 only)", is_output=False)
|
| 280 |
gr.Markdown("*Note: Only MP4 format is currently supported*")
|
| 281 |
|
| 282 |
-
gr.Markdown("####
|
| 283 |
categorical_input = gr.Textbox(
|
| 284 |
label="Categorical Keywords",
|
| 285 |
placeholder="e.g., person, car, dog",
|
|
@@ -294,8 +501,20 @@ with _create_blocks() as demo:
|
|
| 294 |
)
|
| 295 |
binary_input = gr.Textbox(
|
| 296 |
label="Binary Keywords",
|
| 297 |
-
placeholder="e.g.,
|
| 298 |
-
info=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
gr.Markdown("#### Processing Settings")
|
|
@@ -326,7 +545,7 @@ with _create_blocks() as demo:
|
|
| 326 |
label="Binary Relation Confidence Threshold",
|
| 327 |
minimum=0.0,
|
| 328 |
maximum=1.0,
|
| 329 |
-
value
|
| 330 |
step=0.05,
|
| 331 |
info="Minimum confidence to show binary relations and object pairs",
|
| 332 |
)
|
|
@@ -335,24 +554,31 @@ with _create_blocks() as demo:
|
|
| 335 |
|
| 336 |
# Right column: Outputs
|
| 337 |
with gr.Column(scale=1):
|
| 338 |
-
gr.Markdown("### Results")
|
| 339 |
|
| 340 |
video_output = _video_component("Annotated Video Output", is_output=True)
|
| 341 |
|
| 342 |
-
gr.Markdown("###
|
| 343 |
-
summary_output = gr.JSON(label="
|
| 344 |
|
| 345 |
gr.Markdown(
|
| 346 |
"""
|
| 347 |
---
|
| 348 |
-
### How to Use
|
| 349 |
-
1. Upload an MP4
|
| 350 |
-
2.
|
| 351 |
-
3.
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
"""
|
| 357 |
)
|
| 358 |
|
|
@@ -363,6 +589,7 @@ with _create_blocks() as demo:
|
|
| 363 |
categorical_input,
|
| 364 |
unary_input,
|
| 365 |
binary_input,
|
|
|
|
| 366 |
fps_input,
|
| 367 |
box_threshold_input,
|
| 368 |
text_threshold_input,
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
| 63 |
+
def _split_top_level_commas(s: str):
|
| 64 |
+
"""
|
| 65 |
+
Split a string on commas that are NOT inside parentheses.
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
"behind(person, dog), bite(dog, frisbee)"
|
| 69 |
+
-> ["behind(person, dog)", "bite(dog, frisbee)"]
|
| 70 |
+
"""
|
| 71 |
+
parts = []
|
| 72 |
+
buf = []
|
| 73 |
+
depth = 0
|
| 74 |
+
for ch in s:
|
| 75 |
+
if ch == "(":
|
| 76 |
+
depth += 1
|
| 77 |
+
buf.append(ch)
|
| 78 |
+
elif ch == ")":
|
| 79 |
+
if depth > 0:
|
| 80 |
+
depth -= 1
|
| 81 |
+
buf.append(ch)
|
| 82 |
+
elif ch == "," and depth == 0:
|
| 83 |
+
part = "".join(buf).strip()
|
| 84 |
+
if part:
|
| 85 |
+
parts.append(part)
|
| 86 |
+
buf = []
|
| 87 |
+
else:
|
| 88 |
+
buf.append(ch)
|
| 89 |
+
if buf:
|
| 90 |
+
part = "".join(buf).strip()
|
| 91 |
+
if part:
|
| 92 |
+
parts.append(part)
|
| 93 |
+
return parts
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _extract_categories_from_binary(binary_keywords_str: str) -> list[str]:
|
| 97 |
+
"""
|
| 98 |
+
Pull candidate category tokens from binary keyword strings, e.g. relation(a, b).
|
| 99 |
+
Only returns tokens when parentheses and two comma-separated entries exist.
|
| 100 |
+
"""
|
| 101 |
+
categories: list[str] = []
|
| 102 |
+
for kw in _split_top_level_commas(binary_keywords_str or ""):
|
| 103 |
+
lpar = kw.find("(")
|
| 104 |
+
rpar = kw.rfind(")")
|
| 105 |
+
if lpar == -1 or rpar <= lpar:
|
| 106 |
+
continue
|
| 107 |
+
inside = kw[lpar + 1 : rpar]
|
| 108 |
+
parts = [p.strip() for p in inside.split(",") if p.strip()]
|
| 109 |
+
if len(parts) == 2:
|
| 110 |
+
categories.extend(parts)
|
| 111 |
+
return categories
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _parse_binary_keywords(binary_keywords_str: str, categorical_keywords: list[str]):
|
| 115 |
+
"""
|
| 116 |
+
Parse binary keyword string like:
|
| 117 |
+
"behind(person, dog), bite(dog, frisbee)"
|
| 118 |
+
into:
|
| 119 |
+
- binary_keywords_list: list of raw strings (used as CLIP text)
|
| 120 |
+
- batched_binary_predicates: {0: [(rel_text, from_cat, to_cat), ...]} or None
|
| 121 |
+
- warnings: list of warning strings about invalid/mismatched categories
|
| 122 |
+
"""
|
| 123 |
+
if not binary_keywords_str:
|
| 124 |
+
return [], None, []
|
| 125 |
+
|
| 126 |
+
cat_map = {
|
| 127 |
+
kw.strip().lower(): kw.strip()
|
| 128 |
+
for kw in categorical_keywords
|
| 129 |
+
if isinstance(kw, str) and kw.strip()
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
entries = _split_top_level_commas(binary_keywords_str)
|
| 133 |
+
binary_keywords_list: list[str] = []
|
| 134 |
+
predicates: list[tuple[str, str, str]] = []
|
| 135 |
+
warnings: list[str] = []
|
| 136 |
+
|
| 137 |
+
for raw in entries:
|
| 138 |
+
kw = raw.strip()
|
| 139 |
+
if not kw:
|
| 140 |
+
continue
|
| 141 |
+
# Always use the full raw keyword as the CLIP text string
|
| 142 |
+
binary_keywords_list.append(kw)
|
| 143 |
+
|
| 144 |
+
lpar = kw.find("(")
|
| 145 |
+
rpar = kw.rfind(")")
|
| 146 |
+
if (lpar == -1 and rpar != -1) or (lpar != -1 and rpar == -1) or rpar < lpar:
|
| 147 |
+
msg = (
|
| 148 |
+
f"Binary keyword '{kw}' has mismatched parentheses; expected "
|
| 149 |
+
"relation(from_category, to_category)."
|
| 150 |
+
)
|
| 151 |
+
print(msg)
|
| 152 |
+
warnings.append(msg)
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
if lpar == -1 or rpar <= lpar:
|
| 156 |
+
# No explicit (from,to) part; treat as plain relation (no category filter)
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
inside = kw[lpar + 1 : rpar]
|
| 160 |
+
parts = inside.split(",")
|
| 161 |
+
if len(parts) != 2:
|
| 162 |
+
msg = (
|
| 163 |
+
f"Ignoring '(from,to)' part in binary keyword '{kw}': "
|
| 164 |
+
f"expected exactly two comma-separated items."
|
| 165 |
+
)
|
| 166 |
+
print(msg)
|
| 167 |
+
warnings.append(msg)
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
from_raw = parts[0].strip()
|
| 171 |
+
to_raw = parts[1].strip()
|
| 172 |
+
if not from_raw or not to_raw:
|
| 173 |
+
msg = f"Ignoring binary keyword '{kw}': empty from/to category."
|
| 174 |
+
print(msg)
|
| 175 |
+
warnings.append(msg)
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
canonical_from = cat_map.get(from_raw.lower())
|
| 179 |
+
canonical_to = cat_map.get(to_raw.lower())
|
| 180 |
+
|
| 181 |
+
if canonical_from is None:
|
| 182 |
+
msg = (
|
| 183 |
+
f"Binary keyword '{kw}': from-category '{from_raw}' does not "
|
| 184 |
+
f"match any categorical keyword {categorical_keywords}."
|
| 185 |
+
)
|
| 186 |
+
print(msg)
|
| 187 |
+
warnings.append(msg)
|
| 188 |
+
if canonical_to is None:
|
| 189 |
+
msg = (
|
| 190 |
+
f"Binary keyword '{kw}': to-category '{to_raw}' does not "
|
| 191 |
+
f"match any categorical keyword {categorical_keywords}."
|
| 192 |
+
)
|
| 193 |
+
print(msg)
|
| 194 |
+
warnings.append(msg)
|
| 195 |
+
|
| 196 |
+
if canonical_from is None or canonical_to is None:
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
# Store (relation_text, from_category, to_category)
|
| 200 |
+
predicates.append((kw, canonical_from, canonical_to))
|
| 201 |
+
|
| 202 |
+
if not predicates:
|
| 203 |
+
return binary_keywords_list, None, warnings
|
| 204 |
+
|
| 205 |
+
return binary_keywords_list, {0: predicates}, warnings
|
| 206 |
|
| 207 |
|
| 208 |
@lru_cache(maxsize=1)
|
| 209 |
def _load_vine_pipeline():
|
| 210 |
"""
|
| 211 |
+
Lazy-load and cache the LASER (VINE HF) pipeline so we don't re-download/rebuild it on every request.
|
| 212 |
"""
|
| 213 |
from vine_hf import VineConfig, VineModel, VinePipeline
|
| 214 |
|
|
|
|
| 227 |
debug_visualizations=False,
|
| 228 |
device="cuda",
|
| 229 |
categorical_pool="max",
|
| 230 |
+
auto_add_not_unary=False, # UI will control this per-call
|
| 231 |
)
|
| 232 |
model = VineModel(config)
|
| 233 |
return VinePipeline(
|
|
|
|
| 248 |
categorical_keywords,
|
| 249 |
unary_keywords,
|
| 250 |
binary_keywords,
|
| 251 |
+
auto_add_not_unary,
|
| 252 |
output_fps,
|
| 253 |
box_threshold,
|
| 254 |
text_threshold,
|
|
|
|
| 266 |
if not isinstance(video_file, (str, Path)):
|
| 267 |
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
| 268 |
|
| 269 |
+
video_path = Path(video_file)
|
| 270 |
+
if video_path.suffix.lower() != ".mp4":
|
| 271 |
+
msg = (
|
| 272 |
+
"Please upload an MP4 file. LASER currently supports MP4 inputs for "
|
| 273 |
+
"scene-graph generation."
|
| 274 |
+
)
|
| 275 |
+
print(msg)
|
| 276 |
+
return None, {"error": msg}
|
| 277 |
+
video_file = str(video_path)
|
| 278 |
+
|
| 279 |
+
# Keep original strings for parsing
|
| 280 |
+
categorical_keywords_str = categorical_keywords
|
| 281 |
+
unary_keywords_str = unary_keywords
|
| 282 |
+
binary_keywords_str = binary_keywords
|
| 283 |
+
|
| 284 |
categorical_keywords = (
|
| 285 |
+
[kw.strip() for kw in categorical_keywords_str.split(",")]
|
| 286 |
+
if categorical_keywords_str
|
| 287 |
else []
|
| 288 |
)
|
| 289 |
unary_keywords = (
|
| 290 |
+
[kw.strip() for kw in unary_keywords_str.split(",")]
|
| 291 |
+
if unary_keywords_str
|
| 292 |
+
else []
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
+
# Preprocess: pull categories referenced in binary keywords and add any missing ones
|
| 296 |
+
added_categories: list[str] = []
|
| 297 |
+
extra_cats = _extract_categories_from_binary(binary_keywords_str or "")
|
| 298 |
+
if extra_cats:
|
| 299 |
+
existing_lower = {kw.lower() for kw in categorical_keywords}
|
| 300 |
+
for cat in extra_cats:
|
| 301 |
+
if cat and cat.lower() not in existing_lower:
|
| 302 |
+
categorical_keywords.append(cat)
|
| 303 |
+
existing_lower.add(cat.lower())
|
| 304 |
+
added_categories.append(cat)
|
| 305 |
+
|
| 306 |
+
# Parse binary keywords with category info (if provided)
|
| 307 |
+
(
|
| 308 |
+
binary_keywords_list,
|
| 309 |
+
batched_binary_predicates,
|
| 310 |
+
binary_input_warnings,
|
| 311 |
+
) = _parse_binary_keywords(binary_keywords_str or "", categorical_keywords)
|
| 312 |
+
if added_categories:
|
| 313 |
+
binary_input_warnings.append(
|
| 314 |
+
"Auto-added categorical keywords from binary relations: "
|
| 315 |
+
+ ", ".join(added_categories)
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
skip_binary = len(binary_keywords_list) == 0
|
| 319 |
+
|
| 320 |
# Debug: Print what we're sending to the pipeline
|
| 321 |
print("\n" + "=" * 80)
|
| 322 |
+
print("INPUT TO LASER PIPELINE:")
|
| 323 |
print(f" categorical_keywords: {categorical_keywords}")
|
| 324 |
print(f" unary_keywords: {unary_keywords}")
|
| 325 |
+
print(f" binary_keywords (raw parsed): {binary_keywords_list}")
|
| 326 |
+
print(f" batched_binary_predicates: {batched_binary_predicates}")
|
| 327 |
+
print(f" auto_add_not_unary: {auto_add_not_unary}")
|
| 328 |
+
print(f" skip_binary: {skip_binary}")
|
| 329 |
print("=" * 80 + "\n")
|
| 330 |
|
| 331 |
# Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py
|
| 332 |
+
object_pairs: list[tuple[int, int]] = []
|
| 333 |
+
|
| 334 |
+
extra_forward_kwargs = {}
|
| 335 |
+
if batched_binary_predicates is not None and not skip_binary:
|
| 336 |
+
# Use category-based filtering of binary pairs
|
| 337 |
+
extra_forward_kwargs["batched_binary_predicates"] = batched_binary_predicates
|
| 338 |
+
extra_forward_kwargs["topk_cate"] = 1 # as requested
|
| 339 |
+
|
| 340 |
+
extra_forward_kwargs["auto_add_not_unary"] = bool(auto_add_not_unary)
|
| 341 |
+
if skip_binary:
|
| 342 |
+
extra_forward_kwargs["disable_binary"] = True
|
| 343 |
|
| 344 |
results = vine_pipe(
|
| 345 |
inputs=video_file,
|
| 346 |
categorical_keywords=categorical_keywords,
|
| 347 |
unary_keywords=unary_keywords,
|
| 348 |
+
binary_keywords=binary_keywords_list,
|
| 349 |
object_pairs=object_pairs,
|
| 350 |
segmentation_method="grounding_dino_sam2",
|
| 351 |
return_top_k=5,
|
|
|
|
| 356 |
text_threshold=text_threshold,
|
| 357 |
target_fps=output_fps,
|
| 358 |
binary_confidence_threshold=binary_confidence_threshold,
|
| 359 |
+
**extra_forward_kwargs,
|
| 360 |
)
|
| 361 |
|
| 362 |
# Debug: Print what the pipeline returned
|
|
|
|
| 391 |
result_video_path = str(candidates[0]) if candidates else None
|
| 392 |
summary = results_dict.get("summary") or {}
|
| 393 |
|
| 394 |
+
# Attach any binary category parsing warnings into the summary JSON
|
| 395 |
+
if binary_input_warnings:
|
| 396 |
+
if "binary_input_warnings" in summary:
|
| 397 |
+
summary["binary_input_warnings"].extend(binary_input_warnings)
|
| 398 |
+
else:
|
| 399 |
+
summary["binary_input_warnings"] = binary_input_warnings
|
| 400 |
+
|
| 401 |
if result_video_path and os.path.exists(result_video_path):
|
| 402 |
gradio_tmp = (
|
| 403 |
Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
|
|
|
|
| 451 |
"""
|
| 452 |
Build a Blocks context that works across Gradio versions.
|
| 453 |
"""
|
| 454 |
+
blocks_kwargs = {"title": "LASER Scene Graph Demo"}
|
| 455 |
soft_theme = None
|
| 456 |
|
| 457 |
if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
|
|
|
|
| 470 |
with _create_blocks() as demo:
|
| 471 |
gr.Markdown(
|
| 472 |
"""
|
| 473 |
+
# 🎬 LASER: Spatio-temporal Scene Graphs for Video
|
| 474 |
|
| 475 |
+
Turn any MP4 into a spatio-temporal scene graph with LASER - our 100-million parameter foundation model for scene-graph generation. LASER trains on 87K+ open-domain videos using a neurosymbolic caption-to-scene alignment pipeline, so it learns fine-grained video semantics without human labels.
|
| 476 |
+
|
| 477 |
+
Upload an MP4 and sketch the scene graph you care about: specify the objects, actions, and interactions you want, and LASER will assemble a spatio-temporal scene graph plus an annotated video.
|
| 478 |
"""
|
| 479 |
)
|
| 480 |
|
| 481 |
with gr.Row():
|
| 482 |
# Left column: Inputs
|
| 483 |
with gr.Column(scale=1):
|
| 484 |
+
gr.Markdown("### Scene Graph Inputs")
|
| 485 |
|
| 486 |
video_input = _video_component("Upload Video (MP4 only)", is_output=False)
|
| 487 |
gr.Markdown("*Note: Only MP4 format is currently supported*")
|
| 488 |
|
| 489 |
+
gr.Markdown("#### Scene Graph Queries")
|
| 490 |
categorical_input = gr.Textbox(
|
| 491 |
label="Categorical Keywords",
|
| 492 |
placeholder="e.g., person, car, dog",
|
|
|
|
| 501 |
)
|
| 502 |
binary_input = gr.Textbox(
|
| 503 |
label="Binary Keywords",
|
| 504 |
+
placeholder="e.g., behind(person, dog), bite(dog, frisbee)",
|
| 505 |
+
info=(
|
| 506 |
+
"Object-to-object interactions to detect. "
|
| 507 |
+
"Use format: relation(from_category, to_category). "
|
| 508 |
+
"Example: 'behind(person, dog), bite(dog, frisbee)'. "
|
| 509 |
+
"If you omit '(from,to)', the relation will be applied to all object pairs (default behavior). "
|
| 510 |
+
"Leave blank to skip binary relation search entirely."
|
| 511 |
+
),
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
add_not_unary_checkbox = gr.Checkbox(
|
| 515 |
+
label="Also query 'not <unary>' predicates",
|
| 516 |
+
value=False,
|
| 517 |
+
info="If enabled, for each unary keyword X, also query 'not X'.",
|
| 518 |
)
|
| 519 |
|
| 520 |
gr.Markdown("#### Processing Settings")
|
|
|
|
| 545 |
label="Binary Relation Confidence Threshold",
|
| 546 |
minimum=0.0,
|
| 547 |
maximum=1.0,
|
| 548 |
+
value=.5,
|
| 549 |
step=0.05,
|
| 550 |
info="Minimum confidence to show binary relations and object pairs",
|
| 551 |
)
|
|
|
|
| 554 |
|
| 555 |
# Right column: Outputs
|
| 556 |
with gr.Column(scale=1):
|
| 557 |
+
gr.Markdown("### Scene Graph Results")
|
| 558 |
|
| 559 |
video_output = _video_component("Annotated Video Output", is_output=True)
|
| 560 |
|
| 561 |
+
gr.Markdown("### Scene Graph Summary")
|
| 562 |
+
summary_output = gr.JSON(label="Scene Graph / Detected Events")
|
| 563 |
|
| 564 |
gr.Markdown(
|
| 565 |
"""
|
| 566 |
---
|
| 567 |
+
### How to Use LASER
|
| 568 |
+
1. Upload an MP4 (we validate the format for you).
|
| 569 |
+
2. Describe the **nodes** of your spatio-temporal scene graph with categorical keywords (objects) and unary keywords (single-object actions).
|
| 570 |
+
3. Wire up **binary** relations:
|
| 571 |
+
- Use the structured form `relation(from_category, to_category)` (e.g., `behind(person, dog), bite(dog, frisbee)`) to limit relations to those category pairs.
|
| 572 |
+
- Or list relation names (`chasing, carrying`) to evaluate all object pairs.
|
| 573 |
+
- Leave the field blank to skip binary relations entirely (no pair search or binary predicates).
|
| 574 |
+
- Categories referenced inside binary relations are auto-added to the categorical list for you.
|
| 575 |
+
4. Optionally enable automatic `'not <unary>'` predicates.
|
| 576 |
+
5. Adjust processing settings if needed and click **Process Video** to receive an annotated video plus the serialized scene graph.
|
| 577 |
+
|
| 578 |
+
More to explore:
|
| 579 |
+
- LASER paper (ICLR'25): https://arxiv.org/abs/2304.07647 | Demo: https://huggingface.co/spaces/jiani-huang/LASER | Code: https://github.com/video-fm/LASER
|
| 580 |
+
- ESCA paper: https://arxiv.org/abs/2510.15963 | Code: https://github.com/video-fm/ESCA | Model: https://huggingface.co/video-fm/vine_v0 | Dataset: https://huggingface.co/datasets/video-fm/ESCA-video-87K
|
| 581 |
+
- Meet us at **NeurIPS 2025** (San Diego, Exhibit Hall C/D/E, Booth #4908 - Wed, Dec 3 - 11:00 a.m.-2:00 p.m. PST) for the foundation model demo, code, and full paper.
|
| 582 |
"""
|
| 583 |
)
|
| 584 |
|
|
|
|
| 589 |
categorical_input,
|
| 590 |
unary_input,
|
| 591 |
binary_input,
|
| 592 |
+
add_not_unary_checkbox,
|
| 593 |
fps_input,
|
| 594 |
box_threshold_input,
|
| 595 |
text_threshold_input,
|
vine_hf/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/__init__.cpython-310.pyc and b/vine_hf/__pycache__/__init__.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/vine_config.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/vine_config.cpython-310.pyc and b/vine_hf/__pycache__/vine_config.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/vine_model.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/vine_model.cpython-310.pyc and b/vine_hf/__pycache__/vine_model.cpython-310.pyc differ
|
|
|
vine_hf/vine_config.py
CHANGED
|
@@ -41,6 +41,7 @@ class VineConfig(PretrainedConfig):
|
|
| 41 |
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 42 |
debug_visualizations: bool = False,
|
| 43 |
device: Optional[Union[str, int]] = None,
|
|
|
|
| 44 |
**kwargs: Any,
|
| 45 |
):
|
| 46 |
self.model_name = model_name
|
|
@@ -77,6 +78,7 @@ class VineConfig(PretrainedConfig):
|
|
| 77 |
self.return_valid_pairs = return_valid_pairs
|
| 78 |
self.interested_object_pairs = interested_object_pairs or []
|
| 79 |
self.debug_visualizations = debug_visualizations
|
|
|
|
| 80 |
|
| 81 |
if isinstance(device, int):
|
| 82 |
self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 41 |
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 42 |
debug_visualizations: bool = False,
|
| 43 |
device: Optional[Union[str, int]] = None,
|
| 44 |
+
auto_add_not_unary: bool = False,
|
| 45 |
**kwargs: Any,
|
| 46 |
):
|
| 47 |
self.model_name = model_name
|
|
|
|
| 78 |
self.return_valid_pairs = return_valid_pairs
|
| 79 |
self.interested_object_pairs = interested_object_pairs or []
|
| 80 |
self.debug_visualizations = debug_visualizations
|
| 81 |
+
self.auto_add_not_unary = auto_add_not_unary
|
| 82 |
|
| 83 |
if isinstance(device, int):
|
| 84 |
self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
|
vine_hf/vine_model.py
CHANGED
|
@@ -326,6 +326,7 @@ class VineModel(PreTrainedModel):
|
|
| 326 |
debug_visualizations: Optional[bool] = None,
|
| 327 |
**kwargs: Any,
|
| 328 |
) -> Dict[str, Any]:
|
|
|
|
| 329 |
if unary_keywords is None:
|
| 330 |
unary_keywords = []
|
| 331 |
if binary_keywords is None:
|
|
@@ -353,6 +354,8 @@ class VineModel(PreTrainedModel):
|
|
| 353 |
multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
|
| 354 |
output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
|
| 355 |
output_embeddings = kwargs.pop("output_embeddings", False)
|
|
|
|
|
|
|
| 356 |
|
| 357 |
batched_video_ids = [0]
|
| 358 |
|
|
@@ -385,12 +388,12 @@ class VineModel(PreTrainedModel):
|
|
| 385 |
|
| 386 |
batched_names = [list(categorical_keywords)]
|
| 387 |
batched_unary_kws = [list(unary_keywords)]
|
| 388 |
-
batched_binary_kws = [list(binary_keywords)]
|
| 389 |
|
| 390 |
batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 391 |
|
| 392 |
# Auto-generate all object pairs if binary_keywords provided but object_pairs is empty
|
| 393 |
-
if not object_pairs and binary_keywords:
|
| 394 |
# Get all unique object IDs across all frames
|
| 395 |
all_object_ids = set()
|
| 396 |
for frame_masks in masks.values():
|
|
@@ -404,7 +407,10 @@ class VineModel(PreTrainedModel):
|
|
| 404 |
if from_oid != to_oid:
|
| 405 |
object_pairs.append((from_oid, to_oid))
|
| 406 |
|
| 407 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
if object_pairs:
|
| 410 |
for frame_id, frame_masks in masks.items():
|
|
@@ -416,12 +422,34 @@ class VineModel(PreTrainedModel):
|
|
| 416 |
batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
|
| 417 |
|
| 418 |
batched_video_splits = [0]
|
| 419 |
-
batched_binary_predicates = [None]
|
| 420 |
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
new_batched = []
|
| 423 |
for kw_ls in batched_kw:
|
| 424 |
if len(kw_ls) == 0:
|
|
|
|
|
|
|
|
|
|
| 425 |
new_batched.append([dummy_str])
|
| 426 |
else:
|
| 427 |
new_batched.append(list(kw_ls))
|
|
@@ -429,7 +457,7 @@ class VineModel(PreTrainedModel):
|
|
| 429 |
|
| 430 |
batched_names = fill_empty(batched_names)
|
| 431 |
batched_unary_kws = fill_empty(batched_unary_kws)
|
| 432 |
-
batched_binary_kws = fill_empty(batched_binary_kws)
|
| 433 |
|
| 434 |
dummy_prob = torch.tensor(0.0, device=self._device)
|
| 435 |
|
|
@@ -673,6 +701,31 @@ class VineModel(PreTrainedModel):
|
|
| 673 |
batched_obj_per_cate[vid_id] = obj_per_cate
|
| 674 |
|
| 675 |
# Step 4: binary pairs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
|
| 677 |
frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
|
| 678 |
current_info = (0, 0)
|
|
@@ -701,6 +754,8 @@ class VineModel(PreTrainedModel):
|
|
| 701 |
selected_pairs = set(batched_obj_pairs)
|
| 702 |
else:
|
| 703 |
for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
|
|
|
|
|
|
|
| 704 |
topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
|
| 705 |
for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
|
| 706 |
if (
|
|
@@ -925,6 +980,21 @@ class VineModel(PreTrainedModel):
|
|
| 925 |
inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
|
| 926 |
return self._image_features_checkpoint(model, inputs["pixel_values"])
|
| 927 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
# ------------------------------------------------------------------ #
|
| 929 |
# High-level predict API
|
| 930 |
# ------------------------------------------------------------------ #
|
|
@@ -942,7 +1012,35 @@ class VineModel(PreTrainedModel):
|
|
| 942 |
return_valid_pairs: Optional[bool] = None,
|
| 943 |
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 944 |
debug_visualizations: Optional[bool] = None,
|
|
|
|
|
|
|
|
|
|
| 945 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 946 |
with torch.no_grad():
|
| 947 |
outputs = self.forward(
|
| 948 |
video_frames=video_frames,
|
|
@@ -956,6 +1054,7 @@ class VineModel(PreTrainedModel):
|
|
| 956 |
return_valid_pairs=return_valid_pairs,
|
| 957 |
interested_object_pairs=interested_object_pairs,
|
| 958 |
debug_visualizations=debug_visualizations,
|
|
|
|
| 959 |
)
|
| 960 |
|
| 961 |
formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
|
|
|
|
| 326 |
debug_visualizations: Optional[bool] = None,
|
| 327 |
**kwargs: Any,
|
| 328 |
) -> Dict[str, Any]:
|
| 329 |
+
disable_binary = kwargs.pop("disable_binary", False)
|
| 330 |
if unary_keywords is None:
|
| 331 |
unary_keywords = []
|
| 332 |
if binary_keywords is None:
|
|
|
|
| 354 |
multi_class = kwargs.pop("multi_class", getattr(self.config, "multi_class", False))
|
| 355 |
output_logit = kwargs.pop("output_logit", getattr(self.config, "output_logit", False))
|
| 356 |
output_embeddings = kwargs.pop("output_embeddings", False)
|
| 357 |
+
batched_binary_predicates_arg = kwargs.pop("batched_binary_predicates", None)
|
| 358 |
+
skip_binary = disable_binary or len(binary_keywords) == 0
|
| 359 |
|
| 360 |
batched_video_ids = [0]
|
| 361 |
|
|
|
|
| 388 |
|
| 389 |
batched_names = [list(categorical_keywords)]
|
| 390 |
batched_unary_kws = [list(unary_keywords)]
|
| 391 |
+
batched_binary_kws = [list(binary_keywords)] if not skip_binary else [[]]
|
| 392 |
|
| 393 |
batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
|
| 394 |
|
| 395 |
# Auto-generate all object pairs if binary_keywords provided but object_pairs is empty
|
| 396 |
+
if not object_pairs and binary_keywords and not skip_binary:
|
| 397 |
# Get all unique object IDs across all frames
|
| 398 |
all_object_ids = set()
|
| 399 |
for frame_masks in masks.values():
|
|
|
|
| 407 |
if from_oid != to_oid:
|
| 408 |
object_pairs.append((from_oid, to_oid))
|
| 409 |
|
| 410 |
+
print(
|
| 411 |
+
f"Auto-generated {len(object_pairs)} bidirectional object "
|
| 412 |
+
f"pairs for binary relation detection: {object_pairs}"
|
| 413 |
+
)
|
| 414 |
|
| 415 |
if object_pairs:
|
| 416 |
for frame_id, frame_masks in masks.items():
|
|
|
|
| 422 |
batched_obj_pairs.append((0, frame_id, (from_oid, to_oid)))
|
| 423 |
|
| 424 |
batched_video_splits = [0]
|
|
|
|
| 425 |
|
| 426 |
+
# Prepare binary predicates per video (single-video setup)
|
| 427 |
+
if batched_binary_predicates_arg is None:
|
| 428 |
+
batched_binary_predicates = [None]
|
| 429 |
+
elif skip_binary:
|
| 430 |
+
batched_binary_predicates = [None]
|
| 431 |
+
else:
|
| 432 |
+
if isinstance(batched_binary_predicates_arg, dict):
|
| 433 |
+
preds_for_vid0 = batched_binary_predicates_arg.get(0, [])
|
| 434 |
+
if preds_for_vid0:
|
| 435 |
+
batched_binary_predicates = [preds_for_vid0]
|
| 436 |
+
else:
|
| 437 |
+
batched_binary_predicates = [None]
|
| 438 |
+
else:
|
| 439 |
+
if isinstance(batched_binary_predicates_arg, (list, tuple)) and len(
|
| 440 |
+
batched_binary_predicates_arg
|
| 441 |
+
) > 0:
|
| 442 |
+
batched_binary_predicates = [list(batched_binary_predicates_arg)]
|
| 443 |
+
else:
|
| 444 |
+
batched_binary_predicates = [None]
|
| 445 |
+
|
| 446 |
+
def fill_empty(batched_kw, *, allow_dummy: bool = True):
|
| 447 |
new_batched = []
|
| 448 |
for kw_ls in batched_kw:
|
| 449 |
if len(kw_ls) == 0:
|
| 450 |
+
if not allow_dummy:
|
| 451 |
+
new_batched.append([])
|
| 452 |
+
continue
|
| 453 |
new_batched.append([dummy_str])
|
| 454 |
else:
|
| 455 |
new_batched.append(list(kw_ls))
|
|
|
|
| 457 |
|
| 458 |
batched_names = fill_empty(batched_names)
|
| 459 |
batched_unary_kws = fill_empty(batched_unary_kws)
|
| 460 |
+
batched_binary_kws = fill_empty(batched_binary_kws, allow_dummy=not skip_binary)
|
| 461 |
|
| 462 |
dummy_prob = torch.tensor(0.0, device=self._device)
|
| 463 |
|
|
|
|
| 701 |
batched_obj_per_cate[vid_id] = obj_per_cate
|
| 702 |
|
| 703 |
# Step 4: binary pairs
|
| 704 |
+
if skip_binary:
|
| 705 |
+
batched_image_binary_probs = [{} for _ in range(batch_size)]
|
| 706 |
+
batched_obj_pair_features: Dict[int, torch.Tensor] = {
|
| 707 |
+
vid: torch.tensor([]) for vid in range(batch_size)
|
| 708 |
+
}
|
| 709 |
+
result: Dict[str, Any] = {
|
| 710 |
+
"categorical_probs": batched_image_cate_probs,
|
| 711 |
+
"unary_probs": batched_image_unary_probs,
|
| 712 |
+
"binary_probs": batched_image_binary_probs,
|
| 713 |
+
"dummy_prob": dummy_prob,
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
if output_embeddings:
|
| 717 |
+
embeddings_dict = {
|
| 718 |
+
"cate_obj_clip_features": batched_obj_cate_features,
|
| 719 |
+
"cate_obj_name_features": batched_obj_name_features,
|
| 720 |
+
"unary_obj_features": batched_obj_unary_features,
|
| 721 |
+
"unary_nl_features": batched_unary_nl_features,
|
| 722 |
+
"binary_obj_pair_features": batched_obj_pair_features,
|
| 723 |
+
"binary_nl_features": batched_binary_nl_features,
|
| 724 |
+
}
|
| 725 |
+
result.update(embeddings_dict)
|
| 726 |
+
|
| 727 |
+
return result
|
| 728 |
+
|
| 729 |
batched_cropped_obj_pairs: Dict[int, List[np.ndarray]] = {}
|
| 730 |
frame_splits: Dict[Tuple[int, int], Dict[str, int]] = {}
|
| 731 |
current_info = (0, 0)
|
|
|
|
| 754 |
selected_pairs = set(batched_obj_pairs)
|
| 755 |
else:
|
| 756 |
for bp_vid, binary_predicates in enumerate(batched_binary_predicates):
|
| 757 |
+
if binary_predicates is None:
|
| 758 |
+
continue
|
| 759 |
topk_cate_candidates = batched_topk_cate_candidates[bp_vid]
|
| 760 |
for (rel_name, from_obj_name, to_obj_name) in binary_predicates:
|
| 761 |
if (
|
|
|
|
| 980 |
inputs = self.clip_processor(images=image, return_tensors="pt").to(self._device)
|
| 981 |
return self._image_features_checkpoint(model, inputs["pixel_values"])
|
| 982 |
|
| 983 |
+
def _augment_unary_with_negation(self, unary_keywords: List[str]) -> List[str]:
|
| 984 |
+
"""
|
| 985 |
+
Given unary predicates like ["running", "walking"], add "not running",
|
| 986 |
+
"not walking" if they are not already present (case-insensitive).
|
| 987 |
+
"""
|
| 988 |
+
base = [kw for kw in unary_keywords if isinstance(kw, str) and kw.strip()]
|
| 989 |
+
seen_lower = {kw.lower() for kw in base}
|
| 990 |
+
augmented = list(base)
|
| 991 |
+
for kw in base:
|
| 992 |
+
neg = f"not {kw}"
|
| 993 |
+
if neg.lower() not in seen_lower:
|
| 994 |
+
augmented.append(neg)
|
| 995 |
+
seen_lower.add(neg.lower())
|
| 996 |
+
return augmented
|
| 997 |
+
|
| 998 |
# ------------------------------------------------------------------ #
|
| 999 |
# High-level predict API
|
| 1000 |
# ------------------------------------------------------------------ #
|
|
|
|
| 1012 |
return_valid_pairs: Optional[bool] = None,
|
| 1013 |
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
|
| 1014 |
debug_visualizations: Optional[bool] = None,
|
| 1015 |
+
auto_add_not_unary: Optional[bool] = None,
|
| 1016 |
+
batched_binary_predicates: Optional[Dict[int, List[Tuple[str, str, str]]]] = None,
|
| 1017 |
+
topk_cate: Optional[int] = None,
|
| 1018 |
) -> Dict[str, Any]:
|
| 1019 |
+
if unary_keywords is None:
|
| 1020 |
+
unary_keywords = []
|
| 1021 |
+
else:
|
| 1022 |
+
unary_keywords = list(unary_keywords)
|
| 1023 |
+
|
| 1024 |
+
if binary_keywords is None:
|
| 1025 |
+
binary_keywords = []
|
| 1026 |
+
else:
|
| 1027 |
+
binary_keywords = list(binary_keywords)
|
| 1028 |
+
|
| 1029 |
+
if object_pairs is None:
|
| 1030 |
+
object_pairs = []
|
| 1031 |
+
|
| 1032 |
+
if auto_add_not_unary is None:
|
| 1033 |
+
auto_add_not_unary = getattr(self.config, "auto_add_not_unary", False)
|
| 1034 |
+
|
| 1035 |
+
if auto_add_not_unary:
|
| 1036 |
+
unary_keywords = self._augment_unary_with_negation(unary_keywords)
|
| 1037 |
+
|
| 1038 |
+
forward_extra_kwargs: Dict[str, Any] = {}
|
| 1039 |
+
if batched_binary_predicates is not None:
|
| 1040 |
+
forward_extra_kwargs["batched_binary_predicates"] = batched_binary_predicates
|
| 1041 |
+
if topk_cate is not None:
|
| 1042 |
+
forward_extra_kwargs["topk_cate"] = topk_cate
|
| 1043 |
+
|
| 1044 |
with torch.no_grad():
|
| 1045 |
outputs = self.forward(
|
| 1046 |
video_frames=video_frames,
|
|
|
|
| 1054 |
return_valid_pairs=return_valid_pairs,
|
| 1055 |
interested_object_pairs=interested_object_pairs,
|
| 1056 |
debug_visualizations=debug_visualizations,
|
| 1057 |
+
**forward_extra_kwargs,
|
| 1058 |
)
|
| 1059 |
|
| 1060 |
formatted_categorical: Dict[int, List[Tuple[float, str]]] = {}
|
vine_hf/vine_pipeline.py
CHANGED
|
@@ -107,6 +107,14 @@ class VinePipeline(Pipeline):
|
|
| 107 |
forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
|
| 108 |
if "object_pairs" in kwargs:
|
| 109 |
forward_kwargs["object_pairs"] = kwargs["object_pairs"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if "return_flattened_segments" in kwargs:
|
| 111 |
forward_kwargs["return_flattened_segments"] = kwargs[
|
| 112 |
"return_flattened_segments"
|
|
@@ -126,7 +134,9 @@ class VinePipeline(Pipeline):
|
|
| 126 |
if "self.visualize" in kwargs:
|
| 127 |
postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
|
| 128 |
if "binary_confidence_threshold" in kwargs:
|
| 129 |
-
postprocess_kwargs["binary_confidence_threshold"] = kwargs[
|
|
|
|
|
|
|
| 130 |
|
| 131 |
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
| 132 |
|
|
|
|
| 107 |
forward_kwargs["binary_keywords"] = kwargs["binary_keywords"]
|
| 108 |
if "object_pairs" in kwargs:
|
| 109 |
forward_kwargs["object_pairs"] = kwargs["object_pairs"]
|
| 110 |
+
if "batched_binary_predicates" in kwargs:
|
| 111 |
+
# New: per-video (rel, from_cat, to_cat) triples for binary filtering
|
| 112 |
+
forward_kwargs["batched_binary_predicates"] = kwargs["batched_binary_predicates"]
|
| 113 |
+
if "topk_cate" in kwargs:
|
| 114 |
+
# New: override topk_cate when binary filtering is requested
|
| 115 |
+
forward_kwargs["topk_cate"] = kwargs["topk_cate"]
|
| 116 |
+
if "auto_add_not_unary" in kwargs:
|
| 117 |
+
forward_kwargs["auto_add_not_unary"] = kwargs["auto_add_not_unary"]
|
| 118 |
if "return_flattened_segments" in kwargs:
|
| 119 |
forward_kwargs["return_flattened_segments"] = kwargs[
|
| 120 |
"return_flattened_segments"
|
|
|
|
| 134 |
if "self.visualize" in kwargs:
|
| 135 |
postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
|
| 136 |
if "binary_confidence_threshold" in kwargs:
|
| 137 |
+
postprocess_kwargs["binary_confidence_threshold"] = kwargs[
|
| 138 |
+
"binary_confidence_threshold"
|
| 139 |
+
]
|
| 140 |
|
| 141 |
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
| 142 |
|