yonigozlan HF Staff commited on
Commit
02f1528
·
1 Parent(s): 69c91fa

Use full gpu for now

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -4,8 +4,8 @@ from typing import Optional
4
 
5
  import gradio as gr
6
  import numpy as np
7
- import spaces
8
  import torch
 
9
  from PIL import Image, ImageDraw
10
 
11
  # Prefer local transformers in the workspace
@@ -171,7 +171,6 @@ def _model_repo_from_key(key: str) -> str:
171
  return mapping.get(key, mapping["base_plus"])
172
 
173
 
174
- @spaces.GPU()
175
  def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
176
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
177
  if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
@@ -238,7 +237,6 @@ def ensure_session_for_current_model() -> None:
238
  GLOBAL_STATE.session_repo_id = desired_repo
239
 
240
 
241
- @spaces.GPU()
242
  def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
243
  """Gradio handler: load video, init session, return state, slider bounds, and first frame."""
244
  # Reset ONLY video-related fields, keep model loaded
@@ -353,7 +351,6 @@ def _ensure_color_for_obj(obj_id: int):
353
  GLOBAL_STATE.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
354
 
355
 
356
- @spaces.GPU()
357
  def on_image_click(
358
  img: Image.Image | np.ndarray,
359
  state: AppState,
@@ -485,7 +482,6 @@ def on_image_click(
485
  return update_frame_display(GLOBAL_STATE, int(frame_idx))
486
 
487
 
488
- @spaces.GPU()
489
  def propagate_masks(state: AppState, progress=gr.Progress()):
490
  if state is None or state.inference_session is None:
491
  yield "Load a video first."
@@ -567,7 +563,9 @@ def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
567
  return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
568
 
569
 
570
- with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation") as demo:
 
 
571
  state = gr.State(GLOBAL_STATE)
572
 
573
  gr.Markdown("""
 
4
 
5
  import gradio as gr
6
  import numpy as np
 
7
  import torch
8
+ from gradio.themes import Soft
9
  from PIL import Image, ImageDraw
10
 
11
  # Prefer local transformers in the workspace
 
171
  return mapping.get(key, mapping["base_plus"])
172
 
173
 
 
174
  def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
175
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
176
  if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
 
237
  GLOBAL_STATE.session_repo_id = desired_repo
238
 
239
 
 
240
  def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
241
  """Gradio handler: load video, init session, return state, slider bounds, and first frame."""
242
  # Reset ONLY video-related fields, keep model loaded
 
351
  GLOBAL_STATE.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
352
 
353
 
 
354
  def on_image_click(
355
  img: Image.Image | np.ndarray,
356
  state: AppState,
 
482
  return update_frame_display(GLOBAL_STATE, int(frame_idx))
483
 
484
 
 
485
  def propagate_masks(state: AppState, progress=gr.Progress()):
486
  if state is None or state.inference_session is None:
487
  yield "Load a video first."
 
563
  return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
564
 
565
 
566
+ theme = Soft(primary_hue="indigo", secondary_hue="rose", neutral_hue="slate")
567
+
568
+ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo:
569
  state = gr.State(GLOBAL_STATE)
570
 
571
  gr.Markdown("""