Spaces:
Runtime error
Runtime error
Commit
·
02f1528
1
Parent(s):
69c91fa
Use full gpu for now
Browse files
app.py
CHANGED
|
@@ -4,8 +4,8 @@ from typing import Optional
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
-
import spaces
|
| 8 |
import torch
|
|
|
|
| 9 |
from PIL import Image, ImageDraw
|
| 10 |
|
| 11 |
# Prefer local transformers in the workspace
|
|
@@ -171,7 +171,6 @@ def _model_repo_from_key(key: str) -> str:
|
|
| 171 |
return mapping.get(key, mapping["base_plus"])
|
| 172 |
|
| 173 |
|
| 174 |
-
@spaces.GPU()
|
| 175 |
def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
|
| 176 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 177 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
|
@@ -238,7 +237,6 @@ def ensure_session_for_current_model() -> None:
|
|
| 238 |
GLOBAL_STATE.session_repo_id = desired_repo
|
| 239 |
|
| 240 |
|
| 241 |
-
@spaces.GPU()
|
| 242 |
def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
|
| 243 |
"""Gradio handler: load video, init session, return state, slider bounds, and first frame."""
|
| 244 |
# Reset ONLY video-related fields, keep model loaded
|
|
@@ -353,7 +351,6 @@ def _ensure_color_for_obj(obj_id: int):
|
|
| 353 |
GLOBAL_STATE.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 354 |
|
| 355 |
|
| 356 |
-
@spaces.GPU()
|
| 357 |
def on_image_click(
|
| 358 |
img: Image.Image | np.ndarray,
|
| 359 |
state: AppState,
|
|
@@ -485,7 +482,6 @@ def on_image_click(
|
|
| 485 |
return update_frame_display(GLOBAL_STATE, int(frame_idx))
|
| 486 |
|
| 487 |
|
| 488 |
-
@spaces.GPU()
|
| 489 |
def propagate_masks(state: AppState, progress=gr.Progress()):
|
| 490 |
if state is None or state.inference_session is None:
|
| 491 |
yield "Load a video first."
|
|
@@ -567,7 +563,9 @@ def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
|
|
| 567 |
return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
|
| 568 |
|
| 569 |
|
| 570 |
-
|
|
|
|
|
|
|
| 571 |
state = gr.State(GLOBAL_STATE)
|
| 572 |
|
| 573 |
gr.Markdown("""
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
import torch
|
| 8 |
+
from gradio.themes import Soft
|
| 9 |
from PIL import Image, ImageDraw
|
| 10 |
|
| 11 |
# Prefer local transformers in the workspace
|
|
|
|
| 171 |
return mapping.get(key, mapping["base_plus"])
|
| 172 |
|
| 173 |
|
|
|
|
| 174 |
def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
|
| 175 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
| 176 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
|
|
|
| 237 |
GLOBAL_STATE.session_repo_id = desired_repo
|
| 238 |
|
| 239 |
|
|
|
|
| 240 |
def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
|
| 241 |
"""Gradio handler: load video, init session, return state, slider bounds, and first frame."""
|
| 242 |
# Reset ONLY video-related fields, keep model loaded
|
|
|
|
| 351 |
GLOBAL_STATE.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 352 |
|
| 353 |
|
|
|
|
| 354 |
def on_image_click(
|
| 355 |
img: Image.Image | np.ndarray,
|
| 356 |
state: AppState,
|
|
|
|
| 482 |
return update_frame_display(GLOBAL_STATE, int(frame_idx))
|
| 483 |
|
| 484 |
|
|
|
|
| 485 |
def propagate_masks(state: AppState, progress=gr.Progress()):
|
| 486 |
if state is None or state.inference_session is None:
|
| 487 |
yield "Load a video first."
|
|
|
|
| 563 |
return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
|
| 564 |
|
| 565 |
|
| 566 |
+
theme = Soft(primary_hue="indigo", secondary_hue="rose", neutral_hue="slate")
|
| 567 |
+
|
| 568 |
+
with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo:
|
| 569 |
state = gr.State(GLOBAL_STATE)
|
| 570 |
|
| 571 |
gr.Markdown("""
|