yonigozlan HF Staff commited on
Commit
9ca5268
·
1 Parent(s): 6d0e912

add debug and load with cv2

Browse files
Files changed (1) hide show
  1. app.py +30 -66
app.py CHANGED
@@ -3,6 +3,7 @@ import gc
3
  from copy import deepcopy
4
  from typing import Optional
5
 
 
6
  import gradio as gr
7
  import numpy as np
8
  import spaces
@@ -10,7 +11,6 @@ import torch
10
  from gradio.themes import Soft
11
  from PIL import Image, ImageDraw
12
 
13
- # Prefer local transformers in the workspace
14
  from transformers import AutoModel, Sam2VideoProcessor
15
 
16
 
@@ -32,56 +32,25 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
32
  """Load video frames as PIL Images using transformers.video_utils if available,
33
  otherwise fall back to OpenCV. Returns (frames, info).
34
  """
35
- try:
36
- from transformers.video_utils import load_video # type: ignore
37
-
38
- frames, info = load_video(video_path_or_url)
39
- # Ensure PIL format
40
- pil_frames = []
41
- for fr in frames:
42
- if isinstance(fr, Image.Image):
43
- pil_frames.append(fr.convert("RGB"))
44
- else:
45
- pil_frames.append(Image.fromarray(fr).convert("RGB"))
46
- info = info if info is not None else {}
47
- # Ensure fps present when possible (fallback to cv2 probe)
48
- if "fps" not in info or not info.get("fps"):
49
- try:
50
- import cv2 # type: ignore
51
-
52
- cap = cv2.VideoCapture(video_path_or_url)
53
- fps_val = cap.get(cv2.CAP_PROP_FPS)
54
- cap.release()
55
- if fps_val and fps_val > 0:
56
- info["fps"] = float(fps_val)
57
- except Exception as e:
58
- print(f"Failed to render video with cv2: {e}")
59
- pass
60
- return pil_frames, info
61
- except Exception as e:
62
- print(f"Failed to load video with transformers.video_utils: {e}")
63
- # Fallback to OpenCV
64
- try:
65
- import cv2 # type: ignore
66
-
67
- cap = cv2.VideoCapture(video_path_or_url)
68
- frames = []
69
- while cap.isOpened():
70
- ret, frame = cap.read()
71
- if not ret:
72
- break
73
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
- frames.append(Image.fromarray(frame_rgb))
75
- # Gather fps if available
76
- fps_val = cap.get(cv2.CAP_PROP_FPS)
77
- cap.release()
78
- info = {
79
- "num_frames": len(frames),
80
- "fps": float(fps_val) if fps_val and fps_val > 0 else None,
81
- }
82
- return frames, info
83
- except Exception as e:
84
- raise RuntimeError(f"Failed to load video: {e}")
85
 
86
 
87
  def overlay_masks_on_frame(
@@ -190,6 +159,7 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
190
  model = AutoModel.from_pretrained(desired_repo)
191
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
192
  model.to(device, dtype=dtype)
 
193
 
194
  GLOBAL_STATE.model = model
195
  GLOBAL_STATE.processor = processor
@@ -197,14 +167,12 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
197
  GLOBAL_STATE.dtype = dtype
198
  GLOBAL_STATE.model_repo_id = desired_repo
199
 
200
- return model, processor, device, dtype
201
-
202
 
203
  def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
204
  """Ensure the model/processor match the selected repo and inference_session exists.
205
  If a video is already loaded, re-initialize the inference session when needed.
206
  """
207
- model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
208
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
209
  if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
210
  if GLOBAL_STATE.video_frames:
@@ -214,10 +182,10 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
214
  GLOBAL_STATE.boxes_by_frame_obj.clear()
215
  GLOBAL_STATE.composited_frames.clear()
216
  GLOBAL_STATE.inference_session = None
217
- GLOBAL_STATE.inference_session = processor.init_video_session(
218
- inference_device=device,
219
  video_storage_device="cpu",
220
- dtype=dtype,
221
  )
222
  GLOBAL_STATE.session_repo_id = desired_repo
223
 
@@ -230,7 +198,7 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
230
  GLOBAL_STATE.masks_by_frame = {}
231
  GLOBAL_STATE.color_by_obj = {}
232
 
233
- model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
234
 
235
  # Gradio Video may provide a dict with 'name' or a direct file path
236
  video_path: Optional[str] = None
@@ -262,10 +230,10 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
262
  # Try to capture original FPS if provided by loader
263
  GLOBAL_STATE.video_fps = float(fps_in)
264
  # Initialize session
265
- inference_session = processor.init_video_session(
266
- inference_device=device,
267
  video_storage_device="cpu",
268
- dtype=dtype,
269
  )
270
  GLOBAL_STATE.inference_session = inference_session
271
 
@@ -273,7 +241,7 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
273
  max_idx = len(frames) - 1
274
  status = (
275
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
276
- f"Device: {device}, dtype: bfloat16"
277
  )
278
  return GLOBAL_STATE, 0, max_idx, first_frame, status
279
 
@@ -520,8 +488,6 @@ def propagate_masks(GLOBAL_STATE: gr.State):
520
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
521
  if processed % 30 == 0 or processed == total:
522
  yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
523
- # else:
524
- # yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
525
 
526
  text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
527
 
@@ -752,8 +718,6 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
752
  out_path = "/tmp/sam2_playback.mp4"
753
  # Prefer imageio with PyAV/ffmpeg to respect exact fps
754
  try:
755
- import cv2 # type: ignore
756
-
757
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
758
  writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
759
  for fr_bgr in frames_np:
 
3
  from copy import deepcopy
4
  from typing import Optional
5
 
6
+ import cv2
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
 
11
  from gradio.themes import Soft
12
  from PIL import Image, ImageDraw
13
 
 
14
  from transformers import AutoModel, Sam2VideoProcessor
15
 
16
 
 
32
  """Load video frames as PIL Images using transformers.video_utils if available,
33
  otherwise fall back to OpenCV. Returns (frames, info).
34
  """
35
+
36
+ cap = cv2.VideoCapture(video_path_or_url)
37
+ frames = []
38
+ print("loading video frames")
39
+ while cap.isOpened():
40
+ ret, frame = cap.read()
41
+ if not ret:
42
+ break
43
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
44
+ frames.append(Image.fromarray(frame_rgb))
45
+ # Gather fps if available
46
+ fps_val = cap.get(cv2.CAP_PROP_FPS)
47
+ cap.release()
48
+ print("loaded video frames")
49
+ info = {
50
+ "num_frames": len(frames),
51
+ "fps": float(fps_val) if fps_val and fps_val > 0 else None,
52
+ }
53
+ return frames, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def overlay_masks_on_frame(
 
159
  model = AutoModel.from_pretrained(desired_repo)
160
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
161
  model.to(device, dtype=dtype)
162
+ print("model loaded")
163
 
164
  GLOBAL_STATE.model = model
165
  GLOBAL_STATE.processor = processor
 
167
  GLOBAL_STATE.dtype = dtype
168
  GLOBAL_STATE.model_repo_id = desired_repo
169
 
 
 
170
 
171
  def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
172
  """Ensure the model/processor match the selected repo and inference_session exists.
173
  If a video is already loaded, re-initialize the inference session when needed.
174
  """
175
+ load_model_if_needed(GLOBAL_STATE)
176
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
177
  if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
178
  if GLOBAL_STATE.video_frames:
 
182
  GLOBAL_STATE.boxes_by_frame_obj.clear()
183
  GLOBAL_STATE.composited_frames.clear()
184
  GLOBAL_STATE.inference_session = None
185
+ GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session(
186
+ inference_device=GLOBAL_STATE.device,
187
  video_storage_device="cpu",
188
+ dtype=GLOBAL_STATE.dtype,
189
  )
190
  GLOBAL_STATE.session_repo_id = desired_repo
191
 
 
198
  GLOBAL_STATE.masks_by_frame = {}
199
  GLOBAL_STATE.color_by_obj = {}
200
 
201
+ load_model_if_needed(GLOBAL_STATE)
202
 
203
  # Gradio Video may provide a dict with 'name' or a direct file path
204
  video_path: Optional[str] = None
 
230
  # Try to capture original FPS if provided by loader
231
  GLOBAL_STATE.video_fps = float(fps_in)
232
  # Initialize session
233
+ inference_session = GLOBAL_STATE.processor.init_video_session(
234
+ inference_device=GLOBAL_STATE.device,
235
  video_storage_device="cpu",
236
+ dtype=GLOBAL_STATE.dtype,
237
  )
238
  GLOBAL_STATE.inference_session = inference_session
239
 
 
241
  max_idx = len(frames) - 1
242
  status = (
243
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
244
+ f"Device: {GLOBAL_STATE.device}, dtype: bfloat16"
245
  )
246
  return GLOBAL_STATE, 0, max_idx, first_frame, status
247
 
 
488
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
489
  if processed % 30 == 0 or processed == total:
490
  yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
 
 
491
 
492
  text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
493
 
 
718
  out_path = "/tmp/sam2_playback.mp4"
719
  # Prefer imageio with PyAV/ffmpeg to respect exact fps
720
  try:
 
 
721
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
722
  writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
723
  for fr_bgr in frames_np: