Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| from collections.abc import Mapping, Sequence | |
| from functools import lru_cache | |
| import inspect | |
| import shutil | |
| import tempfile | |
| import os | |
| import sys | |
| # Add src/ to sys.path so LASER, video-sam2, GroundingDINO are importable | |
| current_dir = Path(__file__).resolve().parent | |
| src_dir = current_dir / "src" | |
| if src_dir.is_dir() and str(src_dir) not in sys.path: | |
| sys.path.insert(0, str(src_dir)) | |
| import spaces # <-- ZeroGPU integration | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline # not strictly necessary, but fine | |
| # ----------------------------- | |
| # Environment / diagnostics | |
| # ----------------------------- | |
| os.environ["GRADIO_TEMP_DIR"] = str(Path(__file__).parent / "gradio_temp") | |
| os.environ["OPENAI_API_KEY"] = "test" | |
| os.environ["OMP_NUM_THREADS"] = "4" | |
| print("All imports finished") | |
| print(f"Python version: {sys.version}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| print(f"CUDA version: {torch.version.cuda}") | |
| print(f"cuDNN version: {torch.backends.cudnn.version()}") | |
| print(f"Number of GPUs: {torch.cuda.device_count()}") | |
| if torch.cuda.is_available(): | |
| for i in range(torch.cuda.device_count()): | |
| print(f"GPU {i}: {torch.cuda.get_device_name(i)}") | |
| print( | |
| f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB" | |
| ) | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| os.environ["TORCH_DTYPE"] = "float32" | |
| torch.set_default_dtype(torch.float32) | |
| current_dir = Path(__file__).resolve().parent | |
| # For Spaces, assume checkpoints live alongside app.py or in a "checkpoints" subdir. | |
| # If you keep them next to app.py locally, this still works. | |
| # NOTE: SAM2 config uses Hydra, so we use just the filename (it searches in sam2/configs/) | |
| sam_config_path = "sam2_hiera_t.yaml" # Hydra will find this in sam2/configs/ | |
| sam_checkpoint_path = str(current_dir / "sam2_hiera_tiny.pt") | |
| gd_config_path = str(current_dir / "GroundingDINO_SwinT_OGC.py") | |
| gd_checkpoint_path = str(current_dir / "groundingdino_swint_ogc.pth") | |
| visualization_dir = str(current_dir / "outputs") | |
| print( | |
| f"Setting up paths: {sam_config_path}, {sam_checkpoint_path}, {gd_config_path}, {gd_checkpoint_path}" | |
| ) | |
| def _split_top_level_commas(s: str): | |
| """ | |
| Split a string on commas that are NOT inside parentheses. | |
| Example: | |
| "behind(person, dog), bite(dog, frisbee)" | |
| -> ["behind(person, dog)", "bite(dog, frisbee)"] | |
| """ | |
| parts = [] | |
| buf = [] | |
| depth = 0 | |
| for ch in s: | |
| if ch == "(": | |
| depth += 1 | |
| buf.append(ch) | |
| elif ch == ")": | |
| if depth > 0: | |
| depth -= 1 | |
| buf.append(ch) | |
| elif ch == "," and depth == 0: | |
| part = "".join(buf).strip() | |
| if part: | |
| parts.append(part) | |
| buf = [] | |
| else: | |
| buf.append(ch) | |
| if buf: | |
| part = "".join(buf).strip() | |
| if part: | |
| parts.append(part) | |
| return parts | |
| def _extract_categories_from_binary(binary_keywords_str: str) -> list[str]: | |
| """ | |
| Pull candidate category tokens from binary keyword strings, e.g. relation(a, b). | |
| Only returns tokens when parentheses and two comma-separated entries exist. | |
| """ | |
| categories: list[str] = [] | |
| for kw in _split_top_level_commas(binary_keywords_str or ""): | |
| lpar = kw.find("(") | |
| rpar = kw.rfind(")") | |
| if lpar == -1 or rpar <= lpar: | |
| continue | |
| inside = kw[lpar + 1 : rpar] | |
| parts = [p.strip() for p in inside.split(",") if p.strip()] | |
| if len(parts) == 2: | |
| categories.extend(parts) | |
| return categories | |
| def _parse_binary_keywords(binary_keywords_str: str, categorical_keywords: list[str]): | |
| """ | |
| Parse binary keyword string like: | |
| "behind(person, dog), bite(dog, frisbee)" | |
| into: | |
| - binary_keywords_list: list of raw strings (used as CLIP text) | |
| - batched_binary_predicates: {0: [(rel_text, from_cat, to_cat), ...]} or None | |
| - warnings: list of warning strings about invalid/mismatched categories | |
| """ | |
| if not binary_keywords_str: | |
| return [], None, [] | |
| cat_map = { | |
| kw.strip().lower(): kw.strip() | |
| for kw in categorical_keywords | |
| if isinstance(kw, str) and kw.strip() | |
| } | |
| entries = _split_top_level_commas(binary_keywords_str) | |
| binary_keywords_list: list[str] = [] | |
| predicates: list[tuple[str, str, str]] = [] | |
| warnings: list[str] = [] | |
| for raw in entries: | |
| kw = raw.strip() | |
| if not kw: | |
| continue | |
| # Always use the full raw keyword as the CLIP text string | |
| binary_keywords_list.append(kw) | |
| lpar = kw.find("(") | |
| rpar = kw.rfind(")") | |
| if (lpar == -1 and rpar != -1) or (lpar != -1 and rpar == -1) or rpar < lpar: | |
| msg = ( | |
| f"Binary keyword '{kw}' has mismatched parentheses; expected " | |
| "relation(from_category, to_category)." | |
| ) | |
| print(msg) | |
| warnings.append(msg) | |
| continue | |
| if lpar == -1 or rpar <= lpar: | |
| # No explicit (from,to) part; treat as plain relation (no category filter) | |
| continue | |
| inside = kw[lpar + 1 : rpar] | |
| parts = inside.split(",") | |
| if len(parts) != 2: | |
| msg = ( | |
| f"Ignoring '(from,to)' part in binary keyword '{kw}': " | |
| f"expected exactly two comma-separated items." | |
| ) | |
| print(msg) | |
| warnings.append(msg) | |
| continue | |
| from_raw = parts[0].strip() | |
| to_raw = parts[1].strip() | |
| if not from_raw or not to_raw: | |
| msg = f"Ignoring binary keyword '{kw}': empty from/to category." | |
| print(msg) | |
| warnings.append(msg) | |
| continue | |
| canonical_from = cat_map.get(from_raw.lower()) | |
| canonical_to = cat_map.get(to_raw.lower()) | |
| if canonical_from is None: | |
| msg = ( | |
| f"Binary keyword '{kw}': from-category '{from_raw}' does not " | |
| f"match any categorical keyword {categorical_keywords}." | |
| ) | |
| print(msg) | |
| warnings.append(msg) | |
| if canonical_to is None: | |
| msg = ( | |
| f"Binary keyword '{kw}': to-category '{to_raw}' does not " | |
| f"match any categorical keyword {categorical_keywords}." | |
| ) | |
| print(msg) | |
| warnings.append(msg) | |
| if canonical_from is None or canonical_to is None: | |
| continue | |
| # Store (relation_text, from_category, to_category) | |
| predicates.append((kw, canonical_from, canonical_to)) | |
| if not predicates: | |
| return binary_keywords_list, None, warnings | |
| return binary_keywords_list, {0: predicates}, warnings | |
| def _load_vine_pipeline(): | |
| """ | |
| Lazy-load and cache the LASER (VINE HF) pipeline so we don't re-download/rebuild it on every request. | |
| """ | |
| from vine_hf import VineConfig, VineModel, VinePipeline | |
| config = VineConfig( | |
| segmentation_method="grounding_dino_sam2", | |
| model_name="openai/clip-vit-base-patch32", | |
| use_hf_repo=True, | |
| model_repo="KevinX-Penn28/testing", | |
| box_threshold=0.35, | |
| text_threshold=0.25, | |
| target_fps=1, # default 1 FPS | |
| topk_cate=5, | |
| white_alpha=0.3, | |
| visualization_dir=visualization_dir, | |
| visualize=True, | |
| debug_visualizations=False, | |
| device="cuda", | |
| categorical_pool="max", | |
| auto_add_not_unary=False, # UI will control this per-call | |
| ) | |
| model = VineModel(config) | |
| return VinePipeline( | |
| model=model, | |
| tokenizer=None, | |
| sam_config_path=sam_config_path, | |
| sam_checkpoint_path=sam_checkpoint_path, | |
| gd_config_path=gd_config_path, | |
| gd_checkpoint_path=gd_checkpoint_path, | |
| device="cuda", | |
| trust_remote_code=True, | |
| ) | |
| # Up to ~5 minutes of H200 ZeroGPU time per call | |
| def process_video( | |
| video_file, | |
| categorical_keywords, | |
| unary_keywords, | |
| binary_keywords, | |
| auto_add_not_unary, | |
| output_fps, | |
| box_threshold, | |
| text_threshold, | |
| binary_confidence_threshold, | |
| ): | |
| vine_pipe = _load_vine_pipeline() | |
| # Normalize incoming video input to a file path | |
| if isinstance(video_file, dict): | |
| video_file = ( | |
| video_file.get("name") | |
| or video_file.get("filepath") | |
| or video_file.get("data") | |
| ) | |
| if not isinstance(video_file, (str, Path)): | |
| raise ValueError(f"Unsupported video input type: {type(video_file)}") | |
| video_path = Path(video_file) | |
| if video_path.suffix.lower() != ".mp4": | |
| msg = ( | |
| "Please upload an MP4 file. LASER currently supports MP4 inputs for " | |
| "scene-graph generation." | |
| ) | |
| print(msg) | |
| return None, {"error": msg} | |
| video_file = str(video_path) | |
| # Keep original strings for parsing | |
| categorical_keywords_str = categorical_keywords | |
| unary_keywords_str = unary_keywords | |
| binary_keywords_str = binary_keywords | |
| categorical_keywords = ( | |
| [kw.strip() for kw in categorical_keywords_str.split(",")] | |
| if categorical_keywords_str | |
| else [] | |
| ) | |
| unary_keywords = ( | |
| [kw.strip() for kw in unary_keywords_str.split(",")] | |
| if unary_keywords_str | |
| else [] | |
| ) | |
| # Preprocess: pull categories referenced in binary keywords and add any missing ones | |
| added_categories: list[str] = [] | |
| extra_cats = _extract_categories_from_binary(binary_keywords_str or "") | |
| if extra_cats: | |
| existing_lower = {kw.lower() for kw in categorical_keywords} | |
| for cat in extra_cats: | |
| if cat and cat.lower() not in existing_lower: | |
| categorical_keywords.append(cat) | |
| existing_lower.add(cat.lower()) | |
| added_categories.append(cat) | |
| # Parse binary keywords with category info (if provided) | |
| ( | |
| binary_keywords_list, | |
| batched_binary_predicates, | |
| binary_input_warnings, | |
| ) = _parse_binary_keywords(binary_keywords_str or "", categorical_keywords) | |
| if added_categories: | |
| binary_input_warnings.append( | |
| "Auto-added categorical keywords from binary relations: " | |
| + ", ".join(added_categories) | |
| ) | |
| skip_binary = len(binary_keywords_list) == 0 | |
| # Debug: Print what we're sending to the pipeline | |
| print("\n" + "=" * 80) | |
| print("INPUT TO LASER PIPELINE:") | |
| print(f" categorical_keywords: {categorical_keywords}") | |
| print(f" unary_keywords: {unary_keywords}") | |
| print(f" binary_keywords (raw parsed): {binary_keywords_list}") | |
| print(f" batched_binary_predicates: {batched_binary_predicates}") | |
| print(f" auto_add_not_unary: {auto_add_not_unary}") | |
| print(f" skip_binary: {skip_binary}") | |
| print("=" * 80 + "\n") | |
| # Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py | |
| object_pairs: list[tuple[int, int]] = [] | |
| extra_forward_kwargs = {} | |
| if batched_binary_predicates is not None and not skip_binary: | |
| # Use category-based filtering of binary pairs | |
| extra_forward_kwargs["batched_binary_predicates"] = batched_binary_predicates | |
| extra_forward_kwargs["topk_cate"] = 1 # as requested | |
| extra_forward_kwargs["auto_add_not_unary"] = bool(auto_add_not_unary) | |
| if skip_binary: | |
| extra_forward_kwargs["disable_binary"] = True | |
| results = vine_pipe( | |
| inputs=video_file, | |
| categorical_keywords=categorical_keywords, | |
| unary_keywords=unary_keywords, | |
| binary_keywords=binary_keywords_list, | |
| object_pairs=object_pairs, | |
| segmentation_method="grounding_dino_sam2", | |
| return_top_k=5, | |
| include_visualizations=True, | |
| debug_visualizations=False, | |
| device="cuda", | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| target_fps=output_fps, | |
| binary_confidence_threshold=binary_confidence_threshold, | |
| **extra_forward_kwargs, | |
| ) | |
| # Debug: Print what the pipeline returned | |
| print("\n" + "=" * 80) | |
| print("PIPELINE RESULTS DEBUG:") | |
| print(f" results type: {type(results)}") | |
| if isinstance(results, dict): | |
| print(f" results keys: {list(results.keys())}") | |
| print("=" * 80 + "\n") | |
| vine_pipe.box_threshold = box_threshold | |
| vine_pipe.text_threshold = text_threshold | |
| vine_pipe.target_fps = output_fps | |
| if isinstance(results, Mapping): | |
| results_dict = results | |
| elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping): | |
| results_dict = results[0] | |
| else: | |
| results_dict = {} | |
| visualizations = results_dict.get("visualizations") or {} | |
| vine = visualizations.get("vine") or {} | |
| all_vis = vine.get("all") or {} | |
| result_video_path = all_vis.get("video_path") | |
| if not result_video_path: | |
| candidates = sorted( | |
| Path(visualization_dir).rglob("*.mp4"), | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True, | |
| ) | |
| result_video_path = str(candidates[0]) if candidates else None | |
| summary = results_dict.get("summary") or {} | |
| # Attach any binary category parsing warnings into the summary JSON | |
| if binary_input_warnings: | |
| if "binary_input_warnings" in summary: | |
| summary["binary_input_warnings"].extend(binary_input_warnings) | |
| else: | |
| summary["binary_input_warnings"] = binary_input_warnings | |
| if result_video_path and os.path.exists(result_video_path): | |
| gradio_tmp = ( | |
| Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())) | |
| / "vine_outputs" | |
| ) | |
| gradio_tmp.mkdir(parents=True, exist_ok=True) | |
| dest_path = gradio_tmp / Path(result_video_path).name | |
| try: | |
| shutil.copyfile(result_video_path, dest_path) | |
| video_path_for_ui = str(dest_path) | |
| except Exception as e: | |
| print(f"Warning: failed to copy video to Gradio temp dir: {e}") | |
| video_path_for_ui = str(result_video_path) | |
| else: | |
| video_path_for_ui = None | |
| print( | |
| "Warning: annotated video not found or empty; check visualization settings." | |
| ) | |
| return video_path_for_ui, summary | |
| def _video_component(label: str, *, is_output: bool = False): | |
| """ | |
| Build a Gradio Video component that is compatible with older Gradio versions | |
| (no `type`/`sources`/`format` kwargs) and newer ones when available. | |
| """ | |
| kwargs = {"label": label} | |
| sig = inspect.signature(gr.Video.__init__) | |
| # Only set format for OUTPUT components | |
| if is_output and "format" in sig.parameters: | |
| kwargs["format"] = "mp4" | |
| if not is_output: | |
| if "type" in sig.parameters: | |
| kwargs["type"] = "filepath" | |
| if "sources" in sig.parameters: | |
| kwargs["sources"] = ["upload"] | |
| # Restrict to MP4 files only | |
| if "file_types" in sig.parameters: | |
| kwargs["file_types"] = [".mp4"] | |
| if is_output and "autoplay" in sig.parameters: | |
| kwargs["autoplay"] = True | |
| return gr.Video(**kwargs) | |
| def _create_blocks(): | |
| """ | |
| Build a Blocks context that works across Gradio versions. | |
| """ | |
| blocks_kwargs = {"title": "LASER Scene Graph Demo"} | |
| soft_theme = None | |
| if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"): | |
| try: | |
| soft_theme = gr.themes.Soft() | |
| except Exception: | |
| soft_theme = None | |
| if "theme" in inspect.signature(gr.Blocks).parameters and soft_theme is not None: | |
| blocks_kwargs["theme"] = soft_theme | |
| return gr.Blocks(**blocks_kwargs) | |
| # Create Gradio interface with two-column layout | |
| with _create_blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎬 LASER: Spatio-temporal Scene Graphs for Video | |
| Turn any MP4 into a spatio-temporal scene graph with LASER - our 454-million parameter foundation model for scene-graph generation. LASER trains on 87K+ open-domain videos using a neurosymbolic caption-to-scene alignment pipeline, so it learns fine-grained video semantics without human labels. | |
| Upload an MP4 and sketch the scene graph you care about: specify the objects, actions, and interactions you want, and LASER will assemble a spatio-temporal scene graph plus an annotated video. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Left column: Inputs | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Scene Graph Inputs") | |
| video_input = _video_component("Upload Video (MP4 only)", is_output=False) | |
| gr.Markdown("*Note: Only MP4 format is currently supported*") | |
| gr.Markdown("#### Scene Graph Queries") | |
| categorical_input = gr.Textbox( | |
| label="Categorical Keywords", | |
| placeholder="e.g., person, car, dog", | |
| value="person, car, dog", | |
| info="Objects to detect in the video (comma-separated)", | |
| ) | |
| unary_input = gr.Textbox( | |
| label="Unary Keywords", | |
| placeholder="e.g., walking, running, standing", | |
| value="walking, running, standing", | |
| info="Single-object actions to detect (comma-separated)", | |
| ) | |
| binary_input = gr.Textbox( | |
| label="Binary Keywords", | |
| placeholder="e.g., behind(person, dog), bite(dog, frisbee)", | |
| info=( | |
| "Object-to-object interactions to detect. " | |
| "Use format: relation(from_category, to_category). " | |
| "Example: 'behind(person, dog), bite(dog, frisbee)'. " | |
| "If you omit '(from,to)', the relation will be applied to all object pairs (default behavior). " | |
| "Leave blank to skip binary relation search entirely." | |
| ), | |
| ) | |
| add_not_unary_checkbox = gr.Checkbox( | |
| label="Also query 'not <unary>' predicates", | |
| value=False, | |
| info="If enabled, for each unary keyword X, also query 'not X'.", | |
| ) | |
| gr.Markdown("#### Processing Settings") | |
| fps_input = gr.Number( | |
| label="Output FPS", | |
| value=1, | |
| info="Frames per second for processing (lower = faster)", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| box_threshold_input = gr.Slider( | |
| label="Box Threshold", | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.35, | |
| step=0.05, | |
| info="Confidence threshold for object detection", | |
| ) | |
| text_threshold_input = gr.Slider( | |
| label="Text Threshold", | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.25, | |
| step=0.05, | |
| info="Confidence threshold for text-based detection", | |
| ) | |
| binary_confidence_input = gr.Slider( | |
| label="Binary Relation Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=.5, | |
| step=0.05, | |
| info="Minimum confidence to show binary relations and object pairs", | |
| ) | |
| submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg") | |
| # Right column: Outputs | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Scene Graph Results") | |
| video_output = _video_component("Annotated Video Output", is_output=True) | |
| gr.Markdown("### Scene Graph Summary") | |
| summary_output = gr.JSON(label="Scene Graph / Detected Events") | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### How to Use LASER | |
| 1. Upload an MP4 (we validate the format for you). | |
| 2. Describe the **nodes** of your spatio-temporal scene graph with categorical keywords (objects) and unary keywords (single-object actions). | |
| 3. Wire up **binary** relations: | |
| - Use the structured form `relation(from_category, to_category)` (e.g., `behind(person, dog), bite(dog, frisbee)`) to limit relations to those category pairs. | |
| - Or list relation names (`chasing, carrying`) to evaluate all object pairs. | |
| - Leave the field blank to skip binary relations entirely (no pair search or binary predicates). | |
| - Categories referenced inside binary relations are auto-added to the categorical list for you. | |
| 4. Optionally enable automatic `'not <unary>'` predicates. | |
| 5. Adjust processing settings if needed and click **Process Video** to receive an annotated video plus the serialized scene graph. | |
| More to explore: | |
| - LASER paper (ICLR'25): https://arxiv.org/abs/2304.07647 | Demo: https://huggingface.co/spaces/jiani-huang/LASER | Code: https://github.com/video-fm/LASER | |
| - ESCA paper: https://arxiv.org/abs/2510.15963 | Code: https://github.com/video-fm/ESCA | Model: https://huggingface.co/video-fm/vine_v0 | Dataset: https://huggingface.co/datasets/video-fm/ESCA-video-87K | |
| - Meet us at **NeurIPS 2025** (San Diego, Exhibit Hall C/D/E, Booth #4908 - Wed, Dec 3 - 11:00 a.m.-2:00 p.m. PST) for the foundation model demo, code, and full paper. | |
| """ | |
| ) | |
| submit_btn.click( | |
| fn=process_video, | |
| inputs=[ | |
| video_input, | |
| categorical_input, | |
| unary_input, | |
| binary_input, | |
| add_not_unary_checkbox, | |
| fps_input, | |
| box_threshold_input, | |
| text_threshold_input, | |
| binary_confidence_input, | |
| ], | |
| outputs=[video_output, summary_output], | |
| ) | |
| if __name__ == "__main__": | |
| print("Got to main") | |
| demo.launch(share=True, debug=True) | |