yonigozlan HF Staff commited on
Commit
e8e6a3a
Β·
1 Parent(s): 2974ac3
Files changed (7) hide show
  1. .gitattributes +7 -0
  2. README.md +6 -5
  3. app.py +839 -0
  4. deers.mp4 +3 -0
  5. foot.mp4 +3 -0
  6. penguins.mp4 +3 -0
  7. requirements.txt +9 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tennis.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ basket.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ football.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ hurdles.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ deers.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ foot.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ penguins.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Edgetam
3
- emoji: 🏒
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.47.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Segment Anything 2 Video Tracking
3
+ emoji: πŸ‘€
4
+ colorFrom: purple
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Segment any objects and track them through a video with SAM2
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import gc
3
+ from typing import Optional
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import spaces
8
+ import torch
9
+ from gradio.themes import Soft
10
+ from PIL import Image, ImageDraw
11
+
12
+ # Prefer local transformers in the workspace
13
+ from transformers import AutoModel, Sam2VideoProcessor
14
+
15
+
16
+ def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
17
+ """Generate a deterministic pastel RGB color for a given object id.
18
+
19
+ Uses golden ratio to distribute hues; low-medium saturation, high value.
20
+ """
21
+ golden_ratio_conjugate = 0.61803398875
22
+ # Map obj_id (1-based) to hue in [0,1)
23
+ hue = (obj_id * golden_ratio_conjugate) % 1.0
24
+ saturation = 0.45
25
+ value = 1.0
26
+ r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value)
27
+ return int(r_f * 255), int(g_f * 255), int(b_f * 255)
28
+
29
+
30
+ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
31
+ """Load video frames as PIL Images using transformers.video_utils if available,
32
+ otherwise fall back to OpenCV. Returns (frames, info).
33
+ """
34
+ try:
35
+ from transformers.video_utils import load_video # type: ignore
36
+
37
+ frames, info = load_video(video_path_or_url)
38
+ # Ensure PIL format
39
+ pil_frames = []
40
+ for fr in frames:
41
+ if isinstance(fr, Image.Image):
42
+ pil_frames.append(fr.convert("RGB"))
43
+ else:
44
+ pil_frames.append(Image.fromarray(fr).convert("RGB"))
45
+ info = info if info is not None else {}
46
+ # Ensure fps present when possible (fallback to cv2 probe)
47
+ if "fps" not in info or not info.get("fps"):
48
+ try:
49
+ import cv2 # type: ignore
50
+
51
+ cap = cv2.VideoCapture(video_path_or_url)
52
+ fps_val = cap.get(cv2.CAP_PROP_FPS)
53
+ cap.release()
54
+ if fps_val and fps_val > 0:
55
+ info["fps"] = float(fps_val)
56
+ except Exception:
57
+ pass
58
+ return pil_frames, info
59
+ except Exception:
60
+ # Fallback to OpenCV
61
+ try:
62
+ import cv2 # type: ignore
63
+
64
+ cap = cv2.VideoCapture(video_path_or_url)
65
+ frames = []
66
+ while cap.isOpened():
67
+ ret, frame = cap.read()
68
+ if not ret:
69
+ break
70
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
+ frames.append(Image.fromarray(frame_rgb))
72
+ # Gather fps if available
73
+ fps_val = cap.get(cv2.CAP_PROP_FPS)
74
+ cap.release()
75
+ info = {
76
+ "num_frames": len(frames),
77
+ "fps": float(fps_val) if fps_val and fps_val > 0 else None,
78
+ }
79
+ return frames, info
80
+ except Exception as e:
81
+ raise RuntimeError(f"Failed to load video: {e}")
82
+
83
+
84
+ def overlay_masks_on_frame(
85
+ frame: Image.Image,
86
+ masks_per_object: dict[int, np.ndarray],
87
+ color_by_obj: dict[int, tuple[int, int, int]],
88
+ alpha: float = 0.5,
89
+ ) -> Image.Image:
90
+ """Overlay per-object soft masks onto the RGB frame.
91
+
92
+ masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1]
93
+ color_by_obj: mapping of obj_id -> (R, G, B)
94
+ """
95
+ base = np.array(frame).astype(np.float32) / 255.0 # H, W, 3 in [0,1]
96
+ height, width = base.shape[:2]
97
+ overlay = base.copy()
98
+
99
+ for obj_id, mask in masks_per_object.items():
100
+ if mask is None:
101
+ continue
102
+ if mask.dtype != np.float32:
103
+ mask = mask.astype(np.float32)
104
+ # Ensure shape is H x W
105
+ if mask.ndim == 3:
106
+ mask = mask.squeeze()
107
+ mask = np.clip(mask, 0.0, 1.0)
108
+ color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0
109
+ # Blend: overlay = (1 - a*m)*overlay + (a*m)*color
110
+ a = alpha
111
+ m = mask[..., None]
112
+ overlay = (1.0 - a * m) * overlay + (a * m) * color
113
+
114
+ out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
115
+ return Image.fromarray(out)
116
+
117
+
118
+ def get_device_and_dtype() -> tuple[str, torch.dtype]:
119
+ device = "cpu"
120
+ dtype = torch.bfloat16
121
+ return device, dtype
122
+
123
+
124
+ class AppState:
125
+ def __init__(self):
126
+ self.reset()
127
+
128
+ def reset(self):
129
+ self.video_frames: list[Image.Image] = []
130
+ self.inference_session = None
131
+ self.model: Optional[AutoModel] = None
132
+ self.processor: Optional[Sam2VideoProcessor] = None
133
+ self.device: str = "cuda"
134
+ self.dtype: torch.dtype = torch.bfloat16
135
+ self.video_fps: float | None = None
136
+ self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
137
+ self.color_by_obj: dict[int, tuple[int, int, int]] = {}
138
+ self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {}
139
+ self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {}
140
+ # Cache of composited frames (original + masks + clicks)
141
+ self.composited_frames: dict[int, Image.Image] = {}
142
+ # UI state for click handler
143
+ self.current_frame_idx: int = 0
144
+ self.current_obj_id: int = 1
145
+ self.current_label: str = "positive"
146
+ self.current_clear_old: bool = True
147
+ self.current_prompt_type: str = "Points" # or "Boxes"
148
+ self.pending_box_start: tuple[int, int] | None = None
149
+ self.pending_box_start_frame_idx: int | None = None
150
+ self.pending_box_start_obj_id: int | None = None
151
+ self.is_switching_model: bool = False
152
+ # Model selection
153
+ self.model_repo_key: str = "tiny"
154
+ self.model_repo_id: str | None = None
155
+ self.session_repo_id: str | None = None
156
+
157
+ @property
158
+ def num_frames(self) -> int:
159
+ return len(self.video_frames)
160
+
161
+
162
+ def _model_repo_from_key(key: str) -> str:
163
+ mapping = {
164
+ "tiny": "facebook/sam2.1-hiera-tiny",
165
+ "small": "facebook/sam2.1-hiera-small",
166
+ "base_plus": "facebook/sam2.1-hiera-base-plus",
167
+ "large": "facebook/sam2.1-hiera-large",
168
+ "EdgeTAM": "../EdgeTAM-hf",
169
+ }
170
+ return mapping.get(key, mapping["base_plus"])
171
+
172
+
173
+ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]:
174
+ desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
175
+ if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
176
+ if GLOBAL_STATE.model_repo_id == desired_repo:
177
+ return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
178
+ # Different repo requested: dispose current and reload
179
+ try:
180
+ del GLOBAL_STATE.model
181
+ except Exception:
182
+ pass
183
+ try:
184
+ del GLOBAL_STATE.processor
185
+ except Exception:
186
+ pass
187
+ GLOBAL_STATE.model = None
188
+ GLOBAL_STATE.processor = None
189
+ print(f"Loading model from {desired_repo}")
190
+ device, dtype = get_device_and_dtype()
191
+ # free up the gpu memory
192
+ torch.cuda.empty_cache()
193
+ gc.collect()
194
+ model = AutoModel.from_pretrained(desired_repo)
195
+ processor = Sam2VideoProcessor.from_pretrained(desired_repo)
196
+ model.to(device, dtype=dtype)
197
+
198
+ GLOBAL_STATE.model = model
199
+ GLOBAL_STATE.processor = processor
200
+ GLOBAL_STATE.device = device
201
+ GLOBAL_STATE.dtype = dtype
202
+ GLOBAL_STATE.model_repo_id = desired_repo
203
+
204
+ return model, processor, device, dtype
205
+
206
+
207
+ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
208
+ """Ensure the model/processor match the selected repo and inference_session exists.
209
+ If a video is already loaded, re-initialize the inference session when needed.
210
+ """
211
+ model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
212
+ desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
213
+ if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
214
+ if GLOBAL_STATE.video_frames:
215
+ # Clear session-related UI caches when switching model
216
+ GLOBAL_STATE.masks_by_frame.clear()
217
+ GLOBAL_STATE.clicks_by_frame_obj.clear()
218
+ GLOBAL_STATE.boxes_by_frame_obj.clear()
219
+ GLOBAL_STATE.composited_frames.clear()
220
+ # Dispose previous session cleanly
221
+ try:
222
+ if GLOBAL_STATE.inference_session is not None:
223
+ GLOBAL_STATE.inference_session.reset_inference_session()
224
+ except Exception:
225
+ pass
226
+ GLOBAL_STATE.inference_session = None
227
+ gc.collect()
228
+ try:
229
+ if torch.cuda.is_available():
230
+ torch.cuda.empty_cache()
231
+ except Exception:
232
+ pass
233
+ GLOBAL_STATE.inference_session = processor.init_video_session(
234
+ video=GLOBAL_STATE.video_frames,
235
+ inference_device=device,
236
+ video_storage_device="cpu",
237
+ dtype=dtype,
238
+ )
239
+ GLOBAL_STATE.session_repo_id = desired_repo
240
+
241
+
242
+ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
243
+ """Gradio handler: load video, init session, return state, slider bounds, and first frame."""
244
+ # Reset ONLY video-related fields, keep model loaded
245
+ GLOBAL_STATE.video_frames = []
246
+ GLOBAL_STATE.inference_session = None
247
+ GLOBAL_STATE.masks_by_frame = {}
248
+ GLOBAL_STATE.color_by_obj = {}
249
+
250
+ model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
251
+
252
+ # Gradio Video may provide a dict with 'name' or a direct file path
253
+ video_path: Optional[str] = None
254
+ if isinstance(video, dict):
255
+ video_path = video.get("name") or video.get("path") or video.get("data")
256
+ elif isinstance(video, str):
257
+ video_path = video
258
+ else:
259
+ video_path = None
260
+
261
+ if not video_path:
262
+ raise gr.Error("Invalid video input.")
263
+
264
+ frames, info = try_load_video_frames(video_path)
265
+ if len(frames) == 0:
266
+ raise gr.Error("No frames could be loaded from the video.")
267
+
268
+ # Enforce max duration of 8 seconds (trim if longer)
269
+ MAX_SECONDS = 8.0
270
+ trimmed_note = ""
271
+ fps_in = None
272
+ if isinstance(info, dict) and info.get("fps"):
273
+ try:
274
+ fps_in = float(info["fps"]) or None
275
+ except Exception:
276
+ fps_in = None
277
+ if fps_in is not None:
278
+ max_frames_allowed = int(MAX_SECONDS * fps_in)
279
+ if len(frames) > max_frames_allowed:
280
+ frames = frames[:max_frames_allowed]
281
+ trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
282
+ if isinstance(info, dict):
283
+ info["num_frames"] = len(frames)
284
+ else:
285
+ # Fallback when FPS unknown: assume ~30 FPS and cap to 240 frames (~8s)
286
+ max_frames_allowed = 240
287
+ if len(frames) > max_frames_allowed:
288
+ frames = frames[:max_frames_allowed]
289
+ trimmed_note = " (trimmed to 240 frames ~8s @30fps)"
290
+ if isinstance(info, dict):
291
+ info["num_frames"] = len(frames)
292
+
293
+ GLOBAL_STATE.video_frames = frames
294
+ # Try to capture original FPS if provided by loader
295
+ GLOBAL_STATE.video_fps = None
296
+ if isinstance(info, dict) and info.get("fps"):
297
+ try:
298
+ GLOBAL_STATE.video_fps = float(info["fps"]) or None
299
+ except Exception:
300
+ GLOBAL_STATE.video_fps = None
301
+
302
+ # Initialize session
303
+ inference_session = processor.init_video_session(
304
+ video=frames,
305
+ inference_device=device,
306
+ video_storage_device="cpu",
307
+ dtype=dtype,
308
+ )
309
+ GLOBAL_STATE.inference_session = inference_session
310
+
311
+ first_frame = frames[0]
312
+ max_idx = len(frames) - 1
313
+ status = (
314
+ f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
315
+ f"Device: {device}, dtype: bfloat16"
316
+ )
317
+ return GLOBAL_STATE, 0, max_idx, first_frame, status
318
+
319
+
320
+ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
321
+ if state is None or state.video_frames is None or len(state.video_frames) == 0:
322
+ return None
323
+ frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
324
+ frame = state.video_frames[frame_idx]
325
+ masks = state.masks_by_frame.get(frame_idx, {})
326
+ out_img = frame
327
+ if len(masks) != 0:
328
+ out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65)
329
+
330
+ # Draw crosses for conditioning frames only (frames with recorded clicks)
331
+ clicks_map = state.clicks_by_frame_obj.get(frame_idx)
332
+ if clicks_map:
333
+ draw = ImageDraw.Draw(out_img)
334
+ cross_half = 6
335
+ for obj_id, pts in clicks_map.items():
336
+ for x, y, lbl in pts:
337
+ color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0)
338
+ # horizontal
339
+ draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
340
+ # vertical
341
+ draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
342
+ # Draw temporary cross for first corner in box mode
343
+ if (
344
+ state.pending_box_start is not None
345
+ and state.pending_box_start_frame_idx == frame_idx
346
+ and state.pending_box_start_obj_id is not None
347
+ ):
348
+ draw = ImageDraw.Draw(out_img)
349
+ x, y = state.pending_box_start
350
+ cross_half = 6
351
+ color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255))
352
+ draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
353
+ draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
354
+ # Draw boxes for conditioning frames
355
+ box_map = state.boxes_by_frame_obj.get(frame_idx)
356
+ if box_map:
357
+ draw = ImageDraw.Draw(out_img)
358
+ for obj_id, boxes in box_map.items():
359
+ color = state.color_by_obj.get(obj_id, (255, 255, 255))
360
+ for x1, y1, x2, y2 in boxes:
361
+ draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
362
+ # Save to cache and return
363
+ state.composited_frames[frame_idx] = out_img
364
+ return out_img
365
+
366
+
367
+ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
368
+ if state is None or state.video_frames is None or len(state.video_frames) == 0:
369
+ return None
370
+ frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
371
+ # Serve from cache when available
372
+ cached = state.composited_frames.get(frame_idx)
373
+ if cached is not None:
374
+ return cached
375
+ return compose_frame(state, frame_idx)
376
+
377
+
378
+ def _ensure_color_for_obj(state: AppState, obj_id: int):
379
+ if obj_id not in state.color_by_obj:
380
+ state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
381
+
382
+
383
+ def on_image_click(
384
+ img: Image.Image | np.ndarray,
385
+ state: AppState,
386
+ frame_idx: int,
387
+ obj_id: int,
388
+ label: str,
389
+ clear_old: bool,
390
+ evt: gr.SelectData,
391
+ ) -> Image.Image:
392
+ if state is None or state.inference_session is None:
393
+ return img # no-op preview when not ready
394
+ if state.is_switching_model:
395
+ # Gracefully ignore input during model switch; return current preview unchanged
396
+ return update_frame_display(state, int(frame_idx))
397
+
398
+ # Parse click coordinates from event
399
+ x = y = None
400
+ if evt is not None:
401
+ # Try different gradio event data shapes for robustness
402
+ try:
403
+ if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2:
404
+ x, y = int(evt.index[0]), int(evt.index[1])
405
+ elif hasattr(evt, "value") and isinstance(evt.value, dict) and "x" in evt.value and "y" in evt.value:
406
+ x, y = int(evt.value["x"]), int(evt.value["y"])
407
+ except Exception:
408
+ x = y = None
409
+
410
+ if x is None or y is None:
411
+ raise gr.Error("Could not read click coordinates.")
412
+
413
+ _ensure_color_for_obj(state, int(obj_id))
414
+
415
+ processor = state.processor
416
+ model = state.model
417
+ inference_session = state.inference_session
418
+
419
+ if state.current_prompt_type == "Boxes":
420
+ # Two-click box input
421
+ if state.pending_box_start is None:
422
+ # For boxes, always clear old inputs (points) for this object on this frame
423
+ frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
424
+ frame_clicks[int(obj_id)] = []
425
+ state.composited_frames.pop(int(frame_idx), None)
426
+ state.pending_box_start = (int(x), int(y))
427
+ state.pending_box_start_frame_idx = int(frame_idx)
428
+ state.pending_box_start_obj_id = int(obj_id)
429
+ # Invalidate cache so temporary cross is drawn
430
+ state.composited_frames.pop(int(frame_idx), None)
431
+ return update_frame_display(state, int(frame_idx))
432
+ else:
433
+ x1, y1 = state.pending_box_start
434
+ x2, y2 = int(x), int(y)
435
+ # Clear temporary state and invalidate cache
436
+ state.pending_box_start = None
437
+ state.pending_box_start_frame_idx = None
438
+ state.pending_box_start_obj_id = None
439
+ state.composited_frames.pop(int(frame_idx), None)
440
+ x_min, y_min = min(x1, x2), min(y1, y2)
441
+ x_max, y_max = max(x1, x2), max(y1, y2)
442
+
443
+ processor.add_inputs_to_inference_session(
444
+ inference_session=inference_session,
445
+ frame_idx=int(frame_idx),
446
+ obj_ids=int(obj_id),
447
+ input_boxes=[[[x_min, y_min, x_max, y_max]]],
448
+ clear_old_inputs=True, # For boxes, always clear old inputs
449
+ )
450
+
451
+ frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
452
+ obj_boxes = frame_boxes.setdefault(int(obj_id), [])
453
+ # For boxes, always clear old inputs
454
+ obj_boxes.clear()
455
+ obj_boxes.append((x_min, y_min, x_max, y_max))
456
+ state.composited_frames.pop(int(frame_idx), None)
457
+ else:
458
+ # Points mode
459
+ label_int = 1 if str(label).lower().startswith("pos") else 0
460
+ # If clear_old is enabled, clear prior boxes for this object on this frame
461
+ if bool(clear_old):
462
+ frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
463
+ frame_boxes[int(obj_id)] = []
464
+ state.composited_frames.pop(int(frame_idx), None)
465
+ processor.add_inputs_to_inference_session(
466
+ inference_session=inference_session,
467
+ frame_idx=int(frame_idx),
468
+ obj_ids=int(obj_id),
469
+ input_points=[[[[int(x), int(y)]]]],
470
+ input_labels=[[[int(label_int)]]],
471
+ clear_old_inputs=bool(clear_old),
472
+ )
473
+
474
+ frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
475
+ obj_clicks = frame_clicks.setdefault(int(obj_id), [])
476
+ if bool(clear_old):
477
+ obj_clicks.clear()
478
+ obj_clicks.append((int(x), int(y), int(label_int)))
479
+ state.composited_frames.pop(int(frame_idx), None)
480
+
481
+ # Forward on that frame
482
+ with torch.inference_mode():
483
+ outputs = model(
484
+ inference_session=inference_session,
485
+ frame_idx=int(frame_idx),
486
+ )
487
+
488
+ H = inference_session.video_height
489
+ W = inference_session.video_width
490
+ # Detach and move off GPU as early as possible to reduce GPU memory pressure
491
+ pred_masks = outputs.pred_masks.detach().cpu()
492
+ video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
493
+
494
+ # Map returned masks to object ids. For single object forward, it's [1, 1, H, W]
495
+ # But to be safe, iterate over session.obj_ids order.
496
+ masks_for_frame: dict[int, np.ndarray] = {}
497
+ obj_ids_order = list(inference_session.obj_ids)
498
+ for i, oid in enumerate(obj_ids_order):
499
+ mask_i = video_res_masks[i]
500
+ # mask_i shape could be (1, H, W) or (H, W); squeeze to 2D
501
+ mask_2d = mask_i.cpu().numpy().squeeze()
502
+ masks_for_frame[int(oid)] = mask_2d
503
+
504
+ state.masks_by_frame[int(frame_idx)] = masks_for_frame
505
+ # Invalidate cache for this frame to force recomposition
506
+ state.composited_frames.pop(int(frame_idx), None)
507
+
508
+ # Return updated preview
509
+ return update_frame_display(state, int(frame_idx))
510
+
511
+
512
+ @spaces.GPU()
513
+ def propagate_masks(GLOBAL_STATE: gr.State):
514
+ if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
515
+ yield "Load a video first.", gr.update()
516
+ return
517
+
518
+ processor = GLOBAL_STATE.processor
519
+ model = GLOBAL_STATE.model
520
+ inference_session = GLOBAL_STATE.inference_session
521
+ # set inference device to cuda to use zero gpu
522
+ inference_session.inference_device = "cuda"
523
+ inference_session.cache.inference_device = "cuda"
524
+ model.to("cuda")
525
+
526
+ total = max(1, GLOBAL_STATE.num_frames)
527
+ processed = 0
528
+
529
+ # Initial status; no slider change yet
530
+ yield f"Propagating masks: {processed}/{total}", gr.update()
531
+
532
+ last_frame_idx = 0
533
+ with torch.inference_mode():
534
+ for sam2_video_output in model.propagate_in_video_iterator(inference_session):
535
+ H = inference_session.video_height
536
+ W = inference_session.video_width
537
+ pred_masks = sam2_video_output.pred_masks.detach().cpu()
538
+ video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
539
+
540
+ frame_idx = int(sam2_video_output.frame_idx)
541
+ last_frame_idx = frame_idx
542
+ masks_for_frame: dict[int, np.ndarray] = {}
543
+ obj_ids_order = list(inference_session.obj_ids)
544
+ for i, oid in enumerate(obj_ids_order):
545
+ mask_2d = video_res_masks[i].cpu().numpy().squeeze()
546
+ masks_for_frame[int(oid)] = mask_2d
547
+ GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
548
+ # Invalidate cache for that frame to force recomposition
549
+ GLOBAL_STATE.composited_frames.pop(frame_idx, None)
550
+
551
+ processed += 1
552
+ # Every 15th frame (or last), move slider to current frame to update preview via slider binding
553
+ if processed % 15 == 0 or processed == total:
554
+ yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
555
+ else:
556
+ yield f"Propagating masks: {processed}/{total}", gr.update()
557
+
558
+ model.to("cpu")
559
+ inference_session.inference_device = "cpu"
560
+ inference_session.cache.inference_device = "cpu"
561
+ gc.collect()
562
+ torch.cuda.empty_cache()
563
+
564
+ # Final status; ensure slider points to last processed frame
565
+ yield (
566
+ f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
567
+ gr.update(value=last_frame_idx),
568
+ )
569
+
570
+
571
+ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
572
+ # Reset only session-related state, keep uploaded video and model
573
+ if not GLOBAL_STATE.video_frames:
574
+ # Nothing loaded; keep behavior
575
+ return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video."
576
+
577
+ # Clear prompts and caches
578
+ GLOBAL_STATE.masks_by_frame.clear()
579
+ GLOBAL_STATE.clicks_by_frame_obj.clear()
580
+ GLOBAL_STATE.boxes_by_frame_obj.clear()
581
+ GLOBAL_STATE.composited_frames.clear()
582
+ GLOBAL_STATE.pending_box_start = None
583
+ GLOBAL_STATE.pending_box_start_frame_idx = None
584
+ GLOBAL_STATE.pending_box_start_obj_id = None
585
+
586
+ # Dispose and re-init inference session for current model with existing frames
587
+ try:
588
+ if GLOBAL_STATE.inference_session is not None:
589
+ GLOBAL_STATE.inference_session.reset_inference_session()
590
+ except Exception:
591
+ pass
592
+ GLOBAL_STATE.inference_session = None
593
+ gc.collect()
594
+ try:
595
+ if torch.cuda.is_available():
596
+ torch.cuda.empty_cache()
597
+ except Exception:
598
+ pass
599
+ ensure_session_for_current_model(GLOBAL_STATE)
600
+
601
+ # Keep current slider index if possible
602
+ current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
603
+ current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
604
+ preview_img = update_frame_display(GLOBAL_STATE, current_idx)
605
+ slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
606
+ slider_value = gr.update(value=current_idx)
607
+ status = "Session reset. Prompts cleared; video preserved."
608
+ # clear and reload model and processor
609
+ return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
610
+
611
+
612
+ theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")
613
+
614
+ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo:
615
+ GLOBAL_STATE = gr.State(AppState())
616
+
617
+ gr.Markdown(
618
+ """
619
+ ### SAM2 Video Tracking Β· powered by Hugging Face πŸ€— Transformers
620
+ Segment and track objects across a video with SAM2 (Segment Anything 2). This demo runs the official implementation from the Hugging Face Transformers library for interactive, promptable video segmentation.
621
+ """
622
+ )
623
+ with gr.Row():
624
+ with gr.Column():
625
+ gr.Markdown(
626
+ """
627
+ **Quick start**
628
+ - **Load a video**: Upload your own or pick an example below.
629
+ - **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy).
630
+ - **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames.
631
+ - **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically.
632
+ """
633
+ )
634
+ with gr.Column():
635
+ gr.Markdown(
636
+ """
637
+ **Working with results**
638
+ - **Preview**: Use the slider to navigate frames and see the current masks.
639
+ - **Propagate**: Click β€œPropagate across video” to track all defined objects through the entire video. The preview follows progress periodically to keep things responsive.
640
+ - **Export**: Render an MP4 for smooth playback using the original video FPS.
641
+ - **Note**: More info on the Hugging Face πŸ€— Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
642
+ """
643
+ )
644
+
645
+ with gr.Row():
646
+ with gr.Column(scale=1):
647
+ video_in = gr.Video(label="Upload video", sources=["upload", "webcam"], interactive=True)
648
+ ckpt_radio = gr.Radio(
649
+ choices=["tiny", "small", "base_plus", "large", "EdgeTAM"],
650
+ value="tiny",
651
+ label="SAM2.1 checkpoint",
652
+ )
653
+ ckpt_progress = gr.Markdown(visible=False)
654
+ load_status = gr.Markdown(visible=True)
655
+ reset_btn = gr.Button("Reset Session", variant="secondary")
656
+ with gr.Column(scale=2):
657
+ preview = gr.Image(label="Preview", interactive=True)
658
+ with gr.Row():
659
+ frame_slider = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True)
660
+ with gr.Column(scale=0):
661
+ propagate_btn = gr.Button("Propagate across video", variant="primary")
662
+ propagate_status = gr.Markdown(visible=True)
663
+ with gr.Row():
664
+ obj_id_inp = gr.Number(value=1, precision=0, label="Object ID", scale=0)
665
+ label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
666
+ clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
667
+ prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
668
+
669
+ # Wire events
670
+ def _on_video_change(GLOBAL_STATE: gr.State, video):
671
+ GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
672
+ return (
673
+ GLOBAL_STATE,
674
+ gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
675
+ first_frame,
676
+ status,
677
+ )
678
+
679
+ video_in.change(
680
+ _on_video_change,
681
+ inputs=[GLOBAL_STATE, video_in],
682
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status],
683
+ show_progress=True,
684
+ )
685
+
686
+ # (moved) Examples are defined above the render button
687
+ # Each example row must match the number of inputs (GLOBAL_STATE, video_in)
688
+ examples_list = [
689
+ [None, "./deers.mp4"],
690
+ [None, "./penguins.mp4"],
691
+ [None, "./foot.mp4"],
692
+ ]
693
+ with gr.Row():
694
+ gr.Examples(
695
+ examples=examples_list,
696
+ inputs=[GLOBAL_STATE, video_in],
697
+ fn=_on_video_change,
698
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status],
699
+ label="Examples",
700
+ cache_examples=False,
701
+ examples_per_page=5,
702
+ )
703
+ # Examples (place before the render MP4 button) β€” defined after handler below
704
+
705
+ with gr.Row():
706
+ render_btn = gr.Button("Render MP4 for smooth playback", variant="primary")
707
+ playback_video = gr.Video(label="Rendered Playback", interactive=False)
708
+
709
+ def _on_ckpt_change(s: AppState, key: str):
710
+ if s is not None and key:
711
+ key = str(key)
712
+ if key != s.model_repo_key:
713
+ # Update and drop current model to reload lazily next time
714
+ s.is_switching_model = True
715
+ s.model_repo_key = key
716
+ s.model_repo_id = None
717
+ s.model = None
718
+ s.processor = None
719
+ # Stream progress text while loading (first yield shows text)
720
+ yield gr.update(visible=True, value=f"Loading checkpoint: {key}...")
721
+ ensure_session_for_current_model(s)
722
+ if s is not None:
723
+ s.is_switching_model = False
724
+ # Final yield hides the text
725
+ yield gr.update(visible=False, value="")
726
+
727
+ ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress])
728
+
729
+ def _sync_frame_idx(state_in: AppState, idx: int):
730
+ if state_in is not None:
731
+ state_in.current_frame_idx = int(idx)
732
+ return update_frame_display(state_in, int(idx))
733
+
734
+ frame_slider.change(
735
+ _sync_frame_idx,
736
+ inputs=[GLOBAL_STATE, frame_slider],
737
+ outputs=preview,
738
+ )
739
+
740
+ def _sync_obj_id(s: AppState, oid):
741
+ if s is not None and oid is not None:
742
+ s.current_obj_id = int(oid)
743
+ return gr.update()
744
+
745
+ obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[])
746
+
747
+ def _sync_label(s: AppState, lab: str):
748
+ if s is not None and lab is not None:
749
+ s.current_label = str(lab)
750
+ return gr.update()
751
+
752
+ label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[])
753
+
754
+ def _sync_prompt_type(s: AppState, val: str):
755
+ if s is not None and val is not None:
756
+ s.current_prompt_type = str(val)
757
+ s.pending_box_start = None
758
+ is_points = str(val).lower() == "points"
759
+ # Show labels only for points; hide and disable clear_old when boxes
760
+ updates = [
761
+ gr.update(visible=is_points),
762
+ gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
763
+ ]
764
+ return updates
765
+
766
+ prompt_type.change(
767
+ _sync_prompt_type,
768
+ inputs=[GLOBAL_STATE, prompt_type],
769
+ outputs=[label_radio, clear_old_chk],
770
+ )
771
+
772
+ # Image click to add a point and run forward on that frame
773
+ preview.select(
774
+ on_image_click, [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], preview
775
+ )
776
+
777
+ # Playback via MP4 rendering only
778
+
779
+ # Render a smooth MP4 using imageio/pyav (fallbacks to imageio v2 / OpenCV)
780
+ def _render_video(s: AppState):
781
+ if s is None or s.num_frames == 0:
782
+ raise gr.Error("Load a video first.")
783
+ fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
784
+ # Compose all frames (cache will help if already prepared)
785
+ frames_np = []
786
+ first = compose_frame(s, 0)
787
+ h, w = first.size[1], first.size[0]
788
+ for idx in range(s.num_frames):
789
+ img = s.composited_frames.get(idx)
790
+ if img is None:
791
+ img = compose_frame(s, idx)
792
+ frames_np.append(np.array(img)[:, :, ::-1]) # BGR for cv2
793
+ # Periodically release CPU mem to reduce pressure
794
+ if (idx + 1) % 60 == 0:
795
+ gc.collect()
796
+ out_path = "/tmp/sam2_playback.mp4"
797
+ # Prefer imageio with PyAV/ffmpeg to respect exact fps
798
+ try:
799
+ import imageio.v3 as iio # type: ignore
800
+
801
+ iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
802
+ return out_path
803
+ except Exception:
804
+ # Fallbacks
805
+ try:
806
+ import imageio.v2 as imageio # type: ignore
807
+
808
+ imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
809
+ return out_path
810
+ except Exception:
811
+ try:
812
+ import cv2 # type: ignore
813
+
814
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
815
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
816
+ for fr_bgr in frames_np:
817
+ writer.write(fr_bgr)
818
+ writer.release()
819
+ return out_path
820
+ except Exception as e:
821
+ raise gr.Error(f"Failed to render video: {e}")
822
+
823
+ render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
824
+
825
+ # While propagating, we stream two outputs: status text and slider value updates
826
+ propagate_btn.click(
827
+ propagate_masks,
828
+ inputs=[GLOBAL_STATE],
829
+ outputs=[propagate_status, frame_slider],
830
+ )
831
+
832
+ reset_btn.click(
833
+ reset_session,
834
+ inputs=GLOBAL_STATE,
835
+ outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status],
836
+ )
837
+
838
+
839
+ demo.queue(api_open=False).launch()
deers.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e60c4974bbfff98d16e8f264a54d9f84084c5591fdb8455d64449561eb74714
3
+ size 3401495
foot.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e7f86a74b9fa12322024ce4e60c27a2c86acf65abfa32b0a3e3dc44163de96b
3
+ size 2359941
penguins.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a7776418857bd05405fa055cce364f122eafd418be489e88ff7955b4dfd427a
3
+ size 4573098
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ git+https://github.com/SangbumChoi/transformers.git@sam2
3
+ torch
4
+ torchvision
5
+ pillow
6
+ opencv-python
7
+ imageio[pyav]
8
+ spaces
9
+