Mirko Trasciatti commited on
Commit
778b9c7
·
1 Parent(s): 697126c

Add kick detection from SAM2 trajectories

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py CHANGED
@@ -2,6 +2,8 @@ import colorsys
2
  import gc
3
  from copy import deepcopy
4
  import base64
 
 
5
  from pathlib import Path
6
  BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
7
  EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
@@ -213,6 +215,10 @@ class AppState:
213
  self.pending_box_start_obj_id: int | None = None
214
  self.is_switching_model: bool = False
215
  self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
 
 
 
 
216
  # Model selection
217
  self.model_repo_key: str = "tiny"
218
  self.model_repo_id: str | None = None
@@ -288,6 +294,10 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
288
  GLOBAL_STATE.masks_by_frame = {}
289
  GLOBAL_STATE.color_by_obj = {}
290
  GLOBAL_STATE.ball_centers = {}
 
 
 
 
291
 
292
  load_model_if_needed(GLOBAL_STATE)
293
 
@@ -499,10 +509,129 @@ def _update_centroids_for_frame(state: AppState, frame_idx: int):
499
  centers.pop(int(frame_idx), None)
500
  seen_obj_ids.add(int(obj_id))
501
  _ensure_color_for_obj(state, int(obj_id))
 
 
 
 
 
 
 
502
  # Remove frames for objects without masks at this frame
503
  for obj_id, centers in state.ball_centers.items():
504
  if obj_id not in seen_obj_ids:
505
  centers.pop(int(frame_idx), None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
 
508
  def on_image_click(
@@ -708,6 +837,11 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
708
  GLOBAL_STATE.pending_box_start_frame_idx = None
709
  GLOBAL_STATE.pending_box_start_obj_id = None
710
  GLOBAL_STATE.ball_centers.clear()
 
 
 
 
 
711
 
712
  # Dispose and re-init inference session for current model with existing frames
713
  try:
@@ -1154,6 +1288,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1154
  )
1155
 
1156
  status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
 
 
1157
  return preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx)
1158
 
1159
  detect_ball_btn.click(
 
2
  import gc
3
  from copy import deepcopy
4
  import base64
5
+ import math
6
+ import statistics
7
  from pathlib import Path
8
  BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
9
  EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
 
215
  self.pending_box_start_obj_id: int | None = None
216
  self.is_switching_model: bool = False
217
  self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
218
+ self.mask_areas: dict[int, dict[int, float]] = {}
219
+ self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
220
+ self.ball_speeds: dict[int, dict[int, float]] = {}
221
+ self.kick_frame: int | None = None
222
  # Model selection
223
  self.model_repo_key: str = "tiny"
224
  self.model_repo_id: str | None = None
 
294
  GLOBAL_STATE.masks_by_frame = {}
295
  GLOBAL_STATE.color_by_obj = {}
296
  GLOBAL_STATE.ball_centers = {}
297
+ GLOBAL_STATE.mask_areas = {}
298
+ GLOBAL_STATE.smoothed_centers = {}
299
+ GLOBAL_STATE.ball_speeds = {}
300
+ GLOBAL_STATE.kick_frame = None
301
 
302
  load_model_if_needed(GLOBAL_STATE)
303
 
 
509
  centers.pop(int(frame_idx), None)
510
  seen_obj_ids.add(int(obj_id))
511
  _ensure_color_for_obj(state, int(obj_id))
512
+ mask_np = np.array(mask)
513
+ if mask_np.ndim == 3:
514
+ mask_np = mask_np.squeeze()
515
+ mask_np = np.clip(mask_np, 0.0, 1.0)
516
+ area = float(np.count_nonzero(mask_np > 0.3))
517
+ areas = state.mask_areas.setdefault(int(obj_id), {})
518
+ areas[int(frame_idx)] = area
519
  # Remove frames for objects without masks at this frame
520
  for obj_id, centers in state.ball_centers.items():
521
  if obj_id not in seen_obj_ids:
522
  centers.pop(int(frame_idx), None)
523
+ for obj_id, areas in state.mask_areas.items():
524
+ if obj_id not in seen_obj_ids:
525
+ areas.pop(int(frame_idx), None)
526
+ _recompute_motion_metrics(state)
527
+
528
+
529
+ def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
530
+ centers = state.ball_centers.get(target_obj_id)
531
+ if not centers or len(centers) < 3:
532
+ state.smoothed_centers[target_obj_id] = {}
533
+ state.ball_speeds[target_obj_id] = {}
534
+ state.kick_frame = None
535
+ return
536
+
537
+ items = sorted(centers.items())
538
+ dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
539
+ alpha = 0.35
540
+
541
+ smoothed: dict[int, tuple[float, float]] = {}
542
+ speeds: dict[int, float] = {}
543
+
544
+ prev_frame = None
545
+ prev_smooth = None
546
+ for frame_idx, (cx, cy) in items:
547
+ if prev_smooth is None:
548
+ smooth_x, smooth_y = float(cx), float(cy)
549
+ else:
550
+ smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
551
+ smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
552
+ smoothed[frame_idx] = (smooth_x, smooth_y)
553
+ if prev_smooth is None or prev_frame is None:
554
+ speeds[frame_idx] = 0.0
555
+ else:
556
+ frame_delta = max(1, frame_idx - prev_frame)
557
+ time_delta = frame_delta * dt
558
+ dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
559
+ speed = dist / time_delta if time_delta > 0 else dist
560
+ speeds[frame_idx] = speed
561
+ prev_smooth = (smooth_x, smooth_y)
562
+ prev_frame = frame_idx
563
+
564
+ state.smoothed_centers[target_obj_id] = smoothed
565
+ state.ball_speeds[target_obj_id] = speeds
566
+ state.kick_frame = _detect_kick_frame(state, target_obj_id)
567
+
568
+
569
+ def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
570
+ smoothed = state.smoothed_centers.get(target_obj_id, {})
571
+ speeds = state.ball_speeds.get(target_obj_id, {})
572
+ if len(smoothed) < 5:
573
+ return None
574
+
575
+ frames = sorted(smoothed.keys())
576
+ speed_series = [speeds.get(f, 0.0) for f in frames]
577
+
578
+ baseline_window = min(5, len(frames) // 3 or 1)
579
+ baseline_speeds = speed_series[:baseline_window]
580
+ baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
581
+
582
+ speed_threshold = baseline_speed + 80.0 # pixels/second
583
+ sustain_frames = 3
584
+ holdout_frames = 8
585
+ return_distance = 12.0
586
+ area_window = 4
587
+ area_drop_ratio = 0.75
588
+ areas_dict = state.mask_areas.get(target_obj_id, {})
589
+ initial_center = smoothed[frames[0]]
590
+
591
+ for idx in range(baseline_window, len(frames)):
592
+ frame = frames[idx]
593
+ speed = speed_series[idx]
594
+ if speed < speed_threshold:
595
+ continue
596
+
597
+ sustain_ok = True
598
+ for j in range(1, sustain_frames + 1):
599
+ if idx + j >= len(frames):
600
+ break
601
+ if speed_series[idx + j] < speed_threshold * 0.8:
602
+ sustain_ok = False
603
+ break
604
+ if not sustain_ok:
605
+ continue
606
+
607
+ area_pass = True
608
+ current_area = areas_dict.get(frame)
609
+ if current_area:
610
+ prev_areas = [
611
+ areas_dict.get(f)
612
+ for f in frames[max(0, idx - area_window):idx]
613
+ if areas_dict.get(f) is not None
614
+ ]
615
+ if prev_areas:
616
+ median_prev = statistics.median(prev_areas)
617
+ if median_prev > 0 and current_area / median_prev > area_drop_ratio:
618
+ area_pass = False
619
+ if not area_pass:
620
+ continue
621
+
622
+ moved_far = True
623
+ for future_frame in frames[idx:min(len(frames), idx + holdout_frames)]:
624
+ cx, cy = smoothed[future_frame]
625
+ dist = math.hypot(cx - initial_center[0], cy - initial_center[1])
626
+ if dist < return_distance:
627
+ moved_far = False
628
+ break
629
+ if not moved_far:
630
+ continue
631
+
632
+ return frame
633
+
634
+ return None
635
 
636
 
637
  def on_image_click(
 
837
  GLOBAL_STATE.pending_box_start_frame_idx = None
838
  GLOBAL_STATE.pending_box_start_obj_id = None
839
  GLOBAL_STATE.ball_centers.clear()
840
+ GLOBAL_STATE.mask_areas.clear()
841
+ GLOBAL_STATE.smoothed_centers.clear()
842
+ GLOBAL_STATE.ball_speeds.clear()
843
+ GLOBAL_STATE.kick_frame = None
844
+ GLOBAL_STATE.ball_centers.clear()
845
 
846
  # Dispose and re-init inference session for current model with existing frames
847
  try:
 
1288
  )
1289
 
1290
  status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
1291
+ if state_in.kick_frame is not None:
1292
+ status_text += f" | Kick frame ≈ {state_in.kick_frame}"
1293
  return preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx)
1294
 
1295
  detect_ball_btn.click(