yonigozlan HF Staff commited on
Commit
41c6ac5
·
1 Parent(s): 02f1528
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -5,7 +5,6 @@ from typing import Optional
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
- from gradio.themes import Soft
9
  from PIL import Image, ImageDraw
10
 
11
  # Prefer local transformers in the workspace
@@ -233,6 +232,8 @@ def ensure_session_for_current_model() -> None:
233
  GLOBAL_STATE.inference_session = processor.init_video_session(
234
  video=GLOBAL_STATE.video_frames,
235
  inference_device=device,
 
 
236
  )
237
  GLOBAL_STATE.session_repo_id = desired_repo
238
 
@@ -277,6 +278,7 @@ def init_video_session(video: str | dict) -> tuple[AppState, int, int, Image.Ima
277
  video=frames,
278
  inference_device=device,
279
  video_storage_device="cpu",
 
280
  )
281
  GLOBAL_STATE.inference_session = inference_session
282
 
@@ -482,7 +484,7 @@ def on_image_click(
482
  return update_frame_display(GLOBAL_STATE, int(frame_idx))
483
 
484
 
485
- def propagate_masks(state: AppState, progress=gr.Progress()):
486
  if state is None or state.inference_session is None:
487
  yield "Load a video first."
488
  return
@@ -494,7 +496,6 @@ def propagate_masks(state: AppState, progress=gr.Progress()):
494
  total = max(1, GLOBAL_STATE.num_frames)
495
  processed = 0
496
 
497
- # Initial status for first run visibility
498
  yield f"Propagating masks: {processed}/{total}"
499
 
500
  device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
@@ -516,8 +517,6 @@ def propagate_masks(state: AppState, progress=gr.Progress()):
516
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
517
 
518
  processed += 1
519
- progress((processed, total), f"Propagating masks: {processed}/{total}")
520
- # Stream status updates so users see progress text
521
  yield f"Propagating masks: {processed}/{total}"
522
 
523
  yield f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
@@ -563,9 +562,7 @@ def reset_session() -> tuple[AppState, Image.Image, int, int, str]:
563
  return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
564
 
565
 
566
- theme = Soft(primary_hue="indigo", secondary_hue="rose", neutral_hue="slate")
567
-
568
- with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme) as demo:
569
  state = gr.State(GLOBAL_STATE)
570
 
571
  gr.Markdown("""
@@ -743,7 +740,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
743
 
744
  render_btn.click(_render_video, inputs=[state], outputs=[playback_video])
745
 
746
- propagate_btn.click(propagate_masks, inputs=[state], outputs=[propagate_status], show_progress=True)
747
 
748
  reset_btn.click(
749
  reset_session,
 
5
  import gradio as gr
6
  import numpy as np
7
  import torch
 
8
  from PIL import Image, ImageDraw
9
 
10
  # Prefer local transformers in the workspace
 
232
  GLOBAL_STATE.inference_session = processor.init_video_session(
233
  video=GLOBAL_STATE.video_frames,
234
  inference_device=device,
235
+ video_storage_device="cpu",
236
+ torch_dtype=dtype,
237
  )
238
  GLOBAL_STATE.session_repo_id = desired_repo
239
 
 
278
  video=frames,
279
  inference_device=device,
280
  video_storage_device="cpu",
281
+ torch_dtype=dtype,
282
  )
283
  GLOBAL_STATE.inference_session = inference_session
284
 
 
484
  return update_frame_display(GLOBAL_STATE, int(frame_idx))
485
 
486
 
487
+ def propagate_masks(state: AppState):
488
  if state is None or state.inference_session is None:
489
  yield "Load a video first."
490
  return
 
496
  total = max(1, GLOBAL_STATE.num_frames)
497
  processed = 0
498
 
 
499
  yield f"Propagating masks: {processed}/{total}"
500
 
501
  device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
 
517
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
518
 
519
  processed += 1
 
 
520
  yield f"Propagating masks: {processed}/{total}"
521
 
522
  yield f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
 
562
  return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
563
 
564
 
565
+ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme="shivi/calm_seafoam") as demo:
 
 
566
  state = gr.State(GLOBAL_STATE)
567
 
568
  gr.Markdown("""
 
740
 
741
  render_btn.click(_render_video, inputs=[state], outputs=[playback_video])
742
 
743
+ propagate_btn.click(propagate_masks, inputs=[state], outputs=[propagate_status])
744
 
745
  reset_btn.click(
746
  reset_session,