LASER / app.py
jiani-huang's picture
updates
5bca17b
from pathlib import Path
from collections.abc import Mapping, Sequence
from functools import lru_cache
import inspect
import shutil
import tempfile
import os
import sys
# Add src/ to sys.path so LASER, video-sam2, GroundingDINO are importable
current_dir = Path(__file__).resolve().parent
src_dir = current_dir / "src"
if src_dir.is_dir() and str(src_dir) not in sys.path:
sys.path.insert(0, str(src_dir))
import spaces # <-- ZeroGPU integration
import gradio as gr
import torch
from transformers import pipeline # not strictly necessary, but fine
# -----------------------------
# Environment / diagnostics
# -----------------------------
os.environ["GRADIO_TEMP_DIR"] = str(Path(__file__).parent / "gradio_temp")
os.environ["OPENAI_API_KEY"] = "test"
os.environ["OMP_NUM_THREADS"] = "4"
print("All imports finished")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"cuDNN version: {torch.backends.cudnn.version()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
print(
f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB"
)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
os.environ["TORCH_DTYPE"] = "float32"
torch.set_default_dtype(torch.float32)
current_dir = Path(__file__).resolve().parent
# For Spaces, assume checkpoints live alongside app.py or in a "checkpoints" subdir.
# If you keep them next to app.py locally, this still works.
# NOTE: SAM2 config uses Hydra, so we use just the filename (it searches in sam2/configs/)
sam_config_path = "sam2_hiera_t.yaml" # Hydra will find this in sam2/configs/
sam_checkpoint_path = str(current_dir / "sam2_hiera_tiny.pt")
gd_config_path = str(current_dir / "GroundingDINO_SwinT_OGC.py")
gd_checkpoint_path = str(current_dir / "groundingdino_swint_ogc.pth")
visualization_dir = str(current_dir / "outputs")
print(
f"Setting up paths: {sam_config_path}, {sam_checkpoint_path}, {gd_config_path}, {gd_checkpoint_path}"
)
def _split_top_level_commas(s: str):
"""
Split a string on commas that are NOT inside parentheses.
Example:
"behind(person, dog), bite(dog, frisbee)"
-> ["behind(person, dog)", "bite(dog, frisbee)"]
"""
parts = []
buf = []
depth = 0
for ch in s:
if ch == "(":
depth += 1
buf.append(ch)
elif ch == ")":
if depth > 0:
depth -= 1
buf.append(ch)
elif ch == "," and depth == 0:
part = "".join(buf).strip()
if part:
parts.append(part)
buf = []
else:
buf.append(ch)
if buf:
part = "".join(buf).strip()
if part:
parts.append(part)
return parts
def _extract_categories_from_binary(binary_keywords_str: str) -> list[str]:
"""
Pull candidate category tokens from binary keyword strings, e.g. relation(a, b).
Only returns tokens when parentheses and two comma-separated entries exist.
"""
categories: list[str] = []
for kw in _split_top_level_commas(binary_keywords_str or ""):
lpar = kw.find("(")
rpar = kw.rfind(")")
if lpar == -1 or rpar <= lpar:
continue
inside = kw[lpar + 1 : rpar]
parts = [p.strip() for p in inside.split(",") if p.strip()]
if len(parts) == 2:
categories.extend(parts)
return categories
def _parse_binary_keywords(binary_keywords_str: str, categorical_keywords: list[str]):
"""
Parse binary keyword string like:
"behind(person, dog), bite(dog, frisbee)"
into:
- binary_keywords_list: list of raw strings (used as CLIP text)
- batched_binary_predicates: {0: [(rel_text, from_cat, to_cat), ...]} or None
- warnings: list of warning strings about invalid/mismatched categories
"""
if not binary_keywords_str:
return [], None, []
cat_map = {
kw.strip().lower(): kw.strip()
for kw in categorical_keywords
if isinstance(kw, str) and kw.strip()
}
entries = _split_top_level_commas(binary_keywords_str)
binary_keywords_list: list[str] = []
predicates: list[tuple[str, str, str]] = []
warnings: list[str] = []
for raw in entries:
kw = raw.strip()
if not kw:
continue
# Always use the full raw keyword as the CLIP text string
binary_keywords_list.append(kw)
lpar = kw.find("(")
rpar = kw.rfind(")")
if (lpar == -1 and rpar != -1) or (lpar != -1 and rpar == -1) or rpar < lpar:
msg = (
f"Binary keyword '{kw}' has mismatched parentheses; expected "
"relation(from_category, to_category)."
)
print(msg)
warnings.append(msg)
continue
if lpar == -1 or rpar <= lpar:
# No explicit (from,to) part; treat as plain relation (no category filter)
continue
inside = kw[lpar + 1 : rpar]
parts = inside.split(",")
if len(parts) != 2:
msg = (
f"Ignoring '(from,to)' part in binary keyword '{kw}': "
f"expected exactly two comma-separated items."
)
print(msg)
warnings.append(msg)
continue
from_raw = parts[0].strip()
to_raw = parts[1].strip()
if not from_raw or not to_raw:
msg = f"Ignoring binary keyword '{kw}': empty from/to category."
print(msg)
warnings.append(msg)
continue
canonical_from = cat_map.get(from_raw.lower())
canonical_to = cat_map.get(to_raw.lower())
if canonical_from is None:
msg = (
f"Binary keyword '{kw}': from-category '{from_raw}' does not "
f"match any categorical keyword {categorical_keywords}."
)
print(msg)
warnings.append(msg)
if canonical_to is None:
msg = (
f"Binary keyword '{kw}': to-category '{to_raw}' does not "
f"match any categorical keyword {categorical_keywords}."
)
print(msg)
warnings.append(msg)
if canonical_from is None or canonical_to is None:
continue
# Store (relation_text, from_category, to_category)
predicates.append((kw, canonical_from, canonical_to))
if not predicates:
return binary_keywords_list, None, warnings
return binary_keywords_list, {0: predicates}, warnings
@lru_cache(maxsize=1)
def _load_vine_pipeline():
"""
Lazy-load and cache the LASER (VINE HF) pipeline so we don't re-download/rebuild it on every request.
"""
from vine_hf import VineConfig, VineModel, VinePipeline
config = VineConfig(
segmentation_method="grounding_dino_sam2",
model_name="openai/clip-vit-base-patch32",
use_hf_repo=True,
model_repo="KevinX-Penn28/testing",
box_threshold=0.35,
text_threshold=0.25,
target_fps=1, # default 1 FPS
topk_cate=5,
white_alpha=0.3,
visualization_dir=visualization_dir,
visualize=True,
debug_visualizations=False,
device="cuda",
categorical_pool="max",
auto_add_not_unary=False, # UI will control this per-call
)
model = VineModel(config)
return VinePipeline(
model=model,
tokenizer=None,
sam_config_path=sam_config_path,
sam_checkpoint_path=sam_checkpoint_path,
gd_config_path=gd_config_path,
gd_checkpoint_path=gd_checkpoint_path,
device="cuda",
trust_remote_code=True,
)
@spaces.GPU(duration=120) # Up to ~5 minutes of H200 ZeroGPU time per call
def process_video(
video_file,
categorical_keywords,
unary_keywords,
binary_keywords,
auto_add_not_unary,
output_fps,
box_threshold,
text_threshold,
binary_confidence_threshold,
):
vine_pipe = _load_vine_pipeline()
# Normalize incoming video input to a file path
if isinstance(video_file, dict):
video_file = (
video_file.get("name")
or video_file.get("filepath")
or video_file.get("data")
)
if not isinstance(video_file, (str, Path)):
raise ValueError(f"Unsupported video input type: {type(video_file)}")
video_path = Path(video_file)
if video_path.suffix.lower() != ".mp4":
msg = (
"Please upload an MP4 file. LASER currently supports MP4 inputs for "
"scene-graph generation."
)
print(msg)
return None, {"error": msg}
video_file = str(video_path)
# Keep original strings for parsing
categorical_keywords_str = categorical_keywords
unary_keywords_str = unary_keywords
binary_keywords_str = binary_keywords
categorical_keywords = (
[kw.strip() for kw in categorical_keywords_str.split(",")]
if categorical_keywords_str
else []
)
unary_keywords = (
[kw.strip() for kw in unary_keywords_str.split(",")]
if unary_keywords_str
else []
)
# Preprocess: pull categories referenced in binary keywords and add any missing ones
added_categories: list[str] = []
extra_cats = _extract_categories_from_binary(binary_keywords_str or "")
if extra_cats:
existing_lower = {kw.lower() for kw in categorical_keywords}
for cat in extra_cats:
if cat and cat.lower() not in existing_lower:
categorical_keywords.append(cat)
existing_lower.add(cat.lower())
added_categories.append(cat)
# Parse binary keywords with category info (if provided)
(
binary_keywords_list,
batched_binary_predicates,
binary_input_warnings,
) = _parse_binary_keywords(binary_keywords_str or "", categorical_keywords)
if added_categories:
binary_input_warnings.append(
"Auto-added categorical keywords from binary relations: "
+ ", ".join(added_categories)
)
skip_binary = len(binary_keywords_list) == 0
# Debug: Print what we're sending to the pipeline
print("\n" + "=" * 80)
print("INPUT TO LASER PIPELINE:")
print(f" categorical_keywords: {categorical_keywords}")
print(f" unary_keywords: {unary_keywords}")
print(f" binary_keywords (raw parsed): {binary_keywords_list}")
print(f" batched_binary_predicates: {batched_binary_predicates}")
print(f" auto_add_not_unary: {auto_add_not_unary}")
print(f" skip_binary: {skip_binary}")
print("=" * 80 + "\n")
# Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py
object_pairs: list[tuple[int, int]] = []
extra_forward_kwargs = {}
if batched_binary_predicates is not None and not skip_binary:
# Use category-based filtering of binary pairs
extra_forward_kwargs["batched_binary_predicates"] = batched_binary_predicates
extra_forward_kwargs["topk_cate"] = 1 # as requested
extra_forward_kwargs["auto_add_not_unary"] = bool(auto_add_not_unary)
if skip_binary:
extra_forward_kwargs["disable_binary"] = True
results = vine_pipe(
inputs=video_file,
categorical_keywords=categorical_keywords,
unary_keywords=unary_keywords,
binary_keywords=binary_keywords_list,
object_pairs=object_pairs,
segmentation_method="grounding_dino_sam2",
return_top_k=5,
include_visualizations=True,
debug_visualizations=False,
device="cuda",
box_threshold=box_threshold,
text_threshold=text_threshold,
target_fps=output_fps,
binary_confidence_threshold=binary_confidence_threshold,
**extra_forward_kwargs,
)
# Debug: Print what the pipeline returned
print("\n" + "=" * 80)
print("PIPELINE RESULTS DEBUG:")
print(f" results type: {type(results)}")
if isinstance(results, dict):
print(f" results keys: {list(results.keys())}")
print("=" * 80 + "\n")
vine_pipe.box_threshold = box_threshold
vine_pipe.text_threshold = text_threshold
vine_pipe.target_fps = output_fps
if isinstance(results, Mapping):
results_dict = results
elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping):
results_dict = results[0]
else:
results_dict = {}
visualizations = results_dict.get("visualizations") or {}
vine = visualizations.get("vine") or {}
all_vis = vine.get("all") or {}
result_video_path = all_vis.get("video_path")
if not result_video_path:
candidates = sorted(
Path(visualization_dir).rglob("*.mp4"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
result_video_path = str(candidates[0]) if candidates else None
summary = results_dict.get("summary") or {}
# Attach any binary category parsing warnings into the summary JSON
if binary_input_warnings:
if "binary_input_warnings" in summary:
summary["binary_input_warnings"].extend(binary_input_warnings)
else:
summary["binary_input_warnings"] = binary_input_warnings
if result_video_path and os.path.exists(result_video_path):
gradio_tmp = (
Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
/ "vine_outputs"
)
gradio_tmp.mkdir(parents=True, exist_ok=True)
dest_path = gradio_tmp / Path(result_video_path).name
try:
shutil.copyfile(result_video_path, dest_path)
video_path_for_ui = str(dest_path)
except Exception as e:
print(f"Warning: failed to copy video to Gradio temp dir: {e}")
video_path_for_ui = str(result_video_path)
else:
video_path_for_ui = None
print(
"Warning: annotated video not found or empty; check visualization settings."
)
return video_path_for_ui, summary
def _video_component(label: str, *, is_output: bool = False):
"""
Build a Gradio Video component that is compatible with older Gradio versions
(no `type`/`sources`/`format` kwargs) and newer ones when available.
"""
kwargs = {"label": label}
sig = inspect.signature(gr.Video.__init__)
# Only set format for OUTPUT components
if is_output and "format" in sig.parameters:
kwargs["format"] = "mp4"
if not is_output:
if "type" in sig.parameters:
kwargs["type"] = "filepath"
if "sources" in sig.parameters:
kwargs["sources"] = ["upload"]
# Restrict to MP4 files only
if "file_types" in sig.parameters:
kwargs["file_types"] = [".mp4"]
if is_output and "autoplay" in sig.parameters:
kwargs["autoplay"] = True
return gr.Video(**kwargs)
def _create_blocks():
"""
Build a Blocks context that works across Gradio versions.
"""
blocks_kwargs = {"title": "LASER Scene Graph Demo"}
soft_theme = None
if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
try:
soft_theme = gr.themes.Soft()
except Exception:
soft_theme = None
if "theme" in inspect.signature(gr.Blocks).parameters and soft_theme is not None:
blocks_kwargs["theme"] = soft_theme
return gr.Blocks(**blocks_kwargs)
# Create Gradio interface with two-column layout
with _create_blocks() as demo:
gr.Markdown(
"""
# 🎬 LASER: Spatio-temporal Scene Graphs for Video
Turn any MP4 into a spatio-temporal scene graph with LASER - our 454-million parameter foundation model for scene-graph generation. LASER trains on 87K+ open-domain videos using a neurosymbolic caption-to-scene alignment pipeline, so it learns fine-grained video semantics without human labels.
Upload an MP4 and sketch the scene graph you care about: specify the objects, actions, and interactions you want, and LASER will assemble a spatio-temporal scene graph plus an annotated video.
"""
)
with gr.Row():
# Left column: Inputs
with gr.Column(scale=1):
gr.Markdown("### Scene Graph Inputs")
video_input = _video_component("Upload Video (MP4 only)", is_output=False)
gr.Markdown("*Note: Only MP4 format is currently supported*")
gr.Markdown("#### Scene Graph Queries")
categorical_input = gr.Textbox(
label="Categorical Keywords",
placeholder="e.g., person, car, dog",
value="person, car, dog",
info="Objects to detect in the video (comma-separated)",
)
unary_input = gr.Textbox(
label="Unary Keywords",
placeholder="e.g., walking, running, standing",
value="walking, running, standing",
info="Single-object actions to detect (comma-separated)",
)
binary_input = gr.Textbox(
label="Binary Keywords",
placeholder="e.g., behind(person, dog), bite(dog, frisbee)",
info=(
"Object-to-object interactions to detect. "
"Use format: relation(from_category, to_category). "
"Example: 'behind(person, dog), bite(dog, frisbee)'. "
"If you omit '(from,to)', the relation will be applied to all object pairs (default behavior). "
"Leave blank to skip binary relation search entirely."
),
)
add_not_unary_checkbox = gr.Checkbox(
label="Also query 'not <unary>' predicates",
value=False,
info="If enabled, for each unary keyword X, also query 'not X'.",
)
gr.Markdown("#### Processing Settings")
fps_input = gr.Number(
label="Output FPS",
value=1,
info="Frames per second for processing (lower = faster)",
)
with gr.Accordion("Advanced Settings", open=False):
box_threshold_input = gr.Slider(
label="Box Threshold",
minimum=0.1,
maximum=0.9,
value=0.35,
step=0.05,
info="Confidence threshold for object detection",
)
text_threshold_input = gr.Slider(
label="Text Threshold",
minimum=0.1,
maximum=0.9,
value=0.25,
step=0.05,
info="Confidence threshold for text-based detection",
)
binary_confidence_input = gr.Slider(
label="Binary Relation Confidence Threshold",
minimum=0.0,
maximum=1.0,
value=.5,
step=0.05,
info="Minimum confidence to show binary relations and object pairs",
)
submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
# Right column: Outputs
with gr.Column(scale=1):
gr.Markdown("### Scene Graph Results")
video_output = _video_component("Annotated Video Output", is_output=True)
gr.Markdown("### Scene Graph Summary")
summary_output = gr.JSON(label="Scene Graph / Detected Events")
gr.Markdown(
"""
---
### How to Use LASER
1. Upload an MP4 (we validate the format for you).
2. Describe the **nodes** of your spatio-temporal scene graph with categorical keywords (objects) and unary keywords (single-object actions).
3. Wire up **binary** relations:
- Use the structured form `relation(from_category, to_category)` (e.g., `behind(person, dog), bite(dog, frisbee)`) to limit relations to those category pairs.
- Or list relation names (`chasing, carrying`) to evaluate all object pairs.
- Leave the field blank to skip binary relations entirely (no pair search or binary predicates).
- Categories referenced inside binary relations are auto-added to the categorical list for you.
4. Optionally enable automatic `'not <unary>'` predicates.
5. Adjust processing settings if needed and click **Process Video** to receive an annotated video plus the serialized scene graph.
More to explore:
- LASER paper (ICLR'25): https://arxiv.org/abs/2304.07647 | Demo: https://huggingface.co/spaces/jiani-huang/LASER | Code: https://github.com/video-fm/LASER
- ESCA paper: https://arxiv.org/abs/2510.15963 | Code: https://github.com/video-fm/ESCA | Model: https://huggingface.co/video-fm/vine_v0 | Dataset: https://huggingface.co/datasets/video-fm/ESCA-video-87K
- Meet us at **NeurIPS 2025** (San Diego, Exhibit Hall C/D/E, Booth #4908 - Wed, Dec 3 - 11:00 a.m.-2:00 p.m. PST) for the foundation model demo, code, and full paper.
"""
)
submit_btn.click(
fn=process_video,
inputs=[
video_input,
categorical_input,
unary_input,
binary_input,
add_not_unary_checkbox,
fps_input,
box_threshold_input,
text_threshold_input,
binary_confidence_input,
],
outputs=[video_output, summary_output],
)
if __name__ == "__main__":
print("Got to main")
demo.launch(share=True, debug=True)