Mirko Trasciatti commited on
Commit
fb2fd45
·
1 Parent(s): 2feeac4

Add YOLO-driven kick detection and chart

Browse files
Files changed (1) hide show
  1. app.py +507 -22
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import colorsys
2
  import gc
3
  from copy import deepcopy
@@ -161,6 +163,236 @@ def detect_person_box(
161
  return x_min, y_min, x_max, y_max, conf
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
165
  """Generate a deterministic pastel RGB color for a given object id.
166
 
@@ -305,6 +537,25 @@ class AppState:
305
  self.player_obj_id: int | None = None
306
  self.player_detection_frame: int | None = None
307
  self.player_detection_conf: float | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  def __repr__(self):
310
  return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
@@ -403,9 +654,43 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
403
  GLOBAL_STATE.impact_debug_speed_kmh = []
404
  GLOBAL_STATE.impact_debug_speed_threshold_px = None
405
  GLOBAL_STATE.impact_meters_per_px = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  GLOBAL_STATE.player_obj_id = None
407
  GLOBAL_STATE.player_detection_frame = None
408
  GLOBAL_STATE.player_detection_conf = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  load_model_if_needed(GLOBAL_STATE)
411
 
@@ -957,6 +1242,137 @@ def _build_kick_plot(state: AppState):
957
  return fig
958
 
959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
  def _format_impact_status(state: AppState) -> str:
961
  if state is None:
962
  return "Impact frame: not computed"
@@ -1018,15 +1434,10 @@ def _player_has_masks(state: AppState) -> bool:
1018
 
1019
 
1020
  def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
1021
- propagate_main_enabled = _ball_has_masks(state)
1022
- detect_player_enabled = False
1023
- propagate_player_enabled = False
1024
- if isinstance(state, AppState):
1025
- kick_candidate = state.kick_frame or getattr(state, "kick_debug_kick_frame", None)
1026
- if kick_candidate is not None:
1027
- detect_player_enabled = True
1028
- if detect_player_enabled and _player_has_masks(state):
1029
- propagate_player_enabled = True
1030
  return (
1031
  gr.update(interactive=propagate_main_enabled),
1032
  gr.update(interactive=detect_player_enabled),
@@ -1468,6 +1879,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1468
  "Load a video first.",
1469
  gr.update(),
1470
  _build_kick_plot(GLOBAL_STATE),
 
1471
  _format_impact_status(GLOBAL_STATE),
1472
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1473
  propagate_main_update,
@@ -1475,6 +1887,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1475
  propagate_player_update,
1476
  )
1477
 
 
 
1478
  processor = deepcopy(GLOBAL_STATE.processor)
1479
  model = deepcopy(GLOBAL_STATE.model)
1480
  inference_session = deepcopy(GLOBAL_STATE.inference_session)
@@ -1483,9 +1897,19 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1483
  inference_session.cache.inference_device = "cuda"
1484
  model.to("cuda")
1485
 
1486
- total = max(1, GLOBAL_STATE.num_frames)
 
 
 
 
 
 
 
 
1487
  processed = 0
1488
 
 
 
1489
  # Initial status; no slider change yet
1490
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1491
  yield (
@@ -1493,6 +1917,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1493
  f"Propagating masks: {processed}/{total}",
1494
  gr.update(),
1495
  _build_kick_plot(GLOBAL_STATE),
 
1496
  _format_impact_status(GLOBAL_STATE),
1497
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1498
  propagate_main_update,
@@ -1500,9 +1925,10 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1500
  propagate_player_update,
1501
  )
1502
 
1503
- last_frame_idx = 0
1504
  with torch.inference_mode():
1505
- for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
 
1506
  pixel_values = None
1507
  if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
1508
  pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
@@ -1531,6 +1957,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1531
  f"Propagating masks: {processed}/{total}",
1532
  gr.update(value=frame_idx),
1533
  _build_kick_plot(GLOBAL_STATE),
 
1534
  _format_impact_status(GLOBAL_STATE),
1535
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1536
  propagate_main_update,
@@ -1554,6 +1981,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1554
  text,
1555
  gr.update(value=target_frame),
1556
  _build_kick_plot(GLOBAL_STATE),
 
1557
  _format_impact_status(GLOBAL_STATE),
1558
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1559
  propagate_main_update,
@@ -1646,6 +2074,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
1646
  status,
1647
  gr.update(visible=False, value=""),
1648
  _build_kick_plot(GLOBAL_STATE),
 
1649
  _format_impact_status(GLOBAL_STATE),
1650
  propagate_main_update,
1651
  detect_btn_update,
@@ -1893,7 +2322,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1893
  """
1894
  **Working with results**
1895
  - **Preview**: Use the slider to navigate frames and see the current masks.
1896
- - **Propagate**: Click “Propagate across video” to track all defined objects through the entire video. The preview follows progress periodically to keep things responsive.
1897
  - **Export**: Render an MP4 for smooth playback using the original video FPS.
1898
  - **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
1899
  """
@@ -1951,7 +2380,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1951
  )
1952
  with gr.Row():
1953
  detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
1954
- propagate_btn = gr.Button("Propagate across video", variant="primary", interactive=False)
 
1955
  detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
1956
  propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
1957
  ball_status = gr.Markdown(visible=False)
@@ -1963,6 +2393,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1963
  clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
1964
  prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
1965
  kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True)
 
1966
 
1967
  # Wire events
1968
  def _on_video_change(GLOBAL_STATE: gr.State, video):
@@ -1975,6 +2406,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1975
  status,
1976
  gr.update(visible=False, value=""),
1977
  _build_kick_plot(GLOBAL_STATE),
 
1978
  _format_impact_status(GLOBAL_STATE),
1979
  propagate_main_update,
1980
  detect_btn_update,
@@ -1984,7 +2416,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1984
  video_in.change(
1985
  _on_video_change,
1986
  inputs=[GLOBAL_STATE, video_in],
1987
- outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
1988
  show_progress=True,
1989
  )
1990
 
@@ -1997,7 +2429,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1997
  examples=examples_list,
1998
  inputs=[GLOBAL_STATE, video_in],
1999
  fn=_on_video_change,
2000
- outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2001
  label="Examples",
2002
  cache_examples=False,
2003
  examples_per_page=5,
@@ -2187,6 +2619,43 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2187
  outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn],
2188
  )
2189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2190
  def _auto_detect_player(state_in: AppState):
2191
  if state_in is None or state_in.num_frames == 0:
2192
  raise gr.Error("Load a video first, then try auto-detect.")
@@ -2303,6 +2772,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2303
  "Load a video first.",
2304
  gr.update(),
2305
  _build_kick_plot(GLOBAL_STATE),
 
2306
  _format_impact_status(GLOBAL_STATE),
2307
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2308
  propagate_main_update,
@@ -2316,6 +2786,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2316
  "Detect the player before propagating.",
2317
  gr.update(),
2318
  _build_kick_plot(GLOBAL_STATE),
 
2319
  _format_impact_status(GLOBAL_STATE),
2320
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2321
  propagate_main_update,
@@ -2330,8 +2801,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2330
  inference_session.cache.inference_device = "cuda"
2331
  model.to("cuda")
2332
 
2333
- total = max(1, GLOBAL_STATE.num_frames)
 
 
 
 
 
 
 
 
2334
  processed = 0
 
2335
 
2336
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2337
  yield (
@@ -2339,6 +2819,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2339
  f"Propagating player: {processed}/{total}",
2340
  gr.update(),
2341
  _build_kick_plot(GLOBAL_STATE),
 
2342
  _format_impact_status(GLOBAL_STATE),
2343
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2344
  propagate_main_update,
@@ -2349,7 +2830,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2349
  player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID
2350
 
2351
  with torch.inference_mode():
2352
- for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
 
2353
  pixel_values = None
2354
  if (
2355
  inference_session.processed_frames is None
@@ -2375,6 +2857,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2375
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
2376
 
2377
  processed += 1
 
2378
  if processed % 30 == 0 or processed == total:
2379
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2380
  yield (
@@ -2382,6 +2865,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2382
  f"Propagating player: {processed}/{total}",
2383
  gr.update(value=frame_idx),
2384
  _build_kick_plot(GLOBAL_STATE),
 
2385
  _format_impact_status(GLOBAL_STATE),
2386
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2387
  propagate_main_update,
@@ -2394,7 +2878,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2394
  if target_frame is None:
2395
  target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None)
2396
  if target_frame is None:
2397
- target_frame = max(0, processed - 1)
2398
  target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
2399
  GLOBAL_STATE.current_frame_idx = target_frame
2400
 
@@ -2404,6 +2888,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2404
  text,
2405
  gr.update(value=target_frame),
2406
  _build_kick_plot(GLOBAL_STATE),
 
2407
  _format_impact_status(GLOBAL_STATE),
2408
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2409
  propagate_main_update,
@@ -2414,7 +2899,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2414
  propagate_player_btn.click(
2415
  propagate_player_masks,
2416
  inputs=[GLOBAL_STATE],
2417
- outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2418
  )
2419
 
2420
  # Image click to add a point and run forward on that frame
@@ -2483,13 +2968,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2483
  propagate_btn.click(
2484
  propagate_masks,
2485
  inputs=[GLOBAL_STATE],
2486
- outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2487
  )
2488
 
2489
  reset_btn.click(
2490
  reset_session,
2491
  inputs=GLOBAL_STATE,
2492
- outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2493
  )
2494
 
2495
  # ============================================================================
 
1
+ from __future__ import annotations
2
+
3
  import colorsys
4
  import gc
5
  from copy import deepcopy
 
163
  return x_min, y_min, x_max, y_max, conf
164
 
165
 
166
+ def _compute_sam_window_from_kick(state: AppState, kick_frame: int | None) -> tuple[int, int]:
167
+ total_frames = state.num_frames
168
+ if total_frames == 0:
169
+ return 0, 0
170
+ fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0
171
+ target_window_frames = max(1, int(round(fps * 4.0)))
172
+ half_window = target_window_frames // 2
173
+ if kick_frame is None:
174
+ start_idx = 0
175
+ else:
176
+ start_idx = max(0, int(kick_frame) - half_window)
177
+ end_idx = min(total_frames, start_idx + target_window_frames)
178
+ if end_idx <= start_idx:
179
+ end_idx = min(total_frames, start_idx + 1)
180
+ state.sam_window = (start_idx, end_idx)
181
+ return start_idx, end_idx
182
+
183
+
184
+ def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None = None) -> None:
185
+ if state is None or state.num_frames == 0:
186
+ raise gr.Error("Load a video first, then track with YOLO.")
187
+
188
+ model = get_yolo_model()
189
+ class_ids = [
190
+ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
191
+ ]
192
+ if not class_ids:
193
+ raise gr.Error("YOLO model does not contain the sports ball class.")
194
+
195
+ frames = state.video_frames
196
+ total = len(frames)
197
+ centers: dict[int, tuple[float, float]] = {}
198
+ boxes: dict[int, tuple[int, int, int, int]] = {}
199
+ confs: dict[int, float] = {}
200
+ areas: dict[int, float] = {}
201
+ first_detection_frame: int | None = None
202
+
203
+ for idx, frame in enumerate(frames):
204
+ if progress is not None:
205
+ progress((idx + 1) / total)
206
+
207
+ results = model.predict(
208
+ source=frame,
209
+ conf=YOLO_CONF_THRESHOLD,
210
+ iou=YOLO_IOU_THRESHOLD,
211
+ max_det=1,
212
+ classes=class_ids,
213
+ imgsz=640,
214
+ device="cpu",
215
+ verbose=False,
216
+ )
217
+ if not results:
218
+ continue
219
+ boxes_result = results[0].boxes
220
+ if boxes_result is None or len(boxes_result) == 0:
221
+ continue
222
+
223
+ box = boxes_result[0]
224
+ xywh = box.xywh[0].cpu().tolist()
225
+ conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
226
+ x_center, y_center, width, height = xywh
227
+ x_center = float(x_center)
228
+ y_center = float(y_center)
229
+ width = max(1.0, float(width))
230
+ height = max(1.0, float(height))
231
+
232
+ frame_width, frame_height = frame.size
233
+ x_min = int(round(max(0.0, x_center - width / 2.0)))
234
+ y_min = int(round(max(0.0, y_center - height / 2.0)))
235
+ x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
236
+ y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
237
+ if x_max <= x_min or y_max <= y_min:
238
+ continue
239
+
240
+ centers[idx] = (x_center, y_center)
241
+ boxes[idx] = (x_min, y_min, x_max, y_max)
242
+ confs[idx] = conf
243
+ areas[idx] = float((x_max - x_min) * (y_max - y_min))
244
+ if first_detection_frame is None:
245
+ first_detection_frame = idx
246
+
247
+ state.yolo_ball_centers = centers
248
+ state.yolo_ball_boxes = boxes
249
+ state.yolo_ball_conf = confs
250
+ state.yolo_mask_area_proxy = [areas.get(k, 0.0) for k in sorted(centers.keys())]
251
+ state.yolo_initial_frame = first_detection_frame
252
+
253
+ if len(centers) < 3:
254
+ state.yolo_smoothed_centers = {}
255
+ state.yolo_speeds = {}
256
+ state.yolo_distance_from_start = {}
257
+ state.yolo_threshold = None
258
+ state.yolo_baseline_speed = None
259
+ state.yolo_speed_std = None
260
+ state.yolo_kick_frame = None
261
+ state.yolo_status = "❌ YOLO13: insufficient detections to estimate kick. Please retry or annotate manually."
262
+ state.sam_window = None
263
+ return
264
+
265
+ items = sorted(centers.items())
266
+ dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
267
+ alpha = 0.35
268
+
269
+ smoothed: dict[int, tuple[float, float]] = {}
270
+ speeds: dict[int, float] = {}
271
+
272
+ prev_frame = None
273
+ prev_smooth = None
274
+ for frame_idx, (cx, cy) in items:
275
+ if prev_smooth is None:
276
+ smooth_x, smooth_y = float(cx), float(cy)
277
+ else:
278
+ smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
279
+ smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
280
+ smoothed[frame_idx] = (smooth_x, smooth_y)
281
+ if prev_smooth is None or prev_frame is None:
282
+ speeds[frame_idx] = 0.0
283
+ else:
284
+ frame_delta = max(1, frame_idx - prev_frame)
285
+ time_delta = frame_delta * dt
286
+ dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
287
+ speed = dist / time_delta if time_delta > 0 else dist
288
+ speeds[frame_idx] = speed
289
+ prev_smooth = (smooth_x, smooth_y)
290
+ prev_frame = frame_idx
291
+
292
+ frames_ordered = [frame_idx for frame_idx, _ in items]
293
+ speed_series = [speeds.get(f, 0.0) for f in frames_ordered]
294
+
295
+ baseline_window = min(10, len(frames_ordered) // 3 or 1)
296
+ baseline_speeds = speed_series[:baseline_window]
297
+ baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
298
+ speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
299
+ base_threshold = baseline_speed + 4.0 * speed_std
300
+ if base_threshold < baseline_speed * 3.0:
301
+ base_threshold = baseline_speed * 3.0
302
+ speed_threshold = max(base_threshold, 15.0)
303
+
304
+ distance_dict: dict[int, float] = {}
305
+ if smoothed:
306
+ first_frame = frames_ordered[0]
307
+ origin = smoothed[first_frame]
308
+ for frame_idx, (sx, sy) in smoothed.items():
309
+ distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1])
310
+
311
+ areas_dict = {idx: areas.get(idx, 0.0) for idx in frames_ordered}
312
+ initial_area = areas_dict.get(frames_ordered[0], 1.0) or 1.0
313
+ radius_estimate = math.sqrt(initial_area / math.pi)
314
+ adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0))
315
+
316
+ sustain_frames = 3
317
+ holdout_frames = 8
318
+ area_window = 4
319
+ area_drop_ratio = 0.75
320
+
321
+ kalman_pos, kalman_speed, _ = _run_kalman_filter(items, dt)
322
+ kalman_speed_series = [kalman_speed.get(f, 0.0) for f in frames_ordered]
323
+
324
+ kick_frame: int | None = None
325
+ for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window):
326
+ speed = speed_series[idx]
327
+ if speed < speed_threshold:
328
+ continue
329
+ sustain_ok = True
330
+ for j in range(1, sustain_frames + 1):
331
+ if idx + j >= len(frames_ordered):
332
+ break
333
+ if speed_series[idx + j] < speed_threshold * 0.7:
334
+ sustain_ok = False
335
+ break
336
+ if not sustain_ok:
337
+ continue
338
+
339
+ area_pass = True
340
+ current_area = areas_dict.get(frame)
341
+ if current_area:
342
+ prev_areas = [
343
+ areas_dict.get(f)
344
+ for f in frames_ordered[max(0, idx - area_window):idx]
345
+ if areas_dict.get(f) is not None
346
+ ]
347
+ if prev_areas:
348
+ median_prev = statistics.median(prev_areas)
349
+ if median_prev > 0:
350
+ ratio = current_area / median_prev
351
+ if ratio > area_drop_ratio:
352
+ area_pass = False
353
+ if not area_pass and speed < speed_threshold * 1.2:
354
+ continue
355
+
356
+ future_slice = frames_ordered[idx: min(len(frames_ordered), idx + holdout_frames)]
357
+ max_future_dist = 0.0
358
+ for future_frame in future_slice:
359
+ dist = distance_dict.get(future_frame, 0.0)
360
+ if dist > max_future_dist:
361
+ max_future_dist = dist
362
+ if max_future_dist < adaptive_return_distance:
363
+ continue
364
+
365
+ kick_frame = frame
366
+ break
367
+
368
+ state.yolo_smoothed_centers = smoothed
369
+ state.yolo_speeds = speeds
370
+ state.yolo_distance_from_start = distance_dict
371
+ state.yolo_threshold = speed_threshold
372
+ state.yolo_baseline_speed = baseline_speed
373
+ state.yolo_speed_std = speed_std
374
+ state.yolo_kick_frames = frames_ordered
375
+ state.yolo_kick_speeds = speed_series
376
+ state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered]
377
+ state.yolo_mask_area_proxy = [areas_dict.get(f, 0.0) for f in frames_ordered]
378
+ state.yolo_kick_frame = kick_frame
379
+ coverage = len(centers) / total if total else 0.0
380
+ if kick_frame is not None:
381
+ state.yolo_status = f"✅ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%})."
382
+ else:
383
+ state.yolo_status = (
384
+ f"⚠️ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%}) but did not find a definitive kick."
385
+ )
386
+ state.kalman_centers[BALL_OBJECT_ID] = kalman_pos
387
+ state.kalman_speeds[BALL_OBJECT_ID] = kalman_speed
388
+
389
+ if kick_frame is not None:
390
+ state.kick_frame = kick_frame
391
+ _compute_sam_window_from_kick(state, kick_frame)
392
+ else:
393
+ state.sam_window = None
394
+
395
+
396
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
397
  """Generate a deterministic pastel RGB color for a given object id.
398
 
 
537
  self.player_obj_id: int | None = None
538
  self.player_detection_frame: int | None = None
539
  self.player_detection_conf: float | None = None
540
+ # YOLO tracking caches
541
+ self.yolo_ball_centers: dict[int, tuple[float, float]] = {}
542
+ self.yolo_ball_boxes: dict[int, tuple[int, int, int, int]] = {}
543
+ self.yolo_ball_conf: dict[int, float] = {}
544
+ self.yolo_smoothed_centers: dict[int, tuple[float, float]] = {}
545
+ self.yolo_speeds: dict[int, float] = {}
546
+ self.yolo_distance_from_start: dict[int, float] = {}
547
+ self.yolo_threshold: float | None = None
548
+ self.yolo_baseline_speed: float | None = None
549
+ self.yolo_speed_std: float | None = None
550
+ self.yolo_kick_frame: int | None = None
551
+ self.yolo_status: str = ""
552
+ self.yolo_kick_frames: list[int] = []
553
+ self.yolo_kick_speeds: list[float] = []
554
+ self.yolo_kick_distance: list[float] = []
555
+ self.yolo_mask_area_proxy: list[float] = []
556
+ self.yolo_initial_frame: int | None = None
557
+ # SAM window (start_idx inclusive, end_idx exclusive)
558
+ self.sam_window: tuple[int, int] | None = None
559
 
560
  def __repr__(self):
561
  return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
 
654
  GLOBAL_STATE.impact_debug_speed_kmh = []
655
  GLOBAL_STATE.impact_debug_speed_threshold_px = None
656
  GLOBAL_STATE.impact_meters_per_px = None
657
+ GLOBAL_STATE.yolo_ball_centers = {}
658
+ GLOBAL_STATE.yolo_ball_boxes = {}
659
+ GLOBAL_STATE.yolo_ball_conf = {}
660
+ GLOBAL_STATE.yolo_smoothed_centers = {}
661
+ GLOBAL_STATE.yolo_speeds = {}
662
+ GLOBAL_STATE.yolo_distance_from_start = {}
663
+ GLOBAL_STATE.yolo_threshold = None
664
+ GLOBAL_STATE.yolo_baseline_speed = None
665
+ GLOBAL_STATE.yolo_speed_std = None
666
+ GLOBAL_STATE.yolo_kick_frame = None
667
+ GLOBAL_STATE.yolo_status = ""
668
+ GLOBAL_STATE.yolo_kick_frames = []
669
+ GLOBAL_STATE.yolo_kick_speeds = []
670
+ GLOBAL_STATE.yolo_kick_distance = []
671
+ GLOBAL_STATE.yolo_mask_area_proxy = []
672
+ GLOBAL_STATE.yolo_initial_frame = None
673
+ GLOBAL_STATE.sam_window = None
674
  GLOBAL_STATE.player_obj_id = None
675
  GLOBAL_STATE.player_detection_frame = None
676
  GLOBAL_STATE.player_detection_conf = None
677
+ GLOBAL_STATE.yolo_ball_centers = {}
678
+ GLOBAL_STATE.yolo_ball_boxes = {}
679
+ GLOBAL_STATE.yolo_ball_conf = {}
680
+ GLOBAL_STATE.yolo_smoothed_centers = {}
681
+ GLOBAL_STATE.yolo_speeds = {}
682
+ GLOBAL_STATE.yolo_distance_from_start = {}
683
+ GLOBAL_STATE.yolo_threshold = None
684
+ GLOBAL_STATE.yolo_baseline_speed = None
685
+ GLOBAL_STATE.yolo_speed_std = None
686
+ GLOBAL_STATE.yolo_kick_frame = None
687
+ GLOBAL_STATE.yolo_status = ""
688
+ GLOBAL_STATE.yolo_kick_frames = []
689
+ GLOBAL_STATE.yolo_kick_speeds = []
690
+ GLOBAL_STATE.yolo_kick_distance = []
691
+ GLOBAL_STATE.yolo_mask_area_proxy = []
692
+ GLOBAL_STATE.yolo_initial_frame = None
693
+ GLOBAL_STATE.sam_window = None
694
 
695
  load_model_if_needed(GLOBAL_STATE)
696
 
 
1242
  return fig
1243
 
1244
 
1245
+ def _ensure_ball_prompt_from_yolo(state: AppState):
1246
+ if (
1247
+ state is None
1248
+ or state.inference_session is None
1249
+ or not state.yolo_ball_centers
1250
+ ):
1251
+ return
1252
+ # Check if we already have clicks for the ball
1253
+ for frame_clicks in state.clicks_by_frame_obj.values():
1254
+ if frame_clicks.get(BALL_OBJECT_ID):
1255
+ return
1256
+ anchor_frame = state.yolo_initial_frame
1257
+ if anchor_frame is None and state.yolo_ball_centers:
1258
+ anchor_frame = min(state.yolo_ball_centers.keys())
1259
+ if anchor_frame is None or anchor_frame >= state.num_frames:
1260
+ return
1261
+ center = state.yolo_ball_centers.get(anchor_frame)
1262
+ if center is None:
1263
+ return
1264
+ x_center, y_center = center
1265
+ frame_width, frame_height = state.video_frames[anchor_frame].size
1266
+ x_center = int(np.clip(round(x_center), 0, frame_width - 1))
1267
+ y_center = int(np.clip(round(y_center), 0, frame_height - 1))
1268
+ event = SimpleNamespace(
1269
+ index=(x_center, y_center),
1270
+ value={"x": x_center, "y": y_center},
1271
+ )
1272
+ state.current_obj_id = BALL_OBJECT_ID
1273
+ state.current_label = "positive"
1274
+ state.current_frame_idx = anchor_frame
1275
+ on_image_click(
1276
+ update_frame_display(state, anchor_frame),
1277
+ state,
1278
+ anchor_frame,
1279
+ BALL_OBJECT_ID,
1280
+ "positive",
1281
+ False,
1282
+ event,
1283
+ )
1284
+
1285
+
1286
+ def _build_yolo_plot(state: AppState):
1287
+ fig = go.Figure()
1288
+ if state is None or not state.yolo_kick_frames or not state.yolo_kick_speeds:
1289
+ fig.update_layout(
1290
+ title="YOLO kick diagnostics",
1291
+ xaxis_title="Frame",
1292
+ yaxis_title="Speed (px/s)",
1293
+ )
1294
+ return fig
1295
+
1296
+ frames = state.yolo_kick_frames
1297
+ speeds = state.yolo_kick_speeds
1298
+ distance = state.yolo_kick_distance if state.yolo_kick_distance else [0.0] * len(frames)
1299
+ areas = state.yolo_mask_area_proxy if state.yolo_mask_area_proxy else [0.0] * len(frames)
1300
+ threshold = state.yolo_threshold or 0.0
1301
+ baseline = state.yolo_baseline_speed or 0.0
1302
+ kick_frame = state.yolo_kick_frame
1303
+
1304
+ fig.add_trace(
1305
+ go.Scatter(
1306
+ x=frames,
1307
+ y=speeds,
1308
+ mode="lines+markers",
1309
+ name="YOLO speed",
1310
+ line=dict(color="#4caf50"),
1311
+ )
1312
+ )
1313
+ fig.add_trace(
1314
+ go.Scatter(
1315
+ x=frames,
1316
+ y=[threshold] * len(frames),
1317
+ mode="lines",
1318
+ name="Adaptive threshold",
1319
+ line=dict(color="#ff9800", dash="dash"),
1320
+ )
1321
+ )
1322
+ fig.add_trace(
1323
+ go.Scatter(
1324
+ x=frames,
1325
+ y=[baseline] * len(frames),
1326
+ mode="lines",
1327
+ name="Baseline speed",
1328
+ line=dict(color="#9e9e9e", dash="dot"),
1329
+ )
1330
+ )
1331
+ fig.add_trace(
1332
+ go.Scatter(
1333
+ x=frames,
1334
+ y=distance,
1335
+ mode="lines",
1336
+ name="Distance from start",
1337
+ line=dict(color="#03a9f4"),
1338
+ yaxis="y2",
1339
+ )
1340
+ )
1341
+ fig.add_trace(
1342
+ go.Scatter(
1343
+ x=frames,
1344
+ y=areas,
1345
+ mode="lines",
1346
+ name="Box area proxy",
1347
+ line=dict(color="#ab47bc", dash="dot"),
1348
+ yaxis="y2",
1349
+ )
1350
+ )
1351
+
1352
+ if kick_frame is not None:
1353
+ fig.add_vline(
1354
+ x=kick_frame,
1355
+ line=dict(color="#e91e63", width=2),
1356
+ annotation_text=f"Kick {kick_frame}",
1357
+ annotation_position="top right",
1358
+ )
1359
+
1360
+ fig.update_layout(
1361
+ title="YOLO kick diagnostics",
1362
+ xaxis=dict(title="Frame"),
1363
+ yaxis=dict(title="Speed (px/s)"),
1364
+ yaxis2=dict(
1365
+ title="Distance / Area",
1366
+ overlaying="y",
1367
+ side="right",
1368
+ showgrid=False,
1369
+ ),
1370
+ legend=dict(orientation="h"),
1371
+ margin=dict(t=40, l=40, r=40, b=40),
1372
+ )
1373
+ return fig
1374
+
1375
+
1376
  def _format_impact_status(state: AppState) -> str:
1377
  if state is None:
1378
  return "Impact frame: not computed"
 
1434
 
1435
 
1436
  def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
1437
+ yolo_ready = isinstance(state, AppState) and state.yolo_kick_frame is not None
1438
+ propagate_main_enabled = _ball_has_masks(state) or yolo_ready
1439
+ detect_player_enabled = yolo_ready
1440
+ propagate_player_enabled = _player_has_masks(state)
 
 
 
 
 
1441
  return (
1442
  gr.update(interactive=propagate_main_enabled),
1443
  gr.update(interactive=detect_player_enabled),
 
1879
  "Load a video first.",
1880
  gr.update(),
1881
  _build_kick_plot(GLOBAL_STATE),
1882
+ _build_yolo_plot(GLOBAL_STATE),
1883
  _format_impact_status(GLOBAL_STATE),
1884
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1885
  propagate_main_update,
 
1887
  propagate_player_update,
1888
  )
1889
 
1890
+ _ensure_ball_prompt_from_yolo(GLOBAL_STATE)
1891
+
1892
  processor = deepcopy(GLOBAL_STATE.processor)
1893
  model = deepcopy(GLOBAL_STATE.model)
1894
  inference_session = deepcopy(GLOBAL_STATE.inference_session)
 
1897
  inference_session.cache.inference_device = "cuda"
1898
  model.to("cuda")
1899
 
1900
+ if not GLOBAL_STATE.sam_window:
1901
+ _compute_sam_window_from_kick(
1902
+ GLOBAL_STATE,
1903
+ GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None),
1904
+ )
1905
+ start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames)
1906
+ start_idx = max(0, int(start_idx))
1907
+ end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx)))
1908
+ total = max(1, end_idx - start_idx)
1909
  processed = 0
1910
 
1911
+ _ensure_ball_prompt_from_yolo(GLOBAL_STATE)
1912
+
1913
  # Initial status; no slider change yet
1914
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1915
  yield (
 
1917
  f"Propagating masks: {processed}/{total}",
1918
  gr.update(),
1919
  _build_kick_plot(GLOBAL_STATE),
1920
+ _build_yolo_plot(GLOBAL_STATE),
1921
  _format_impact_status(GLOBAL_STATE),
1922
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1923
  propagate_main_update,
 
1925
  propagate_player_update,
1926
  )
1927
 
1928
+ last_frame_idx = start_idx
1929
  with torch.inference_mode():
1930
+ for frame_idx in range(start_idx, end_idx):
1931
+ frame = GLOBAL_STATE.video_frames[frame_idx]
1932
  pixel_values = None
1933
  if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
1934
  pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
 
1957
  f"Propagating masks: {processed}/{total}",
1958
  gr.update(value=frame_idx),
1959
  _build_kick_plot(GLOBAL_STATE),
1960
+ _build_yolo_plot(GLOBAL_STATE),
1961
  _format_impact_status(GLOBAL_STATE),
1962
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1963
  propagate_main_update,
 
1981
  text,
1982
  gr.update(value=target_frame),
1983
  _build_kick_plot(GLOBAL_STATE),
1984
+ _build_yolo_plot(GLOBAL_STATE),
1985
  _format_impact_status(GLOBAL_STATE),
1986
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1987
  propagate_main_update,
 
2074
  status,
2075
  gr.update(visible=False, value=""),
2076
  _build_kick_plot(GLOBAL_STATE),
2077
+ _build_yolo_plot(GLOBAL_STATE),
2078
  _format_impact_status(GLOBAL_STATE),
2079
  propagate_main_update,
2080
  detect_btn_update,
 
2322
  """
2323
  **Working with results**
2324
  - **Preview**: Use the slider to navigate frames and see the current masks.
2325
+ - **Track**: Click “Track ball (SAM2)” to track all defined objects across the selected window. The preview follows progress periodically to keep things responsive.
2326
  - **Export**: Render an MP4 for smooth playback using the original video FPS.
2327
  - **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
2328
  """
 
2380
  )
2381
  with gr.Row():
2382
  detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
2383
+ track_ball_yolo_btn = gr.Button("Track ball (YOLO13)", variant="secondary")
2384
+ propagate_btn = gr.Button("Track ball (SAM2)", variant="primary", interactive=False)
2385
  detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
2386
  propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
2387
  ball_status = gr.Markdown(visible=False)
 
2393
  clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
2394
  prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
2395
  kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True)
2396
+ yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
2397
 
2398
  # Wire events
2399
  def _on_video_change(GLOBAL_STATE: gr.State, video):
 
2406
  status,
2407
  gr.update(visible=False, value=""),
2408
  _build_kick_plot(GLOBAL_STATE),
2409
+ _build_yolo_plot(GLOBAL_STATE),
2410
  _format_impact_status(GLOBAL_STATE),
2411
  propagate_main_update,
2412
  detect_btn_update,
 
2416
  video_in.change(
2417
  _on_video_change,
2418
  inputs=[GLOBAL_STATE, video_in],
2419
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2420
  show_progress=True,
2421
  )
2422
 
 
2429
  examples=examples_list,
2430
  inputs=[GLOBAL_STATE, video_in],
2431
  fn=_on_video_change,
2432
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2433
  label="Examples",
2434
  cache_examples=False,
2435
  examples_per_page=5,
 
2619
  outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn],
2620
  )
2621
 
2622
+ def _track_ball_yolo(state_in: AppState):
2623
+ if state_in is None or state_in.num_frames == 0:
2624
+ raise gr.Error("Load a video first, then track the ball with YOLO.")
2625
+ progress = gr.Progress(track_tqdm=False)
2626
+ _perform_yolo_ball_tracking(state_in, progress=progress)
2627
+ target_frame = (
2628
+ state_in.yolo_kick_frame
2629
+ if state_in.yolo_kick_frame is not None
2630
+ else state_in.yolo_initial_frame
2631
+ if state_in.yolo_initial_frame is not None
2632
+ else 0
2633
+ )
2634
+ if state_in.num_frames:
2635
+ target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
2636
+ state_in.current_frame_idx = target_frame
2637
+ preview_img = update_frame_display(state_in, target_frame)
2638
+ base_msg = state_in.yolo_status or ""
2639
+ kick_msg = _format_kick_status(state_in)
2640
+ status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
2641
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
2642
+ return (
2643
+ preview_img,
2644
+ gr.update(value=status_text, visible=True),
2645
+ gr.update(value=target_frame),
2646
+ _build_kick_plot(state_in),
2647
+ _build_yolo_plot(state_in),
2648
+ propagate_main_update,
2649
+ detect_btn_update,
2650
+ propagate_player_update,
2651
+ )
2652
+
2653
+ track_ball_yolo_btn.click(
2654
+ _track_ball_yolo,
2655
+ inputs=[GLOBAL_STATE],
2656
+ outputs=[preview, ball_status, frame_slider, kick_plot, yolo_plot, propagate_btn, detect_player_btn, propagate_player_btn],
2657
+ )
2658
+
2659
  def _auto_detect_player(state_in: AppState):
2660
  if state_in is None or state_in.num_frames == 0:
2661
  raise gr.Error("Load a video first, then try auto-detect.")
 
2772
  "Load a video first.",
2773
  gr.update(),
2774
  _build_kick_plot(GLOBAL_STATE),
2775
+ _build_yolo_plot(GLOBAL_STATE),
2776
  _format_impact_status(GLOBAL_STATE),
2777
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2778
  propagate_main_update,
 
2786
  "Detect the player before propagating.",
2787
  gr.update(),
2788
  _build_kick_plot(GLOBAL_STATE),
2789
+ _build_yolo_plot(GLOBAL_STATE),
2790
  _format_impact_status(GLOBAL_STATE),
2791
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2792
  propagate_main_update,
 
2801
  inference_session.cache.inference_device = "cuda"
2802
  model.to("cuda")
2803
 
2804
+ if not GLOBAL_STATE.sam_window:
2805
+ _compute_sam_window_from_kick(
2806
+ GLOBAL_STATE,
2807
+ GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None),
2808
+ )
2809
+ start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames)
2810
+ start_idx = max(0, int(start_idx))
2811
+ end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx)))
2812
+ total = max(1, end_idx - start_idx)
2813
  processed = 0
2814
+ last_frame_idx = start_idx
2815
 
2816
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2817
  yield (
 
2819
  f"Propagating player: {processed}/{total}",
2820
  gr.update(),
2821
  _build_kick_plot(GLOBAL_STATE),
2822
+ _build_yolo_plot(GLOBAL_STATE),
2823
  _format_impact_status(GLOBAL_STATE),
2824
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2825
  propagate_main_update,
 
2830
  player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID
2831
 
2832
  with torch.inference_mode():
2833
+ for frame_idx in range(start_idx, end_idx):
2834
+ frame = GLOBAL_STATE.video_frames[frame_idx]
2835
  pixel_values = None
2836
  if (
2837
  inference_session.processed_frames is None
 
2857
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
2858
 
2859
  processed += 1
2860
+ last_frame_idx = frame_idx
2861
  if processed % 30 == 0 or processed == total:
2862
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2863
  yield (
 
2865
  f"Propagating player: {processed}/{total}",
2866
  gr.update(value=frame_idx),
2867
  _build_kick_plot(GLOBAL_STATE),
2868
+ _build_yolo_plot(GLOBAL_STATE),
2869
  _format_impact_status(GLOBAL_STATE),
2870
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2871
  propagate_main_update,
 
2878
  if target_frame is None:
2879
  target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None)
2880
  if target_frame is None:
2881
+ target_frame = last_frame_idx
2882
  target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
2883
  GLOBAL_STATE.current_frame_idx = target_frame
2884
 
 
2888
  text,
2889
  gr.update(value=target_frame),
2890
  _build_kick_plot(GLOBAL_STATE),
2891
+ _build_yolo_plot(GLOBAL_STATE),
2892
  _format_impact_status(GLOBAL_STATE),
2893
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2894
  propagate_main_update,
 
2899
  propagate_player_btn.click(
2900
  propagate_player_masks,
2901
  inputs=[GLOBAL_STATE],
2902
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2903
  )
2904
 
2905
  # Image click to add a point and run forward on that frame
 
2968
  propagate_btn.click(
2969
  propagate_masks,
2970
  inputs=[GLOBAL_STATE],
2971
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2972
  )
2973
 
2974
  reset_btn.click(
2975
  reset_session,
2976
  inputs=GLOBAL_STATE,
2977
+ outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2978
  )
2979
 
2980
  # ============================================================================