Mirko Trasciatti commited on
Commit
4668c44
·
1 Parent(s): cccfb86

Add kick diagnostics plot and smoothing overlay

Browse files
Files changed (2) hide show
  1. app.py +160 -23
  2. requirements.txt +1 -0
app.py CHANGED
@@ -5,6 +5,8 @@ import base64
5
  import math
6
  import statistics
7
  from pathlib import Path
 
 
8
  BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
9
  EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
10
 
@@ -219,6 +221,13 @@ class AppState:
219
  self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
220
  self.ball_speeds: dict[int, dict[int, float]] = {}
221
  self.kick_frame: int | None = None
 
 
 
 
 
 
 
222
  # Model selection
223
  self.model_repo_key: str = "tiny"
224
  self.model_repo_id: str | None = None
@@ -435,28 +444,35 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
435
  for obj_id, centers in state.ball_centers.items():
436
  if not centers:
437
  continue
438
- items = sorted(centers.items())
 
 
 
 
 
 
 
439
  distances: list[float] = []
440
  prev_center = None
441
- for _, (cx, cy) in items:
442
  if prev_center is None:
443
  distances.append(0.0)
444
  else:
445
- dx = cx - prev_center[0]
446
- dy = cy - prev_center[1]
447
  distances.append(float(np.hypot(dx, dy)))
448
- prev_center = (cx, cy)
449
  max_dist = max(distances[1:], default=0.0)
450
  color_by_frame: dict[int, tuple[int, int, int]] = {}
451
- for (f_idx, _), dist in zip(items, distances):
452
  ratio = dist / max_dist if max_dist > 0 else 0.0
453
  color_by_frame[f_idx] = _speed_to_color(ratio)
454
- for f_idx, (cx, cy) in reversed(items):
455
  highlight = (f_idx == frame_idx)
456
  color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0))
457
  line_width = 1 if not highlight else 2
458
- draw.line([(cx - cross_half, cy), (cx + cross_half, cy)], fill=color, width=line_width)
459
- draw.line([(cx, cy - cross_half), (cx, cy + cross_half)], fill=color, width=line_width)
460
  # Save to cache and return
461
  state.composited_frames[frame_idx] = out_img
462
  return out_img
@@ -526,12 +542,100 @@ def _update_centroids_for_frame(state: AppState, frame_idx: int):
526
  _recompute_motion_metrics(state)
527
 
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
530
  centers = state.ball_centers.get(target_obj_id)
531
  if not centers or len(centers) < 3:
532
  state.smoothed_centers[target_obj_id] = {}
533
  state.ball_speeds[target_obj_id] = {}
534
  state.kick_frame = None
 
 
 
 
 
 
 
535
  return
536
 
537
  items = sorted(centers.items())
@@ -591,6 +695,14 @@ def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
591
  areas_dict = state.mask_areas.get(target_obj_id, {})
592
  initial_center = smoothed[frames[0]]
593
 
 
 
 
 
 
 
 
 
594
  for idx in range(baseline_window, len(frames)):
595
  frame = frames[idx]
596
  speed = speed_series[idx]
@@ -634,6 +746,7 @@ def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
634
  if not moved_far:
635
  continue
636
 
 
637
  return frame
638
 
639
  return None
@@ -778,7 +891,7 @@ def on_image_click(
778
  def propagate_masks(GLOBAL_STATE: gr.State):
779
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
780
  # yield GLOBAL_STATE, "Load a video first.", gr.update()
781
- return GLOBAL_STATE, "Load a video first.", gr.update()
782
 
783
  processor = deepcopy(GLOBAL_STATE.processor)
784
  model = deepcopy(GLOBAL_STATE.model)
@@ -792,7 +905,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
792
  processed = 0
793
 
794
  # Initial status; no slider change yet
795
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
796
 
797
  last_frame_idx = 0
798
  with torch.inference_mode():
@@ -819,19 +932,27 @@ def propagate_masks(GLOBAL_STATE: gr.State):
819
  processed += 1
820
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
821
  if processed % 30 == 0 or processed == total:
822
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
823
 
824
  text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
825
 
826
  # Final status; ensure slider points to last processed frame
827
- yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
828
 
829
 
830
- def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
831
  # Reset only session-related state, keep uploaded video and model
832
  if not GLOBAL_STATE.video_frames:
833
  # Nothing loaded; keep behavior
834
- return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video."
 
 
 
 
 
 
 
 
835
 
836
  # Clear prompts and caches
837
  GLOBAL_STATE.masks_by_frame.clear()
@@ -866,7 +987,15 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
866
  slider_value = gr.update(value=current_idx)
867
  status = "Session reset. Prompts cleared; video preserved."
868
  # clear and reload model and processor
869
- return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status, gr.update(visible=False, value="")
 
 
 
 
 
 
 
 
870
 
871
 
872
  def create_annotation_preview(video_file, annotations):
@@ -1135,6 +1264,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1135
  label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
1136
  clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
1137
  prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
 
1138
 
1139
  # Wire events
1140
  def _on_video_change(GLOBAL_STATE: gr.State, video):
@@ -1144,13 +1274,14 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1144
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
1145
  first_frame,
1146
  status,
1147
- gr.update(visible=False, value="")
 
1148
  )
1149
 
1150
  video_in.change(
1151
  _on_video_change,
1152
  inputs=[GLOBAL_STATE, video_in],
1153
- outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status],
1154
  show_progress=True,
1155
  )
1156
 
@@ -1165,7 +1296,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1165
  examples=examples_list,
1166
  inputs=[GLOBAL_STATE, video_in],
1167
  fn=_on_video_change,
1168
- outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status],
1169
  label="Examples",
1170
  cache_examples=False,
1171
  examples_per_page=5,
@@ -1265,6 +1396,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1265
  visible=True,
1266
  ),
1267
  gr.update(value=frame_idx),
 
1268
  )
1269
 
1270
  x_center, y_center, _, _, conf = detection
@@ -1297,12 +1429,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1297
  status_text += f" | Kick frame ≈ {state_in.kick_frame}"
1298
  else:
1299
  status_text += " | Kick frame: not detected"
1300
- return preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx)
 
 
 
 
 
1301
 
1302
  detect_ball_btn.click(
1303
  _auto_detect_ball,
1304
  inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
1305
- outputs=[preview, ball_status, frame_slider],
1306
  )
1307
 
1308
  # Image click to add a point and run forward on that frame
@@ -1352,13 +1489,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1352
  propagate_btn.click(
1353
  propagate_masks,
1354
  inputs=[GLOBAL_STATE],
1355
- outputs=[GLOBAL_STATE, propagate_status, frame_slider],
1356
  )
1357
 
1358
  reset_btn.click(
1359
  reset_session,
1360
  inputs=GLOBAL_STATE,
1361
- outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status],
1362
  )
1363
 
1364
  # ============================================================================
 
5
  import math
6
  import statistics
7
  from pathlib import Path
8
+
9
+ import plotly.graph_objects as go
10
  BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
11
  EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
12
 
 
221
  self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
222
  self.ball_speeds: dict[int, dict[int, float]] = {}
223
  self.kick_frame: int | None = None
224
+ self.kick_debug_frames: list[int] = []
225
+ self.kick_debug_speeds: list[float] = []
226
+ self.kick_debug_threshold: float | None = None
227
+ self.kick_debug_baseline: float | None = None
228
+ self.kick_debug_speed_std: float | None = None
229
+ self.kick_debug_area: list[float] = []
230
+ self.kick_debug_kick_frame: int | None = None
231
  # Model selection
232
  self.model_repo_key: str = "tiny"
233
  self.model_repo_id: str | None = None
 
444
  for obj_id, centers in state.ball_centers.items():
445
  if not centers:
446
  continue
447
+ raw_items = sorted(centers.items())
448
+ for _, (rx, ry) in raw_items:
449
+ draw.line([(rx - cross_half, ry), (rx + cross_half, ry)], fill=(160, 160, 160), width=1)
450
+ draw.line([(rx, ry - cross_half), (rx, ry + cross_half)], fill=(160, 160, 160), width=1)
451
+ smooth_dict = state.smoothed_centers.get(obj_id, {})
452
+ if not smooth_dict:
453
+ continue
454
+ smooth_items = sorted(smooth_dict.items())
455
  distances: list[float] = []
456
  prev_center = None
457
+ for _, (sx, sy) in smooth_items:
458
  if prev_center is None:
459
  distances.append(0.0)
460
  else:
461
+ dx = sx - prev_center[0]
462
+ dy = sy - prev_center[1]
463
  distances.append(float(np.hypot(dx, dy)))
464
+ prev_center = (sx, sy)
465
  max_dist = max(distances[1:], default=0.0)
466
  color_by_frame: dict[int, tuple[int, int, int]] = {}
467
+ for (f_idx, _), dist in zip(smooth_items, distances):
468
  ratio = dist / max_dist if max_dist > 0 else 0.0
469
  color_by_frame[f_idx] = _speed_to_color(ratio)
470
+ for f_idx, (sx, sy) in reversed(smooth_items):
471
  highlight = (f_idx == frame_idx)
472
  color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0))
473
  line_width = 1 if not highlight else 2
474
+ draw.line([(sx - cross_half, sy), (sx + cross_half, sy)], fill=color, width=line_width)
475
+ draw.line([(sx, sy - cross_half), (sx, sy + cross_half)], fill=color, width=line_width)
476
  # Save to cache and return
477
  state.composited_frames[frame_idx] = out_img
478
  return out_img
 
542
  _recompute_motion_metrics(state)
543
 
544
 
545
+ def _build_kick_plot(state: AppState):
546
+ fig = go.Figure()
547
+ if state is None or not state.kick_debug_frames or not state.kick_debug_speeds:
548
+ fig.update_layout(
549
+ title="Kick speed diagnostics",
550
+ xaxis_title="Frame",
551
+ yaxis_title="Speed (px/s)",
552
+ )
553
+ return fig
554
+
555
+ frames = state.kick_debug_frames
556
+ speeds = state.kick_debug_speeds
557
+ areas = state.kick_debug_area if state.kick_debug_area else [0.0] * len(frames)
558
+ threshold = state.kick_debug_threshold or 0.0
559
+ baseline = state.kick_debug_baseline or 0.0
560
+ kick_frame = state.kick_debug_kick_frame
561
+
562
+ fig.add_trace(
563
+ go.Scatter(
564
+ x=frames,
565
+ y=speeds,
566
+ mode="lines+markers",
567
+ name="Speed (px/s)",
568
+ line=dict(color="#1f77b4"),
569
+ )
570
+ )
571
+ fig.add_trace(
572
+ go.Scatter(
573
+ x=[frames[0], frames[-1]],
574
+ y=[threshold, threshold],
575
+ mode="lines",
576
+ name="Adaptive threshold",
577
+ line=dict(color="#d62728", dash="dash"),
578
+ )
579
+ )
580
+ fig.add_trace(
581
+ go.Scatter(
582
+ x=[frames[0], frames[-1]],
583
+ y=[baseline, baseline],
584
+ mode="lines",
585
+ name="Baseline speed",
586
+ line=dict(color="#ff7f0e", dash="dot"),
587
+ )
588
+ )
589
+ fig.add_trace(
590
+ go.Scatter(
591
+ x=frames,
592
+ y=areas,
593
+ mode="lines",
594
+ name="Mask area",
595
+ line=dict(color="#2ca02c"),
596
+ yaxis="y2",
597
+ )
598
+ )
599
+ if kick_frame is not None:
600
+ fig.add_trace(
601
+ go.Scatter(
602
+ x=[kick_frame, kick_frame],
603
+ y=[min(speeds), max(max(speeds), threshold)],
604
+ mode="lines",
605
+ name="Detected kick",
606
+ line=dict(color="#9467bd", dash="dashdot"),
607
+ )
608
+ )
609
+ fig.update_layout(
610
+ title="Kick speed diagnostics",
611
+ xaxis_title="Frame",
612
+ yaxis_title="Speed (px/s)",
613
+ yaxis=dict(side="left"),
614
+ yaxis2=dict(
615
+ title="Mask area (px)",
616
+ overlaying="y",
617
+ side="right",
618
+ showgrid=False,
619
+ ),
620
+ legend=dict(orientation="h"),
621
+ margin=dict(t=40, l=40, r=40, b=40),
622
+ )
623
+ return fig
624
+
625
+
626
  def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
627
  centers = state.ball_centers.get(target_obj_id)
628
  if not centers or len(centers) < 3:
629
  state.smoothed_centers[target_obj_id] = {}
630
  state.ball_speeds[target_obj_id] = {}
631
  state.kick_frame = None
632
+ state.kick_debug_frames = []
633
+ state.kick_debug_speeds = []
634
+ state.kick_debug_threshold = None
635
+ state.kick_debug_baseline = None
636
+ state.kick_debug_speed_std = None
637
+ state.kick_debug_area = []
638
+ state.kick_debug_kick_frame = None
639
  return
640
 
641
  items = sorted(centers.items())
 
695
  areas_dict = state.mask_areas.get(target_obj_id, {})
696
  initial_center = smoothed[frames[0]]
697
 
698
+ state.kick_debug_frames = frames
699
+ state.kick_debug_speeds = speed_series
700
+ state.kick_debug_threshold = speed_threshold
701
+ state.kick_debug_baseline = baseline_speed
702
+ state.kick_debug_speed_std = speed_std
703
+ state.kick_debug_area = [areas_dict.get(f, 0.0) for f in frames]
704
+ state.kick_debug_kick_frame = None
705
+
706
  for idx in range(baseline_window, len(frames)):
707
  frame = frames[idx]
708
  speed = speed_series[idx]
 
746
  if not moved_far:
747
  continue
748
 
749
+ state.kick_debug_kick_frame = frame
750
  return frame
751
 
752
  return None
 
891
  def propagate_masks(GLOBAL_STATE: gr.State):
892
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
893
  # yield GLOBAL_STATE, "Load a video first.", gr.update()
894
+ return GLOBAL_STATE, "Load a video first.", gr.update(), _build_kick_plot(GLOBAL_STATE)
895
 
896
  processor = deepcopy(GLOBAL_STATE.processor)
897
  model = deepcopy(GLOBAL_STATE.model)
 
905
  processed = 0
906
 
907
  # Initial status; no slider change yet
908
+ yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(), _build_kick_plot(GLOBAL_STATE)
909
 
910
  last_frame_idx = 0
911
  with torch.inference_mode():
 
932
  processed += 1
933
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
934
  if processed % 30 == 0 or processed == total:
935
+ yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx), _build_kick_plot(GLOBAL_STATE)
936
 
937
  text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
938
 
939
  # Final status; ensure slider points to last processed frame
940
+ yield GLOBAL_STATE, text, gr.update(value=last_frame_idx), _build_kick_plot(GLOBAL_STATE)
941
 
942
 
943
+ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, gr.Update, go.Figure]:
944
  # Reset only session-related state, keep uploaded video and model
945
  if not GLOBAL_STATE.video_frames:
946
  # Nothing loaded; keep behavior
947
+ return (
948
+ GLOBAL_STATE,
949
+ None,
950
+ 0,
951
+ 0,
952
+ "Session reset. Load a new video.",
953
+ gr.update(visible=False, value=""),
954
+ _build_kick_plot(GLOBAL_STATE),
955
+ )
956
 
957
  # Clear prompts and caches
958
  GLOBAL_STATE.masks_by_frame.clear()
 
987
  slider_value = gr.update(value=current_idx)
988
  status = "Session reset. Prompts cleared; video preserved."
989
  # clear and reload model and processor
990
+ return (
991
+ GLOBAL_STATE,
992
+ preview_img,
993
+ slider_minmax,
994
+ slider_value,
995
+ status,
996
+ gr.update(visible=False, value=""),
997
+ _build_kick_plot(GLOBAL_STATE),
998
+ )
999
 
1000
 
1001
  def create_annotation_preview(video_file, annotations):
 
1264
  label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
1265
  clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
1266
  prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
1267
+ kick_plot = gr.Plot(label="Kick diagnostics", show_label=True)
1268
 
1269
  # Wire events
1270
  def _on_video_change(GLOBAL_STATE: gr.State, video):
 
1274
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
1275
  first_frame,
1276
  status,
1277
+ gr.update(visible=False, value=""),
1278
+ _build_kick_plot(GLOBAL_STATE)
1279
  )
1280
 
1281
  video_in.change(
1282
  _on_video_change,
1283
  inputs=[GLOBAL_STATE, video_in],
1284
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot],
1285
  show_progress=True,
1286
  )
1287
 
 
1296
  examples=examples_list,
1297
  inputs=[GLOBAL_STATE, video_in],
1298
  fn=_on_video_change,
1299
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot],
1300
  label="Examples",
1301
  cache_examples=False,
1302
  examples_per_page=5,
 
1396
  visible=True,
1397
  ),
1398
  gr.update(value=frame_idx),
1399
+ _build_kick_plot(state_in),
1400
  )
1401
 
1402
  x_center, y_center, _, _, conf = detection
 
1429
  status_text += f" | Kick frame ≈ {state_in.kick_frame}"
1430
  else:
1431
  status_text += " | Kick frame: not detected"
1432
+ return (
1433
+ preview_img,
1434
+ gr.update(value=status_text, visible=True),
1435
+ gr.update(value=frame_idx),
1436
+ _build_kick_plot(state_in),
1437
+ )
1438
 
1439
  detect_ball_btn.click(
1440
  _auto_detect_ball,
1441
  inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
1442
+ outputs=[preview, ball_status, frame_slider, kick_plot],
1443
  )
1444
 
1445
  # Image click to add a point and run forward on that frame
 
1489
  propagate_btn.click(
1490
  propagate_masks,
1491
  inputs=[GLOBAL_STATE],
1492
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot],
1493
  )
1494
 
1495
  reset_btn.click(
1496
  reset_session,
1497
  inputs=GLOBAL_STATE,
1498
+ outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot],
1499
  )
1500
 
1501
  # ============================================================================
requirements.txt CHANGED
@@ -7,5 +7,6 @@ opencv-python
7
  imageio[pyav]
8
  spaces
9
  git+https://github.com/iMoonLab/yolov13
 
10
 
11
 
 
7
  imageio[pyav]
8
  spaces
9
  git+https://github.com/iMoonLab/yolov13
10
+ plotly
11
 
12