Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| from collections.abc import Mapping, Sequence | |
| from functools import lru_cache | |
| import inspect | |
| import shutil | |
| import tempfile | |
| import os | |
| import sys | |
| import spaces # <-- ZeroGPU integration | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline # not strictly necessary, but fine | |
| # ----------------------------- | |
| # Environment / diagnostics | |
| # ----------------------------- | |
| os.environ["GRADIO_TEMP_DIR"] = str(Path(__file__).parent / "gradio_temp") | |
| os.environ["OPENAI_API_KEY"] = "test" | |
| os.environ["OMP_NUM_THREADS"] = "4" | |
| print("All imports finished") | |
| print(f"Python version: {sys.version}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| print(f"CUDA version: {torch.version.cuda}") | |
| print(f"cuDNN version: {torch.backends.cudnn.version()}") | |
| print(f"Number of GPUs: {torch.cuda.device_count()}") | |
| if torch.cuda.is_available(): | |
| for i in range(torch.cuda.device_count()): | |
| print(f"GPU {i}: {torch.cuda.get_device_name(i)}") | |
| print( | |
| f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB" | |
| ) | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| os.environ["TORCH_DTYPE"] = "float32" | |
| torch.set_default_dtype(torch.float32) | |
| current_dir = Path(__file__).resolve().parent | |
| # For Spaces, assume checkpoints live alongside app.py or in a "checkpoints" subdir. | |
| # If you keep them next to app.py locally, this still works. | |
| # NOTE: SAM2 config uses Hydra, so we use just the filename (it searches in sam2/configs/) | |
| sam_config_path = "sam2_hiera_t.yaml" # Hydra will find this in sam2/configs/ | |
| sam_checkpoint_path = str(current_dir / "sam2_hiera_tiny.pt") | |
| gd_config_path = str(current_dir / "GroundingDINO_SwinT_OGC.py") | |
| gd_checkpoint_path = str(current_dir / "groundingdino_swint_ogc.pth") | |
| visualization_dir = str(current_dir / "outputs") | |
| print( | |
| f"Setting up paths: {sam_config_path}, {sam_checkpoint_path}, {gd_config_path}, {gd_checkpoint_path}" | |
| ) | |
| def _load_vine_pipeline(): | |
| """ | |
| Lazy-load and cache the Vine pipeline so we don't re-download/rebuild it on every request. | |
| """ | |
| from vine_hf import VineConfig, VineModel, VinePipeline | |
| config = VineConfig( | |
| segmentation_method="grounding_dino_sam2", | |
| model_name="openai/clip-vit-base-patch32", | |
| use_hf_repo=True, | |
| model_repo="KevinX-Penn28/testing", | |
| box_threshold=0.35, | |
| text_threshold=0.25, | |
| target_fps=1, # default 1 FPS | |
| topk_cate=5, | |
| white_alpha=0.3, | |
| visualization_dir=visualization_dir, | |
| visualize=True, | |
| debug_visualizations=False, | |
| device="cuda", | |
| categorical_pool="max", | |
| ) | |
| model = VineModel(config) | |
| return VinePipeline( | |
| model=model, | |
| tokenizer=None, | |
| sam_config_path=sam_config_path, | |
| sam_checkpoint_path=sam_checkpoint_path, | |
| gd_config_path=gd_config_path, | |
| gd_checkpoint_path=gd_checkpoint_path, | |
| device="cuda", | |
| trust_remote_code=True, | |
| ) | |
| # Up to ~5 minutes of H200 ZeroGPU time per call | |
| def process_video( | |
| video_file, | |
| categorical_keywords, | |
| unary_keywords, | |
| binary_keywords, | |
| object_pairs, | |
| output_fps, | |
| box_threshold, | |
| text_threshold, | |
| ): | |
| vine_pipe = _load_vine_pipeline() | |
| # Normalize incoming video input to a file path | |
| if isinstance(video_file, dict): | |
| video_file = ( | |
| video_file.get("name") | |
| or video_file.get("filepath") | |
| or video_file.get("data") | |
| ) | |
| if not isinstance(video_file, (str, Path)): | |
| raise ValueError(f"Unsupported video input type: {type(video_file)}") | |
| categorical_keywords = ( | |
| [kw.strip() for kw in categorical_keywords.split(",")] | |
| if categorical_keywords | |
| else [] | |
| ) | |
| unary_keywords = ( | |
| [kw.strip() for kw in unary_keywords.split(",")] if unary_keywords else [] | |
| ) | |
| binary_keywords = ( | |
| [kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else [] | |
| ) | |
| object_pairs = ( | |
| [tuple(map(int, pair.split("-"))) for pair in object_pairs.split(",")] | |
| if object_pairs | |
| else [] | |
| ) | |
| results = vine_pipe( | |
| inputs=video_file, | |
| categorical_keywords=categorical_keywords, | |
| unary_keywords=unary_keywords, | |
| binary_keywords=binary_keywords, | |
| object_pairs=object_pairs, | |
| segmentation_method="grounding_dino_sam2", | |
| return_top_k=5, | |
| include_visualizations=True, | |
| debug_visualizations=False, | |
| device="cuda", | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| target_fps=output_fps, | |
| ) | |
| vine_pipe.box_threshold = box_threshold | |
| vine_pipe.text_threshold = text_threshold | |
| vine_pipe.target_fps = output_fps | |
| if isinstance(results, Mapping): | |
| results_dict = results | |
| elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping): | |
| results_dict = results[0] | |
| else: | |
| results_dict = {} | |
| visualizations = results_dict.get("visualizations") or {} | |
| vine = visualizations.get("vine") or {} | |
| all_vis = vine.get("all") or {} | |
| result_video_path = all_vis.get("video_path") | |
| if not result_video_path: | |
| candidates = sorted( | |
| Path(visualization_dir).rglob("*.mp4"), | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True, | |
| ) | |
| result_video_path = str(candidates[0]) if candidates else None | |
| summary = results_dict.get("summary") or {} | |
| if result_video_path and os.path.exists(result_video_path): | |
| gradio_tmp = Path( | |
| os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()) | |
| ) / "vine_outputs" | |
| gradio_tmp.mkdir(parents=True, exist_ok=True) | |
| dest_path = gradio_tmp / Path(result_video_path).name | |
| try: | |
| shutil.copyfile(result_video_path, dest_path) | |
| video_path_for_ui = str(dest_path) | |
| except Exception as e: | |
| print(f"Warning: failed to copy video to Gradio temp dir: {e}") | |
| video_path_for_ui = str(result_video_path) | |
| else: | |
| video_path_for_ui = None | |
| print( | |
| "Warning: annotated video not found or empty; check visualization settings." | |
| ) | |
| return video_path_for_ui, summary | |
| def _video_component(label: str, *, is_output: bool = False): | |
| """ | |
| Build a Gradio Video component that is compatible with older Gradio versions | |
| (no `type`/`sources`/`format` kwargs) and newer ones when available. | |
| """ | |
| kwargs = {"label": label} | |
| sig = inspect.signature(gr.Video.__init__) | |
| # Only set format for OUTPUT components | |
| if is_output and "format" in sig.parameters: | |
| kwargs["format"] = "mp4" | |
| if not is_output: | |
| if "type" in sig.parameters: | |
| kwargs["type"] = "filepath" | |
| if "sources" in sig.parameters: | |
| kwargs["sources"] = ["upload"] | |
| if is_output and "autoplay" in sig.parameters: | |
| kwargs["autoplay"] = True | |
| return gr.Video(**kwargs) | |
| def _create_blocks(): | |
| """ | |
| Build a Blocks context that works across Gradio versions. | |
| """ | |
| blocks_kwargs = {"title": "VINE Demo"} | |
| soft_theme = None | |
| if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"): | |
| try: | |
| soft_theme = gr.themes.Soft() | |
| except Exception: | |
| soft_theme = None | |
| if "theme" in inspect.signature(gr.Blocks).parameters and soft_theme is not None: | |
| blocks_kwargs["theme"] = soft_theme | |
| return gr.Blocks(**blocks_kwargs) | |
| # Create Gradio interface | |
| with _create_blocks() as demo: | |
| video_input = _video_component("Upload Video", is_output=False) | |
| categorical_input = gr.Textbox( | |
| label="Categorical Keywords (comma-separated)", | |
| value="person, car, tree, background", | |
| ) | |
| unary_input = gr.Textbox( | |
| label="Unary Keywords (comma-separated)", value="walking, running, standing" | |
| ) | |
| binary_input = gr.Textbox( | |
| label="Binary Keywords (comma-separated)", | |
| placeholder="e.g., chasing, carrying", | |
| ) | |
| pairs_input = gr.Textbox( | |
| label="Object Pairs (comma-separated indices)", | |
| placeholder="e.g., 0-1,0-2 for pairs of objects", | |
| ) | |
| fps_input = gr.Number( | |
| label="Output FPS (affects processing speed)", value=1 # default 1 FPS | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| box_threshold_input = gr.Slider( | |
| label="Box Threshold", minimum=0.1, maximum=0.9, value=0.35, step=0.05 | |
| ) | |
| text_threshold_input = gr.Slider( | |
| label="Text Threshold", minimum=0.1, maximum=0.9, value=0.25, step=0.05 | |
| ) | |
| submit_btn = gr.Button("Process Video", variant="primary") | |
| video_output = _video_component("Output Video with Annotations", is_output=True) | |
| json_output = gr.JSON(label="Summary of Detected Events") | |
| submit_btn.click( | |
| fn=process_video, | |
| inputs=[ | |
| video_input, | |
| categorical_input, | |
| unary_input, | |
| binary_input, | |
| pairs_input, | |
| fps_input, | |
| box_threshold_input, | |
| text_threshold_input, | |
| ], | |
| outputs=[video_output, json_output], | |
| ) | |
| if __name__ == "__main__": | |
| print("Got to main") | |
| demo.launch(share=True, debug=True) | |