Mirko Trasciatti commited on
Commit
b95892a
·
1 Parent(s): fe5cf63

Add Kalman tracker and distance visualization

Browse files
Files changed (1) hide show
  1. app.py +115 -1
app.py CHANGED
@@ -229,6 +229,14 @@ class AppState:
229
  self.kick_debug_area: list[float] = []
230
  self.kick_debug_kick_frame: int | None = None
231
  self.kick_debug_distance: list[float] = []
 
 
 
 
 
 
 
 
232
  # Model selection
233
  self.model_repo_key: str = "tiny"
234
  self.model_repo_id: str | None = None
@@ -308,6 +316,18 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
308
  GLOBAL_STATE.smoothed_centers = {}
309
  GLOBAL_STATE.ball_speeds = {}
310
  GLOBAL_STATE.kick_frame = None
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  load_model_if_needed(GLOBAL_STATE)
313
 
@@ -543,6 +563,69 @@ def _update_centroids_for_frame(state: AppState, frame_idx: int):
543
  _recompute_motion_metrics(state)
544
 
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  def _build_kick_plot(state: AppState):
547
  fig = go.Figure()
548
  if state is None or not state.kick_debug_frames or not state.kick_debug_speeds:
@@ -605,9 +688,18 @@ def _build_kick_plot(state: AppState):
605
  mode="lines",
606
  name="Distance from start",
607
  line=dict(color="#9467bd"),
608
- yaxis="y2",
609
  )
610
  )
 
 
 
 
 
 
 
 
 
 
611
  if kick_frame is not None:
612
  fig.add_trace(
613
  go.Scatter(
@@ -649,6 +741,10 @@ def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
649
  state.kick_debug_area = []
650
  state.kick_debug_kick_frame = None
651
  state.kick_debug_distance = []
 
 
 
 
652
  return
653
 
654
  items = sorted(centers.items())
@@ -680,6 +776,10 @@ def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
680
 
681
  state.smoothed_centers[target_obj_id] = smoothed
682
  state.ball_speeds[target_obj_id] = speeds
 
 
 
 
683
  state.kick_frame = _detect_kick_frame(state, target_obj_id)
684
 
685
 
@@ -720,6 +820,8 @@ def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
720
  math.hypot(smoothed[f][0] - initial_center[0], smoothed[f][1] - initial_center[1])
721
  for f in frames
722
  ]
 
 
723
  state.kick_debug_kick_frame = None
724
 
725
  for idx in range(baseline_window, len(frames)):
@@ -988,6 +1090,18 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
988
  GLOBAL_STATE.ball_speeds.clear()
989
  GLOBAL_STATE.kick_frame = None
990
  GLOBAL_STATE.ball_centers.clear()
 
 
 
 
 
 
 
 
 
 
 
 
991
 
992
  # Dispose and re-init inference session for current model with existing frames
993
  try:
 
229
  self.kick_debug_area: list[float] = []
230
  self.kick_debug_kick_frame: int | None = None
231
  self.kick_debug_distance: list[float] = []
232
+ self.kalman_centers: dict[int, dict[int, tuple[float, float]]] = {}
233
+ self.kalman_speeds: dict[int, dict[int, float]] = {}
234
+ self.kalman_residuals: dict[int, dict[int, float]] = {}
235
+ self.kick_debug_kalman_speeds: list[float] = []
236
+ self.kalman_centers: dict[int, dict[int, tuple[float, float]]] = {}
237
+ self.kalman_speeds: dict[int, dict[int, float]] = {}
238
+ self.kalman_residuals: dict[int, dict[int, float]] = {}
239
+ self.kick_debug_kalman_speeds: list[float] = []
240
  # Model selection
241
  self.model_repo_key: str = "tiny"
242
  self.model_repo_id: str | None = None
 
316
  GLOBAL_STATE.smoothed_centers = {}
317
  GLOBAL_STATE.ball_speeds = {}
318
  GLOBAL_STATE.kick_frame = None
319
+ GLOBAL_STATE.kalman_centers = {}
320
+ GLOBAL_STATE.kalman_speeds = {}
321
+ GLOBAL_STATE.kalman_residuals = {}
322
+ GLOBAL_STATE.kick_debug_kalman_speeds = []
323
+ GLOBAL_STATE.kick_debug_frames = []
324
+ GLOBAL_STATE.kick_debug_speeds = []
325
+ GLOBAL_STATE.kick_debug_threshold = None
326
+ GLOBAL_STATE.kick_debug_baseline = None
327
+ GLOBAL_STATE.kick_debug_speed_std = None
328
+ GLOBAL_STATE.kick_debug_area = []
329
+ GLOBAL_STATE.kick_debug_kick_frame = None
330
+ GLOBAL_STATE.kick_debug_distance = []
331
 
332
  load_model_if_needed(GLOBAL_STATE)
333
 
 
563
  _recompute_motion_metrics(state)
564
 
565
 
566
+ def _run_kalman_filter(
567
+ ordered_items: list[tuple[int, tuple[float, float]]],
568
+ base_dt: float,
569
+ ) -> tuple[dict[int, tuple[float, float]], dict[int, float], dict[int, float]]:
570
+ if not ordered_items:
571
+ return {}, {}, {}
572
+
573
+ H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=float)
574
+ R = np.eye(2, dtype=float) * 25.0
575
+
576
+ state_vec = np.array(
577
+ [ordered_items[0][1][0], ordered_items[0][1][1], 0.0, 0.0], dtype=float
578
+ )
579
+ P = np.eye(4, dtype=float) * 50.0
580
+
581
+ positions: dict[int, tuple[float, float]] = {}
582
+ speeds: dict[int, float] = {}
583
+ residuals: dict[int, float] = {}
584
+
585
+ prev_frame = ordered_items[0][0]
586
+
587
+ for frame_idx, (cx, cy) in ordered_items:
588
+ frame_delta = max(1, frame_idx - prev_frame) if frame_idx != prev_frame else 1
589
+ dt = frame_delta * base_dt
590
+ F = np.array(
591
+ [
592
+ [1, 0, dt, 0],
593
+ [0, 1, 0, dt],
594
+ [0, 0, 1, 0],
595
+ [0, 0, 0, 1],
596
+ ],
597
+ dtype=float,
598
+ )
599
+ q = 0.5 * dt**2
600
+ Q = np.array(
601
+ [
602
+ [q, 0, dt, 0],
603
+ [0, q, 0, dt],
604
+ [dt, 0, 1, 0],
605
+ [0, dt, 0, 1],
606
+ ],
607
+ dtype=float,
608
+ ) * 0.05
609
+
610
+ state_vec = F @ state_vec
611
+ P = F @ P @ F.T + Q
612
+
613
+ z = np.array([cx, cy], dtype=float)
614
+ innovation = z - H @ state_vec
615
+ S = H @ P @ H.T + R
616
+ K = P @ H.T @ np.linalg.inv(S)
617
+ state_vec = state_vec + K @ innovation
618
+ P = (np.eye(4) - K @ H) @ P
619
+
620
+ positions[frame_idx] = (state_vec[0], state_vec[1])
621
+ speeds[frame_idx] = float(math.hypot(state_vec[2], state_vec[3]))
622
+ residuals[frame_idx] = float(math.hypot(innovation[0], innovation[1]))
623
+
624
+ prev_frame = frame_idx
625
+
626
+ return positions, speeds, residuals
627
+
628
+
629
  def _build_kick_plot(state: AppState):
630
  fig = go.Figure()
631
  if state is None or not state.kick_debug_frames or not state.kick_debug_speeds:
 
688
  mode="lines",
689
  name="Distance from start",
690
  line=dict(color="#9467bd"),
 
691
  )
692
  )
693
+ if state.kick_debug_kalman_speeds:
694
+ fig.add_trace(
695
+ go.Scatter(
696
+ x=frames,
697
+ y=state.kick_debug_kalman_speeds,
698
+ mode="lines",
699
+ name="Kalman speed",
700
+ line=dict(color="#8c564b"),
701
+ )
702
+ )
703
  if kick_frame is not None:
704
  fig.add_trace(
705
  go.Scatter(
 
741
  state.kick_debug_area = []
742
  state.kick_debug_kick_frame = None
743
  state.kick_debug_distance = []
744
+ state.kalman_centers[target_obj_id] = {}
745
+ state.kalman_speeds[target_obj_id] = {}
746
+ state.kalman_residuals[target_obj_id] = {}
747
+ state.kick_debug_kalman_speeds = []
748
  return
749
 
750
  items = sorted(centers.items())
 
776
 
777
  state.smoothed_centers[target_obj_id] = smoothed
778
  state.ball_speeds[target_obj_id] = speeds
779
+ kalman_pos, kalman_speed, kalman_res = _run_kalman_filter(items, dt)
780
+ state.kalman_centers[target_obj_id] = kalman_pos
781
+ state.kalman_speeds[target_obj_id] = kalman_speed
782
+ state.kalman_residuals[target_obj_id] = kalman_res
783
  state.kick_frame = _detect_kick_frame(state, target_obj_id)
784
 
785
 
 
820
  math.hypot(smoothed[f][0] - initial_center[0], smoothed[f][1] - initial_center[1])
821
  for f in frames
822
  ]
823
+ kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {})
824
+ state.kick_debug_kalman_speeds = [kalman_speed_dict.get(f, 0.0) for f in frames]
825
  state.kick_debug_kick_frame = None
826
 
827
  for idx in range(baseline_window, len(frames)):
 
1090
  GLOBAL_STATE.ball_speeds.clear()
1091
  GLOBAL_STATE.kick_frame = None
1092
  GLOBAL_STATE.ball_centers.clear()
1093
+ GLOBAL_STATE.kalman_centers.clear()
1094
+ GLOBAL_STATE.kalman_speeds.clear()
1095
+ GLOBAL_STATE.kalman_residuals.clear()
1096
+ GLOBAL_STATE.kick_debug_frames = []
1097
+ GLOBAL_STATE.kick_debug_speeds = []
1098
+ GLOBAL_STATE.kick_debug_threshold = None
1099
+ GLOBAL_STATE.kick_debug_baseline = None
1100
+ GLOBAL_STATE.kick_debug_speed_std = None
1101
+ GLOBAL_STATE.kick_debug_area = []
1102
+ GLOBAL_STATE.kick_debug_kick_frame = None
1103
+ GLOBAL_STATE.kick_debug_distance = []
1104
+ GLOBAL_STATE.kick_debug_kalman_speeds = []
1105
 
1106
  # Dispose and re-init inference session for current model with existing frames
1107
  try: