yonigozlan HF Staff commited on
Commit
06c9ffc
·
1 Parent(s): 189fb9e

update app

Browse files
Files changed (1) hide show
  1. app.py +184 -44
app.py CHANGED
@@ -8,11 +8,17 @@ import spaces
8
  import torch
9
  from PIL import Image, ImageDraw
10
 
 
11
  from transformers import Sam2VideoModel, Sam2VideoProcessor
12
 
13
 
14
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
 
 
 
 
15
  golden_ratio_conjugate = 0.61803398875
 
16
  hue = (obj_id * golden_ratio_conjugate) % 1.0
17
  saturation = 0.45
18
  value = 1.0
@@ -21,10 +27,14 @@ def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
21
 
22
 
23
  def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
 
 
 
24
  try:
25
  from transformers.video_utils import load_video # type: ignore
26
 
27
  frames, info = load_video(video_path_or_url)
 
28
  pil_frames = []
29
  for fr in frames:
30
  if isinstance(fr, Image.Image):
@@ -32,6 +42,7 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
32
  else:
33
  pil_frames.append(Image.fromarray(fr).convert("RGB"))
34
  info = info if info is not None else {}
 
35
  if "fps" not in info or not info.get("fps"):
36
  try:
37
  import cv2 # type: ignore
@@ -45,6 +56,7 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
45
  pass
46
  return pil_frames, info
47
  except Exception:
 
48
  try:
49
  import cv2 # type: ignore
50
 
@@ -56,6 +68,7 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
56
  break
57
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
58
  frames.append(Image.fromarray(frame_rgb))
 
59
  fps_val = cap.get(cv2.CAP_PROP_FPS)
60
  cap.release()
61
  info = {
@@ -71,28 +84,40 @@ def overlay_masks_on_frame(
71
  frame: Image.Image,
72
  masks_per_object: dict[int, np.ndarray],
73
  color_by_obj: dict[int, tuple[int, int, int]],
74
- alpha: float = 0.65,
75
  ) -> Image.Image:
76
- base = np.array(frame).astype(np.float32) / 255.0
 
 
 
 
 
 
77
  overlay = base.copy()
 
78
  for obj_id, mask in masks_per_object.items():
79
  if mask is None:
80
  continue
81
  if mask.dtype != np.float32:
82
  mask = mask.astype(np.float32)
 
83
  if mask.ndim == 3:
84
  mask = mask.squeeze()
85
  mask = np.clip(mask, 0.0, 1.0)
86
  color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0
 
 
87
  m = mask[..., None]
88
- overlay = (1.0 - alpha * m) * overlay + (alpha * m) * color
 
89
  out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
90
  return Image.fromarray(out)
91
 
92
 
93
  def get_device_and_dtype() -> tuple[str, torch.dtype]:
94
- # Force CPU-only on Spaces with zero GPU
95
- return "cpu", torch.float32
 
96
 
97
 
98
  class AppState:
@@ -105,22 +130,25 @@ class AppState:
105
  self.model: Optional[Sam2VideoModel] = None
106
  self.processor: Optional[Sam2VideoProcessor] = None
107
  self.device: str = "cpu"
108
- self.dtype: torch.dtype = torch.float32
109
  self.video_fps: float | None = None
110
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
111
  self.color_by_obj: dict[int, tuple[int, int, int]] = {}
112
  self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {}
113
  self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {}
 
114
  self.composited_frames: dict[int, Image.Image] = {}
 
115
  self.current_frame_idx: int = 0
116
  self.current_obj_id: int = 1
117
  self.current_label: str = "positive"
118
  self.current_clear_old: bool = True
119
- self.current_prompt_type: str = "Points"
120
  self.pending_box_start: tuple[int, int] | None = None
121
  self.pending_box_start_frame_idx: int | None = None
122
  self.pending_box_start_obj_id: int | None = None
123
  self.is_switching_model: bool = False
 
124
  self.model_repo_key: str = "tiny"
125
  self.model_repo_id: str | None = None
126
  self.session_repo_id: str | None = None
@@ -149,6 +177,7 @@ def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, tor
149
  if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
150
  if GLOBAL_STATE.model_repo_id == desired_repo:
151
  return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
 
152
  try:
153
  del GLOBAL_STATE.model
154
  except Exception:
@@ -159,28 +188,37 @@ def load_model_if_needed() -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, tor
159
  pass
160
  GLOBAL_STATE.model = None
161
  GLOBAL_STATE.processor = None
162
-
163
  device, dtype = get_device_and_dtype()
 
164
  model = Sam2VideoModel.from_pretrained(desired_repo, torch_dtype=dtype)
165
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
 
166
  model.to(device)
 
167
  GLOBAL_STATE.model = model
168
  GLOBAL_STATE.processor = processor
169
  GLOBAL_STATE.device = device
170
  GLOBAL_STATE.dtype = dtype
171
  GLOBAL_STATE.model_repo_id = desired_repo
 
172
  return model, processor, device, dtype
173
 
174
 
175
  def ensure_session_for_current_model() -> None:
 
 
 
176
  model, processor, device, dtype = load_model_if_needed()
177
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
178
  if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
179
  if GLOBAL_STATE.video_frames:
 
180
  GLOBAL_STATE.masks_by_frame.clear()
181
  GLOBAL_STATE.clicks_by_frame_obj.clear()
182
  GLOBAL_STATE.boxes_by_frame_obj.clear()
183
  GLOBAL_STATE.composited_frames.clear()
 
184
  try:
185
  if GLOBAL_STATE.inference_session is not None:
186
  GLOBAL_STATE.inference_session.reset_inference_session()
@@ -188,22 +226,29 @@ def ensure_session_for_current_model() -> None:
188
  pass
189
  GLOBAL_STATE.inference_session = None
190
  gc.collect()
 
 
 
 
 
191
  GLOBAL_STATE.inference_session = processor.init_video_session(
192
  video=GLOBAL_STATE.video_frames,
193
  inference_device=device,
194
- video_storage_device="cpu",
195
  )
196
  GLOBAL_STATE.session_repo_id = desired_repo
197
 
198
 
199
- def init_video_session(video: str | dict):
 
 
200
  GLOBAL_STATE.video_frames = []
201
  GLOBAL_STATE.inference_session = None
202
  GLOBAL_STATE.masks_by_frame = {}
203
  GLOBAL_STATE.color_by_obj = {}
204
 
205
- load_model_if_needed()
206
 
 
207
  video_path: Optional[str] = None
208
  if isinstance(video, dict):
209
  video_path = video.get("name") or video.get("path") or video.get("data")
@@ -211,6 +256,7 @@ def init_video_session(video: str | dict):
211
  video_path = video
212
  else:
213
  video_path = None
 
214
  if not video_path:
215
  raise gr.Error("Invalid video input.")
216
 
@@ -219,6 +265,7 @@ def init_video_session(video: str | dict):
219
  raise gr.Error("No frames could be loaded from the video.")
220
 
221
  GLOBAL_STATE.video_frames = frames
 
222
  GLOBAL_STATE.video_fps = None
223
  if isinstance(info, dict) and info.get("fps"):
224
  try:
@@ -226,8 +273,7 @@ def init_video_session(video: str | dict):
226
  except Exception:
227
  GLOBAL_STATE.video_fps = None
228
 
229
- processor = GLOBAL_STATE.processor
230
- device = GLOBAL_STATE.device
231
  inference_session = processor.init_video_session(
232
  video=frames,
233
  inference_device=device,
@@ -237,7 +283,9 @@ def init_video_session(video: str | dict):
237
 
238
  first_frame = frames[0]
239
  max_idx = len(frames) - 1
240
- status = f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps. Device: {device}, dtype: {GLOBAL_STATE.dtype}"
 
 
241
  return GLOBAL_STATE, 0, max_idx, first_frame, status
242
 
243
 
@@ -251,6 +299,7 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
251
  if len(masks) != 0:
252
  out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65)
253
 
 
254
  clicks_map = state.clicks_by_frame_obj.get(frame_idx)
255
  if clicks_map:
256
  draw = ImageDraw.Draw(out_img)
@@ -258,17 +307,11 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
258
  for obj_id, pts in clicks_map.items():
259
  for x, y, lbl in pts:
260
  color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0)
 
261
  draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
 
262
  draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
263
-
264
- box_map = state.boxes_by_frame_obj.get(frame_idx)
265
- if box_map:
266
- draw = ImageDraw.Draw(out_img)
267
- for obj_id, boxes in box_map.items():
268
- color = state.color_by_obj.get(obj_id, (255, 255, 255))
269
- for x1, y1, x2, y2 in boxes:
270
- draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
271
-
272
  if (
273
  state.pending_box_start is not None
274
  and state.pending_box_start_frame_idx == frame_idx
@@ -280,7 +323,15 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
280
  color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255))
281
  draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
282
  draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
283
-
 
 
 
 
 
 
 
 
284
  state.composited_frames[frame_idx] = out_img
285
  return out_img
286
 
@@ -289,6 +340,7 @@ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
289
  if state is None or state.video_frames is None or len(state.video_frames) == 0:
290
  return None
291
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
 
292
  cached = state.composited_frames.get(frame_idx)
293
  if cached is not None:
294
  return cached
@@ -309,14 +361,17 @@ def on_image_click(
309
  label: str,
310
  clear_old: bool,
311
  evt: gr.SelectData,
312
- ):
313
  if state is None or state.inference_session is None:
314
- return img
315
  if state.is_switching_model:
 
316
  return update_frame_display(state, int(frame_idx))
317
 
 
318
  x = y = None
319
  if evt is not None:
 
320
  try:
321
  if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2:
322
  x, y = int(evt.index[0]), int(evt.index[1])
@@ -324,16 +379,20 @@ def on_image_click(
324
  x, y = int(evt.value["x"]), int(evt.value["y"])
325
  except Exception:
326
  x = y = None
 
327
  if x is None or y is None:
328
- return update_frame_display(state, int(frame_idx))
329
 
330
  _ensure_color_for_obj(int(obj_id))
 
331
  processor = GLOBAL_STATE.processor
332
  model = GLOBAL_STATE.model
333
  inference_session = GLOBAL_STATE.inference_session
334
 
335
  if state.current_prompt_type == "Boxes":
 
336
  if state.pending_box_start is None:
 
337
  if bool(clear_old):
338
  frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
339
  frame_clicks[int(obj_id)] = []
@@ -341,11 +400,13 @@ def on_image_click(
341
  state.pending_box_start = (int(x), int(y))
342
  state.pending_box_start_frame_idx = int(frame_idx)
343
  state.pending_box_start_obj_id = int(obj_id)
 
344
  state.composited_frames.pop(int(frame_idx), None)
345
  return update_frame_display(state, int(frame_idx))
346
  else:
347
  x1, y1 = state.pending_box_start
348
  x2, y2 = int(x), int(y)
 
349
  state.pending_box_start = None
350
  state.pending_box_start_frame_idx = None
351
  state.pending_box_start_obj_id = None
@@ -368,7 +429,9 @@ def on_image_click(
368
  obj_boxes.append((x_min, y_min, x_max, y_max))
369
  state.composited_frames.pop(int(frame_idx), None)
370
  else:
 
371
  label_int = 1 if str(label).lower().startswith("pos") else 0
 
372
  if bool(clear_old):
373
  frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
374
  frame_boxes[int(obj_id)] = []
@@ -381,6 +444,7 @@ def on_image_click(
381
  input_labels=[[[int(label_int)]]],
382
  clear_old_inputs=bool(clear_old),
383
  )
 
384
  frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
385
  obj_clicks = frame_clicks.setdefault(int(obj_id), [])
386
  if bool(clear_old):
@@ -388,21 +452,35 @@ def on_image_click(
388
  obj_clicks.append((int(x), int(y), int(label_int)))
389
  state.composited_frames.pop(int(frame_idx), None)
390
 
391
- with torch.inference_mode():
392
- outputs = model(inference_session=inference_session, frame_idx=int(frame_idx))
 
 
 
 
 
393
 
394
  H = inference_session.video_height
395
  W = inference_session.video_width
 
396
  pred_masks = outputs.pred_masks.detach().cpu()
397
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
 
 
 
398
  masks_for_frame: dict[int, np.ndarray] = {}
399
  obj_ids_order = list(inference_session.obj_ids)
400
  for i, oid in enumerate(obj_ids_order):
401
  mask_i = video_res_masks[i]
 
402
  mask_2d = mask_i.cpu().numpy().squeeze()
403
  masks_for_frame[int(oid)] = mask_2d
 
404
  GLOBAL_STATE.masks_by_frame[int(frame_idx)] = masks_for_frame
 
405
  GLOBAL_STATE.composited_frames.pop(int(frame_idx), None)
 
 
406
  return update_frame_display(GLOBAL_STATE, int(frame_idx))
407
 
408
 
@@ -411,18 +489,25 @@ def propagate_masks(state: AppState, progress=gr.Progress()):
411
  if state is None or state.inference_session is None:
412
  yield "Load a video first."
413
  return
 
414
  processor = GLOBAL_STATE.processor
415
  model = GLOBAL_STATE.model
416
  inference_session = GLOBAL_STATE.inference_session
 
417
  total = max(1, GLOBAL_STATE.num_frames)
418
  processed = 0
 
 
419
  yield f"Propagating masks: {processed}/{total}"
420
- with torch.inference_mode():
 
 
421
  for sam2_video_output in model.propagate_in_video_iterator(inference_session):
422
  H = inference_session.video_height
423
  W = inference_session.video_width
424
  pred_masks = sam2_video_output.pred_masks.detach().cpu()
425
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
 
426
  frame_idx = int(sam2_video_output.frame_idx)
427
  masks_for_frame: dict[int, np.ndarray] = {}
428
  obj_ids_order = list(inference_session.obj_ids)
@@ -430,16 +515,24 @@ def propagate_masks(state: AppState, progress=gr.Progress()):
430
  mask_2d = video_res_masks[i].cpu().numpy().squeeze()
431
  masks_for_frame[int(oid)] = mask_2d
432
  GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
 
433
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
 
434
  processed += 1
435
  progress((processed, total), f"Propagating masks: {processed}/{total}")
 
436
  yield f"Propagating masks: {processed}/{total}"
 
437
  yield f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
438
 
439
 
440
- def reset_session():
 
441
  if not GLOBAL_STATE.video_frames:
 
442
  return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video."
 
 
443
  GLOBAL_STATE.masks_by_frame.clear()
444
  GLOBAL_STATE.clicks_by_frame_obj.clear()
445
  GLOBAL_STATE.boxes_by_frame_obj.clear()
@@ -447,6 +540,8 @@ def reset_session():
447
  GLOBAL_STATE.pending_box_start = None
448
  GLOBAL_STATE.pending_box_start_frame_idx = None
449
  GLOBAL_STATE.pending_box_start_obj_id = None
 
 
450
  try:
451
  if GLOBAL_STATE.inference_session is not None:
452
  GLOBAL_STATE.inference_session.reset_inference_session()
@@ -454,7 +549,14 @@ def reset_session():
454
  pass
455
  GLOBAL_STATE.inference_session = None
456
  gc.collect()
 
 
 
 
 
457
  ensure_session_for_current_model()
 
 
458
  current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
459
  current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
460
  preview_img = update_frame_display(GLOBAL_STATE, current_idx)
@@ -464,14 +566,12 @@ def reset_session():
464
  return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
465
 
466
 
467
- with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation (CPU)") as demo:
468
  state = gr.State(GLOBAL_STATE)
469
 
470
- gr.Markdown(
471
- """
472
- **SAM2 Video (Transformers)** — CPU-only Space. Upload a video, click to add positive/negative points per object or draw two-click boxes, preview masks, then propagate across the video. Use the slider to scrub frames.
473
- """
474
- )
475
 
476
  with gr.Row():
477
  with gr.Column(scale=1):
@@ -485,7 +585,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation (CPU)
485
  load_status = gr.Markdown(visible=True)
486
  reset_btn = gr.Button("Reset Session", variant="secondary")
487
  examples_list = [
488
- ["./tennis.mp4"],
 
489
  ]
490
  with gr.Column(scale=2):
491
  preview = gr.Image(label="Preview", interactive=True)
@@ -504,13 +605,23 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation (CPU)
504
  render_btn = gr.Button("Render MP4 for smooth playback")
505
  playback_video = gr.Video(label="Rendered Playback", interactive=False)
506
 
 
507
  def _on_video_change(video):
508
  s, min_idx, max_idx, first_frame, status = init_video_session(video)
509
- return s, gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), first_frame, status
 
 
 
 
 
510
 
511
  video_in.change(
512
- _on_video_change, inputs=[video_in], outputs=[state, frame_slider, preview, load_status], show_progress=True
 
 
 
513
  )
 
514
  gr.Examples(
515
  examples=examples_list,
516
  inputs=[video_in],
@@ -525,21 +636,26 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation (CPU)
525
  if s is not None and key:
526
  key = str(key)
527
  if key != s.model_repo_key:
 
528
  s.is_switching_model = True
529
  s.model_repo_key = key
530
  s.model_repo_id = None
531
  s.model = None
532
  s.processor = None
 
533
  yield gr.update(visible=True, value=f"Loading checkpoint: {key}...")
534
  ensure_session_for_current_model()
535
  if s is not None:
536
  s.is_switching_model = False
 
537
  yield gr.update(visible=False, value="")
538
 
539
  ckpt_radio.change(_on_ckpt_change, inputs=[state, ckpt_radio], outputs=[ckpt_progress])
540
 
 
541
  def _rebind_session_after_ckpt(s: AppState):
542
  ensure_session_for_current_model()
 
543
  if s is not None:
544
  s.pending_box_start = None
545
  return gr.update()
@@ -551,7 +667,11 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation (CPU)
551
  state_in.current_frame_idx = int(idx)
552
  return update_frame_display(state_in, int(idx))
553
 
554
- frame_slider.change(_sync_frame_idx, inputs=[state, frame_slider], outputs=preview)
 
 
 
 
555
 
556
  def _sync_obj_id(s: AppState, oid):
557
  if s is not None and oid is not None:
@@ -576,34 +696,54 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation (CPU)
576
 
577
  prompt_type.change(_sync_prompt_type, inputs=[state, prompt_type], outputs=[label_radio])
578
 
 
579
  preview.select(on_image_click, [preview, state, frame_slider, obj_id_inp, label_radio, clear_old_chk], preview)
580
 
 
 
 
581
  def _render_video(s: AppState):
582
  if s is None or s.num_frames == 0:
583
  raise gr.Error("Load a video first.")
584
  fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
 
585
  frames_np = []
 
 
586
  for idx in range(s.num_frames):
587
  img = s.composited_frames.get(idx)
588
  if img is None:
589
  img = compose_frame(s, idx)
590
- frames_np.append(np.array(img)[:, :, ::-1])
 
591
  if (idx + 1) % 60 == 0:
592
  gc.collect()
593
  out_path = "/tmp/sam2_playback.mp4"
 
594
  try:
595
  import imageio.v3 as iio # type: ignore
596
 
597
  iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
598
  return out_path
599
  except Exception:
 
600
  try:
601
  import imageio.v2 as imageio # type: ignore
602
 
603
  imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
604
  return out_path
605
- except Exception as e:
606
- raise gr.Error(f"Failed to render video: {e}")
 
 
 
 
 
 
 
 
 
 
607
 
608
  render_btn.click(_render_video, inputs=[state], outputs=[playback_video])
609
 
 
8
  import torch
9
  from PIL import Image, ImageDraw
10
 
11
+ # Prefer local transformers in the workspace
12
  from transformers import Sam2VideoModel, Sam2VideoProcessor
13
 
14
 
15
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
16
+ """Generate a deterministic pastel RGB color for a given object id.
17
+
18
+ Uses golden ratio to distribute hues; low-medium saturation, high value.
19
+ """
20
  golden_ratio_conjugate = 0.61803398875
21
+ # Map obj_id (1-based) to hue in [0,1)
22
  hue = (obj_id * golden_ratio_conjugate) % 1.0
23
  saturation = 0.45
24
  value = 1.0
 
27
 
28
 
29
  def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
30
+ """Load video frames as PIL Images using transformers.video_utils if available,
31
+ otherwise fall back to OpenCV. Returns (frames, info).
32
+ """
33
  try:
34
  from transformers.video_utils import load_video # type: ignore
35
 
36
  frames, info = load_video(video_path_or_url)
37
+ # Ensure PIL format
38
  pil_frames = []
39
  for fr in frames:
40
  if isinstance(fr, Image.Image):
 
42
  else:
43
  pil_frames.append(Image.fromarray(fr).convert("RGB"))
44
  info = info if info is not None else {}
45
+ # Ensure fps present when possible (fallback to cv2 probe)
46
  if "fps" not in info or not info.get("fps"):
47
  try:
48
  import cv2 # type: ignore
 
56
  pass
57
  return pil_frames, info
58
  except Exception:
59
+ # Fallback to OpenCV
60
  try:
61
  import cv2 # type: ignore
62
 
 
68
  break
69
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
70
  frames.append(Image.fromarray(frame_rgb))
71
+ # Gather fps if available
72
  fps_val = cap.get(cv2.CAP_PROP_FPS)
73
  cap.release()
74
  info = {
 
84
  frame: Image.Image,
85
  masks_per_object: dict[int, np.ndarray],
86
  color_by_obj: dict[int, tuple[int, int, int]],
87
+ alpha: float = 0.5,
88
  ) -> Image.Image:
89
+ """Overlay per-object soft masks onto the RGB frame.
90
+
91
+ masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1]
92
+ color_by_obj: mapping of obj_id -> (R, G, B)
93
+ """
94
+ base = np.array(frame).astype(np.float32) / 255.0 # H, W, 3 in [0,1]
95
+ height, width = base.shape[:2]
96
  overlay = base.copy()
97
+
98
  for obj_id, mask in masks_per_object.items():
99
  if mask is None:
100
  continue
101
  if mask.dtype != np.float32:
102
  mask = mask.astype(np.float32)
103
+ # Ensure shape is H x W
104
  if mask.ndim == 3:
105
  mask = mask.squeeze()
106
  mask = np.clip(mask, 0.0, 1.0)
107
  color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0
108
+ # Blend: overlay = (1 - a*m)*overlay + (a*m)*color
109
+ a = alpha
110
  m = mask[..., None]
111
+ overlay = (1.0 - a * m) * overlay + (a * m) * color
112
+
113
  out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
114
  return Image.fromarray(out)
115
 
116
 
117
  def get_device_and_dtype() -> tuple[str, torch.dtype]:
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ dtype = torch.bfloat16
120
+ return device, dtype
121
 
122
 
123
  class AppState:
 
130
  self.model: Optional[Sam2VideoModel] = None
131
  self.processor: Optional[Sam2VideoProcessor] = None
132
  self.device: str = "cpu"
133
+ self.dtype: torch.dtype = torch.bfloat16
134
  self.video_fps: float | None = None
135
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
136
  self.color_by_obj: dict[int, tuple[int, int, int]] = {}
137
  self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {}
138
  self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {}
139
+ # Cache of composited frames (original + masks + clicks)
140
  self.composited_frames: dict[int, Image.Image] = {}
141
+ # UI state for click handler
142
  self.current_frame_idx: int = 0
143
  self.current_obj_id: int = 1
144
  self.current_label: str = "positive"
145
  self.current_clear_old: bool = True
146
+ self.current_prompt_type: str = "Points" # or "Boxes"
147
  self.pending_box_start: tuple[int, int] | None = None
148
  self.pending_box_start_frame_idx: int | None = None
149
  self.pending_box_start_obj_id: int | None = None
150
  self.is_switching_model: bool = False
151
+ # Model selection
152
  self.model_repo_key: str = "tiny"
153
  self.model_repo_id: str | None = None
154
  self.session_repo_id: str | None = None
 
177
  if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
178
  if GLOBAL_STATE.model_repo_id == desired_repo:
179
  return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
180
+ # Different repo requested: dispose current and reload
181
  try:
182
  del GLOBAL_STATE.model
183
  except Exception:
 
188
  pass
189
  GLOBAL_STATE.model = None
190
  GLOBAL_STATE.processor = None
191
+ print(f"Loading model from {desired_repo}")
192
  device, dtype = get_device_and_dtype()
193
+
194
  model = Sam2VideoModel.from_pretrained(desired_repo, torch_dtype=dtype)
195
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
196
+
197
  model.to(device)
198
+
199
  GLOBAL_STATE.model = model
200
  GLOBAL_STATE.processor = processor
201
  GLOBAL_STATE.device = device
202
  GLOBAL_STATE.dtype = dtype
203
  GLOBAL_STATE.model_repo_id = desired_repo
204
+
205
  return model, processor, device, dtype
206
 
207
 
208
  def ensure_session_for_current_model() -> None:
209
+ """Ensure the model/processor match the selected repo and inference_session exists.
210
+ If a video is already loaded, re-initialize the inference session when needed.
211
+ """
212
  model, processor, device, dtype = load_model_if_needed()
213
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
214
  if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
215
  if GLOBAL_STATE.video_frames:
216
+ # Clear session-related UI caches when switching model
217
  GLOBAL_STATE.masks_by_frame.clear()
218
  GLOBAL_STATE.clicks_by_frame_obj.clear()
219
  GLOBAL_STATE.boxes_by_frame_obj.clear()
220
  GLOBAL_STATE.composited_frames.clear()
221
+ # Dispose previous session cleanly
222
  try:
223
  if GLOBAL_STATE.inference_session is not None:
224
  GLOBAL_STATE.inference_session.reset_inference_session()
 
226
  pass
227
  GLOBAL_STATE.inference_session = None
228
  gc.collect()
229
+ try:
230
+ if torch.cuda.is_available():
231
+ torch.cuda.empty_cache()
232
+ except Exception:
233
+ pass
234
  GLOBAL_STATE.inference_session = processor.init_video_session(
235
  video=GLOBAL_STATE.video_frames,
236
  inference_device=device,
 
237
  )
238
  GLOBAL_STATE.session_repo_id = desired_repo
239
 
240
 
241
+ def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Image, str]:
242
+ """Gradio handler: load video, init session, return state, slider bounds, and first frame."""
243
+ # Reset ONLY video-related fields, keep model loaded
244
  GLOBAL_STATE.video_frames = []
245
  GLOBAL_STATE.inference_session = None
246
  GLOBAL_STATE.masks_by_frame = {}
247
  GLOBAL_STATE.color_by_obj = {}
248
 
249
+ model, processor, device, dtype = load_model_if_needed()
250
 
251
+ # Gradio Video may provide a dict with 'name' or a direct file path
252
  video_path: Optional[str] = None
253
  if isinstance(video, dict):
254
  video_path = video.get("name") or video.get("path") or video.get("data")
 
256
  video_path = video
257
  else:
258
  video_path = None
259
+
260
  if not video_path:
261
  raise gr.Error("Invalid video input.")
262
 
 
265
  raise gr.Error("No frames could be loaded from the video.")
266
 
267
  GLOBAL_STATE.video_frames = frames
268
+ # Try to capture original FPS if provided by loader
269
  GLOBAL_STATE.video_fps = None
270
  if isinstance(info, dict) and info.get("fps"):
271
  try:
 
273
  except Exception:
274
  GLOBAL_STATE.video_fps = None
275
 
276
+ # Initialize session
 
277
  inference_session = processor.init_video_session(
278
  video=frames,
279
  inference_device=device,
 
283
 
284
  first_frame = frames[0]
285
  max_idx = len(frames) - 1
286
+ status = (
287
+ f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps. Device: {device}, dtype: bfloat16"
288
+ )
289
  return GLOBAL_STATE, 0, max_idx, first_frame, status
290
 
291
 
 
299
  if len(masks) != 0:
300
  out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65)
301
 
302
+ # Draw crosses for conditioning frames only (frames with recorded clicks)
303
  clicks_map = state.clicks_by_frame_obj.get(frame_idx)
304
  if clicks_map:
305
  draw = ImageDraw.Draw(out_img)
 
307
  for obj_id, pts in clicks_map.items():
308
  for x, y, lbl in pts:
309
  color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0)
310
+ # horizontal
311
  draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
312
+ # vertical
313
  draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
314
+ # Draw temporary cross for first corner in box mode
 
 
 
 
 
 
 
 
315
  if (
316
  state.pending_box_start is not None
317
  and state.pending_box_start_frame_idx == frame_idx
 
323
  color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255))
324
  draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2)
325
  draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2)
326
+ # Draw boxes for conditioning frames
327
+ box_map = state.boxes_by_frame_obj.get(frame_idx)
328
+ if box_map:
329
+ draw = ImageDraw.Draw(out_img)
330
+ for obj_id, boxes in box_map.items():
331
+ color = state.color_by_obj.get(obj_id, (255, 255, 255))
332
+ for x1, y1, x2, y2 in boxes:
333
+ draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
334
+ # Save to cache and return
335
  state.composited_frames[frame_idx] = out_img
336
  return out_img
337
 
 
340
  if state is None or state.video_frames is None or len(state.video_frames) == 0:
341
  return None
342
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
343
+ # Serve from cache when available
344
  cached = state.composited_frames.get(frame_idx)
345
  if cached is not None:
346
  return cached
 
361
  label: str,
362
  clear_old: bool,
363
  evt: gr.SelectData,
364
+ ) -> Image.Image:
365
  if state is None or state.inference_session is None:
366
+ return img # no-op preview when not ready
367
  if state.is_switching_model:
368
+ # Gracefully ignore input during model switch; return current preview unchanged
369
  return update_frame_display(state, int(frame_idx))
370
 
371
+ # Parse click coordinates from event
372
  x = y = None
373
  if evt is not None:
374
+ # Try different gradio event data shapes for robustness
375
  try:
376
  if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2:
377
  x, y = int(evt.index[0]), int(evt.index[1])
 
379
  x, y = int(evt.value["x"]), int(evt.value["y"])
380
  except Exception:
381
  x = y = None
382
+
383
  if x is None or y is None:
384
+ raise gr.Error("Could not read click coordinates.")
385
 
386
  _ensure_color_for_obj(int(obj_id))
387
+
388
  processor = GLOBAL_STATE.processor
389
  model = GLOBAL_STATE.model
390
  inference_session = GLOBAL_STATE.inference_session
391
 
392
  if state.current_prompt_type == "Boxes":
393
+ # Two-click box input
394
  if state.pending_box_start is None:
395
+ # If clear_old is enabled, clear prior points for this object on this frame
396
  if bool(clear_old):
397
  frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
398
  frame_clicks[int(obj_id)] = []
 
400
  state.pending_box_start = (int(x), int(y))
401
  state.pending_box_start_frame_idx = int(frame_idx)
402
  state.pending_box_start_obj_id = int(obj_id)
403
+ # Invalidate cache so temporary cross is drawn
404
  state.composited_frames.pop(int(frame_idx), None)
405
  return update_frame_display(state, int(frame_idx))
406
  else:
407
  x1, y1 = state.pending_box_start
408
  x2, y2 = int(x), int(y)
409
+ # Clear temporary state and invalidate cache
410
  state.pending_box_start = None
411
  state.pending_box_start_frame_idx = None
412
  state.pending_box_start_obj_id = None
 
429
  obj_boxes.append((x_min, y_min, x_max, y_max))
430
  state.composited_frames.pop(int(frame_idx), None)
431
  else:
432
+ # Points mode
433
  label_int = 1 if str(label).lower().startswith("pos") else 0
434
+ # If clear_old is enabled, clear prior boxes for this object on this frame
435
  if bool(clear_old):
436
  frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
437
  frame_boxes[int(obj_id)] = []
 
444
  input_labels=[[[int(label_int)]]],
445
  clear_old_inputs=bool(clear_old),
446
  )
447
+
448
  frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {})
449
  obj_clicks = frame_clicks.setdefault(int(obj_id), [])
450
  if bool(clear_old):
 
452
  obj_clicks.append((int(x), int(y), int(label_int)))
453
  state.composited_frames.pop(int(frame_idx), None)
454
 
455
+ # Forward on that frame
456
+ device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
457
+ with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=GLOBAL_STATE.dtype):
458
+ outputs = model(
459
+ inference_session=inference_session,
460
+ frame_idx=int(frame_idx),
461
+ )
462
 
463
  H = inference_session.video_height
464
  W = inference_session.video_width
465
+ # Detach and move off GPU as early as possible to reduce GPU memory pressure
466
  pred_masks = outputs.pred_masks.detach().cpu()
467
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
468
+
469
+ # Map returned masks to object ids. For single object forward, it's [1, 1, H, W]
470
+ # But to be safe, iterate over session.obj_ids order.
471
  masks_for_frame: dict[int, np.ndarray] = {}
472
  obj_ids_order = list(inference_session.obj_ids)
473
  for i, oid in enumerate(obj_ids_order):
474
  mask_i = video_res_masks[i]
475
+ # mask_i shape could be (1, H, W) or (H, W); squeeze to 2D
476
  mask_2d = mask_i.cpu().numpy().squeeze()
477
  masks_for_frame[int(oid)] = mask_2d
478
+
479
  GLOBAL_STATE.masks_by_frame[int(frame_idx)] = masks_for_frame
480
+ # Invalidate cache for this frame to force recomposition
481
  GLOBAL_STATE.composited_frames.pop(int(frame_idx), None)
482
+
483
+ # Return updated preview
484
  return update_frame_display(GLOBAL_STATE, int(frame_idx))
485
 
486
 
 
489
  if state is None or state.inference_session is None:
490
  yield "Load a video first."
491
  return
492
+
493
  processor = GLOBAL_STATE.processor
494
  model = GLOBAL_STATE.model
495
  inference_session = GLOBAL_STATE.inference_session
496
+
497
  total = max(1, GLOBAL_STATE.num_frames)
498
  processed = 0
499
+
500
+ # Initial status for first run visibility
501
  yield f"Propagating masks: {processed}/{total}"
502
+
503
+ device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
504
+ with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=GLOBAL_STATE.dtype):
505
  for sam2_video_output in model.propagate_in_video_iterator(inference_session):
506
  H = inference_session.video_height
507
  W = inference_session.video_width
508
  pred_masks = sam2_video_output.pred_masks.detach().cpu()
509
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
510
+
511
  frame_idx = int(sam2_video_output.frame_idx)
512
  masks_for_frame: dict[int, np.ndarray] = {}
513
  obj_ids_order = list(inference_session.obj_ids)
 
515
  mask_2d = video_res_masks[i].cpu().numpy().squeeze()
516
  masks_for_frame[int(oid)] = mask_2d
517
  GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
518
+ # Invalidate cache for that frame to force recomposition
519
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
520
+
521
  processed += 1
522
  progress((processed, total), f"Propagating masks: {processed}/{total}")
523
+ # Stream status updates so users see progress text
524
  yield f"Propagating masks: {processed}/{total}"
525
+
526
  yield f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
527
 
528
 
529
+ def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
530
+ # Reset only session-related state, keep uploaded video and model
531
  if not GLOBAL_STATE.video_frames:
532
+ # Nothing loaded; keep behavior
533
  return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video."
534
+
535
+ # Clear prompts and caches
536
  GLOBAL_STATE.masks_by_frame.clear()
537
  GLOBAL_STATE.clicks_by_frame_obj.clear()
538
  GLOBAL_STATE.boxes_by_frame_obj.clear()
 
540
  GLOBAL_STATE.pending_box_start = None
541
  GLOBAL_STATE.pending_box_start_frame_idx = None
542
  GLOBAL_STATE.pending_box_start_obj_id = None
543
+
544
+ # Dispose and re-init inference session for current model with existing frames
545
  try:
546
  if GLOBAL_STATE.inference_session is not None:
547
  GLOBAL_STATE.inference_session.reset_inference_session()
 
549
  pass
550
  GLOBAL_STATE.inference_session = None
551
  gc.collect()
552
+ try:
553
+ if torch.cuda.is_available():
554
+ torch.cuda.empty_cache()
555
+ except Exception:
556
+ pass
557
  ensure_session_for_current_model()
558
+
559
+ # Keep current slider index if possible
560
  current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
561
  current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
562
  preview_img = update_frame_display(GLOBAL_STATE, current_idx)
 
566
  return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
567
 
568
 
569
+ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation") as demo:
570
  state = gr.State(GLOBAL_STATE)
571
 
572
+ gr.Markdown("""
573
+ **SAM2 Video (Transformers)** — Upload a video, click to add positive/negative points per object, preview masks on the clicked frame, then propagate across the video. Use the slider to scrub frames.
574
+ """)
 
 
575
 
576
  with gr.Row():
577
  with gr.Column(scale=1):
 
585
  load_status = gr.Markdown(visible=True)
586
  reset_btn = gr.Button("Reset Session", variant="secondary")
587
  examples_list = [
588
+ ["/home/ubuntu/models_implem/tennis.mp4"],
589
+ ["/home/ubuntu/models_implem/tennis.mp4"],
590
  ]
591
  with gr.Column(scale=2):
592
  preview = gr.Image(label="Preview", interactive=True)
 
605
  render_btn = gr.Button("Render MP4 for smooth playback")
606
  playback_video = gr.Video(label="Rendered Playback", interactive=False)
607
 
608
+ # Wire events
609
  def _on_video_change(video):
610
  s, min_idx, max_idx, first_frame, status = init_video_session(video)
611
+ return (
612
+ s,
613
+ gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
614
+ first_frame,
615
+ status,
616
+ )
617
 
618
  video_in.change(
619
+ _on_video_change,
620
+ inputs=[video_in],
621
+ outputs=[state, frame_slider, preview, load_status],
622
+ show_progress=True,
623
  )
624
+
625
  gr.Examples(
626
  examples=examples_list,
627
  inputs=[video_in],
 
636
  if s is not None and key:
637
  key = str(key)
638
  if key != s.model_repo_key:
639
+ # Update and drop current model to reload lazily next time
640
  s.is_switching_model = True
641
  s.model_repo_key = key
642
  s.model_repo_id = None
643
  s.model = None
644
  s.processor = None
645
+ # Stream progress text while loading (first yield shows text)
646
  yield gr.update(visible=True, value=f"Loading checkpoint: {key}...")
647
  ensure_session_for_current_model()
648
  if s is not None:
649
  s.is_switching_model = False
650
+ # Final yield hides the text
651
  yield gr.update(visible=False, value="")
652
 
653
  ckpt_radio.change(_on_ckpt_change, inputs=[state, ckpt_radio], outputs=[ckpt_progress])
654
 
655
+ # Also retrigger session re-init if a video already loaded
656
  def _rebind_session_after_ckpt(s: AppState):
657
  ensure_session_for_current_model()
658
+ # Reset pending box corner to avoid mismatched state
659
  if s is not None:
660
  s.pending_box_start = None
661
  return gr.update()
 
667
  state_in.current_frame_idx = int(idx)
668
  return update_frame_display(state_in, int(idx))
669
 
670
+ frame_slider.change(
671
+ _sync_frame_idx,
672
+ inputs=[state, frame_slider],
673
+ outputs=preview,
674
+ )
675
 
676
  def _sync_obj_id(s: AppState, oid):
677
  if s is not None and oid is not None:
 
696
 
697
  prompt_type.change(_sync_prompt_type, inputs=[state, prompt_type], outputs=[label_radio])
698
 
699
+ # Image click to add a point and run forward on that frame
700
  preview.select(on_image_click, [preview, state, frame_slider, obj_id_inp, label_radio, clear_old_chk], preview)
701
 
702
+ # Playback via MP4 rendering only
703
+
704
+ # Render a smooth MP4 using imageio/pyav (fallbacks to imageio v2 / OpenCV)
705
  def _render_video(s: AppState):
706
  if s is None or s.num_frames == 0:
707
  raise gr.Error("Load a video first.")
708
  fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
709
+ # Compose all frames (cache will help if already prepared)
710
  frames_np = []
711
+ first = compose_frame(s, 0)
712
+ h, w = first.size[1], first.size[0]
713
  for idx in range(s.num_frames):
714
  img = s.composited_frames.get(idx)
715
  if img is None:
716
  img = compose_frame(s, idx)
717
+ frames_np.append(np.array(img)[:, :, ::-1]) # BGR for cv2
718
+ # Periodically release CPU mem to reduce pressure
719
  if (idx + 1) % 60 == 0:
720
  gc.collect()
721
  out_path = "/tmp/sam2_playback.mp4"
722
+ # Prefer imageio with PyAV/ffmpeg to respect exact fps
723
  try:
724
  import imageio.v3 as iio # type: ignore
725
 
726
  iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
727
  return out_path
728
  except Exception:
729
+ # Fallbacks
730
  try:
731
  import imageio.v2 as imageio # type: ignore
732
 
733
  imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
734
  return out_path
735
+ except Exception:
736
+ try:
737
+ import cv2 # type: ignore
738
+
739
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
740
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
741
+ for fr_bgr in frames_np:
742
+ writer.write(fr_bgr)
743
+ writer.release()
744
+ return out_path
745
+ except Exception as e:
746
+ raise gr.Error(f"Failed to render video: {e}")
747
 
748
  render_btn.click(_render_video, inputs=[state], outputs=[playback_video])
749