LASER / app.py
moqingyan123
final fixes
888f9e4
raw
history blame
9.43 kB
from pathlib import Path
from collections.abc import Mapping, Sequence
from functools import lru_cache
import inspect
import shutil
import tempfile
import os
import sys
import spaces # <-- ZeroGPU integration
import gradio as gr
import torch
from transformers import pipeline # not strictly necessary, but fine
# -----------------------------
# Environment / diagnostics
# -----------------------------
os.environ["GRADIO_TEMP_DIR"] = str(Path(__file__).parent / "gradio_temp")
os.environ["OPENAI_API_KEY"] = "test"
os.environ["OMP_NUM_THREADS"] = "4"
print("All imports finished")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"cuDNN version: {torch.backends.cudnn.version()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
print(
f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB"
)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
os.environ["TORCH_DTYPE"] = "float32"
torch.set_default_dtype(torch.float32)
current_dir = Path(__file__).resolve().parent
# For Spaces, assume checkpoints live alongside app.py or in a "checkpoints" subdir.
# If you keep them next to app.py locally, this still works.
# NOTE: SAM2 config uses Hydra, so we use just the filename (it searches in sam2/configs/)
sam_config_path = "sam2_hiera_t.yaml" # Hydra will find this in sam2/configs/
sam_checkpoint_path = str(current_dir / "sam2_hiera_tiny.pt")
gd_config_path = str(current_dir / "GroundingDINO_SwinT_OGC.py")
gd_checkpoint_path = str(current_dir / "groundingdino_swint_ogc.pth")
visualization_dir = str(current_dir / "outputs")
print(
f"Setting up paths: {sam_config_path}, {sam_checkpoint_path}, {gd_config_path}, {gd_checkpoint_path}"
)
@lru_cache(maxsize=1)
def _load_vine_pipeline():
"""
Lazy-load and cache the Vine pipeline so we don't re-download/rebuild it on every request.
"""
from vine_hf import VineConfig, VineModel, VinePipeline
config = VineConfig(
segmentation_method="grounding_dino_sam2",
model_name="openai/clip-vit-base-patch32",
use_hf_repo=True,
model_repo="KevinX-Penn28/testing",
box_threshold=0.35,
text_threshold=0.25,
target_fps=1, # default 1 FPS
topk_cate=5,
white_alpha=0.3,
visualization_dir=visualization_dir,
visualize=True,
debug_visualizations=False,
device="cuda",
categorical_pool="max",
)
model = VineModel(config)
return VinePipeline(
model=model,
tokenizer=None,
sam_config_path=sam_config_path,
sam_checkpoint_path=sam_checkpoint_path,
gd_config_path=gd_config_path,
gd_checkpoint_path=gd_checkpoint_path,
device="cuda",
trust_remote_code=True,
)
@spaces.GPU(duration=300) # Up to ~5 minutes of H200 ZeroGPU time per call
def process_video(
video_file,
categorical_keywords,
unary_keywords,
binary_keywords,
object_pairs,
output_fps,
box_threshold,
text_threshold,
):
vine_pipe = _load_vine_pipeline()
# Normalize incoming video input to a file path
if isinstance(video_file, dict):
video_file = (
video_file.get("name")
or video_file.get("filepath")
or video_file.get("data")
)
if not isinstance(video_file, (str, Path)):
raise ValueError(f"Unsupported video input type: {type(video_file)}")
categorical_keywords = (
[kw.strip() for kw in categorical_keywords.split(",")]
if categorical_keywords
else []
)
unary_keywords = (
[kw.strip() for kw in unary_keywords.split(",")] if unary_keywords else []
)
binary_keywords = (
[kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
)
object_pairs = (
[tuple(map(int, pair.split("-"))) for pair in object_pairs.split(",")]
if object_pairs
else []
)
results = vine_pipe(
inputs=video_file,
categorical_keywords=categorical_keywords,
unary_keywords=unary_keywords,
binary_keywords=binary_keywords,
object_pairs=object_pairs,
segmentation_method="grounding_dino_sam2",
return_top_k=5,
include_visualizations=True,
debug_visualizations=False,
device="cuda",
box_threshold=box_threshold,
text_threshold=text_threshold,
target_fps=output_fps,
)
vine_pipe.box_threshold = box_threshold
vine_pipe.text_threshold = text_threshold
vine_pipe.target_fps = output_fps
if isinstance(results, Mapping):
results_dict = results
elif isinstance(results, Sequence) and results and isinstance(results[0], Mapping):
results_dict = results[0]
else:
results_dict = {}
visualizations = results_dict.get("visualizations") or {}
vine = visualizations.get("vine") or {}
all_vis = vine.get("all") or {}
result_video_path = all_vis.get("video_path")
if not result_video_path:
candidates = sorted(
Path(visualization_dir).rglob("*.mp4"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
result_video_path = str(candidates[0]) if candidates else None
summary = results_dict.get("summary") or {}
if result_video_path and os.path.exists(result_video_path):
gradio_tmp = Path(
os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
) / "vine_outputs"
gradio_tmp.mkdir(parents=True, exist_ok=True)
dest_path = gradio_tmp / Path(result_video_path).name
try:
shutil.copyfile(result_video_path, dest_path)
video_path_for_ui = str(dest_path)
except Exception as e:
print(f"Warning: failed to copy video to Gradio temp dir: {e}")
video_path_for_ui = str(result_video_path)
else:
video_path_for_ui = None
print(
"Warning: annotated video not found or empty; check visualization settings."
)
return video_path_for_ui, summary
def _video_component(label: str, *, is_output: bool = False):
"""
Build a Gradio Video component that is compatible with older Gradio versions
(no `type`/`sources`/`format` kwargs) and newer ones when available.
"""
kwargs = {"label": label}
sig = inspect.signature(gr.Video.__init__)
# Only set format for OUTPUT components
if is_output and "format" in sig.parameters:
kwargs["format"] = "mp4"
if not is_output:
if "type" in sig.parameters:
kwargs["type"] = "filepath"
if "sources" in sig.parameters:
kwargs["sources"] = ["upload"]
if is_output and "autoplay" in sig.parameters:
kwargs["autoplay"] = True
return gr.Video(**kwargs)
def _create_blocks():
"""
Build a Blocks context that works across Gradio versions.
"""
blocks_kwargs = {"title": "VINE Demo"}
soft_theme = None
if hasattr(gr, "themes") and hasattr(gr.themes, "Soft"):
try:
soft_theme = gr.themes.Soft()
except Exception:
soft_theme = None
if "theme" in inspect.signature(gr.Blocks).parameters and soft_theme is not None:
blocks_kwargs["theme"] = soft_theme
return gr.Blocks(**blocks_kwargs)
# Create Gradio interface
with _create_blocks() as demo:
video_input = _video_component("Upload Video", is_output=False)
categorical_input = gr.Textbox(
label="Categorical Keywords (comma-separated)",
value="person, car, tree, background",
)
unary_input = gr.Textbox(
label="Unary Keywords (comma-separated)", value="walking, running, standing"
)
binary_input = gr.Textbox(
label="Binary Keywords (comma-separated)",
placeholder="e.g., chasing, carrying",
)
pairs_input = gr.Textbox(
label="Object Pairs (comma-separated indices)",
placeholder="e.g., 0-1,0-2 for pairs of objects",
)
fps_input = gr.Number(
label="Output FPS (affects processing speed)", value=1 # default 1 FPS
)
with gr.Accordion("Advanced Settings", open=False):
box_threshold_input = gr.Slider(
label="Box Threshold", minimum=0.1, maximum=0.9, value=0.35, step=0.05
)
text_threshold_input = gr.Slider(
label="Text Threshold", minimum=0.1, maximum=0.9, value=0.25, step=0.05
)
submit_btn = gr.Button("Process Video", variant="primary")
video_output = _video_component("Output Video with Annotations", is_output=True)
json_output = gr.JSON(label="Summary of Detected Events")
submit_btn.click(
fn=process_video,
inputs=[
video_input,
categorical_input,
unary_input,
binary_input,
pairs_input,
fps_input,
box_threshold_input,
text_threshold_input,
],
outputs=[video_output, json_output],
)
if __name__ == "__main__":
print("Got to main")
demo.launch(share=True, debug=True)