Mirko Trasciatti commited on
Commit
8df8982
·
1 Parent(s): 0653036

Add multi-ball candidate detection and tracking

Browse files

- detect_all_balls(): Detect multiple balls in first frame with ROI filtering
- _track_single_ball_candidate(): Track each candidate using proximity matching
- _detect_and_track_all_ball_candidates(): Score candidates by kick detection, velocity, position
- _build_multi_ball_chart(): Combined speed chart comparing all candidates
- _apply_selected_ball_to_yolo_state(): Copy selected candidate to main YOLO state
- Updated _auto_detect_ball to detect multiple candidates
- Updated _track_ball_yolo to use multi-ball tracking when multiple candidates found
- Added ball_candidates, selected_ball_idx, ball_selection_confirmed to AppState
- Added multi-ball selection UI components (radio buttons, chart, confirm button)

Files changed (1) hide show
  1. app.py +616 -21
app.py CHANGED
@@ -132,6 +132,96 @@ def detect_ball_center(
132
  )
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def detect_person_box(
136
  frame: Image.Image,
137
  model_filename: str = YOLO_DEFAULT_MODEL,
@@ -683,6 +773,347 @@ def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None =
683
  state.sam_window = None
684
 
685
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
687
  """Generate a deterministic pastel RGB color for a given object id.
688
 
@@ -877,6 +1308,14 @@ class AppState:
877
  self.is_sam_tracked: bool = False
878
  self.is_player_detected: bool = False
879
  self.is_player_propagated: bool = False
 
 
 
 
 
 
 
 
880
  self.goal_mode: str = GOAL_MODE_IDLE
881
  self.goal_points_norm: list[tuple[float, float]] = []
882
  self.goal_confirmed_points_norm: list[tuple[float, float]] = []
@@ -2225,6 +2664,109 @@ def _build_yolo_plot(state: AppState):
2225
  return fig
2226
 
2227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2228
  def _jump_to_frame(state: AppState, target: int | None):
2229
  if state is None or state.num_frames == 0 or target is None:
2230
  return gr.update(), gr.update()
@@ -3680,6 +4222,20 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
3680
  with gr.Row(elem_classes=["model-status"]):
3681
  yolo_kick_btn = gr.Button("⚽: N/A", interactive=False)
3682
  yolo_impact_btn = gr.Button("🚩: N/A", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3683
  yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
3684
  with gr.Column(elem_classes=["model-section"]):
3685
  with gr.Row(elem_classes=["model-row"]):
@@ -4216,25 +4772,45 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
4216
  state_in.is_ball_detected = False
4217
  frame_idx = 0
4218
  frame = state_in.video_frames[frame_idx]
4219
- detection = detect_ball_center(frame)
4220
- if detection is None:
4221
- propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
4222
- status_updates = _ui_status_updates(state_in)
4223
- return (
4224
- update_frame_display(state_in, frame_idx),
4225
- gr.update(
4226
- value="❌ Unable to auto-detect the ball. Please add a point manually.",
4227
- visible=True,
4228
- ),
4229
- gr.update(value=frame_idx),
4230
- _build_kick_plot(state_in),
4231
- propagate_main_update,
4232
- detect_btn_update,
4233
- propagate_player_update,
4234
- *status_updates,
4235
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4236
 
4237
- x_center, y_center, _, _, conf = detection
4238
  frame_width, frame_height = frame.size
4239
  x_center = max(0, min(frame_width - 1, int(x_center)))
4240
  y_center = max(0, min(frame_height - 1, int(y_center)))
@@ -4260,7 +4836,11 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
4260
  )
4261
 
4262
  state_in.is_ball_detected = True
4263
- status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
 
 
 
 
4264
  status_text += f" | {_format_kick_status(state_in)}"
4265
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
4266
  status_updates = _ui_status_updates(state_in)
@@ -4307,7 +4887,23 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
4307
  raise gr.Error("Load a video first, then track the ball with YOLO.")
4308
  progress = gr.Progress(track_tqdm=False)
4309
  state_in.is_yolo_tracked = False
4310
- _perform_yolo_ball_tracking(state_in, progress=progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4311
  target_frame = (
4312
  state_in.yolo_kick_frame
4313
  if state_in.yolo_kick_frame is not None
@@ -4319,7 +4915,6 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
4319
  target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
4320
  state_in.current_frame_idx = target_frame
4321
  preview_img = update_frame_display(state_in, target_frame)
4322
- base_msg = state_in.yolo_status or ""
4323
  kick_msg = _format_kick_status(state_in)
4324
  status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
4325
  state_in.is_yolo_tracked = True
 
132
  )
133
 
134
 
135
+ def detect_all_balls(
136
+ frame: Image.Image,
137
+ model_filename: str = YOLO_DEFAULT_MODEL,
138
+ conf_threshold: float = 0.1, # Lower threshold to catch more candidates
139
+ iou_threshold: float = YOLO_IOU_THRESHOLD,
140
+ max_detections: int = 5,
141
+ roi_x_min: float = 0.20, # Filter: ball must be within 20-80% horizontal
142
+ roi_x_max: float = 0.80,
143
+ ) -> list[dict]:
144
+ """
145
+ Detect all ball candidates in a frame, filtered by ROI.
146
+
147
+ Returns list of dicts with keys:
148
+ - id: int (candidate index)
149
+ - center: (x, y) tuple
150
+ - box: (x_min, y_min, x_max, y_max) tuple
151
+ - width: float
152
+ - height: float
153
+ - conf: float (YOLO confidence)
154
+ - x_ratio: float (horizontal position as fraction of frame width)
155
+ """
156
+ model = get_yolo_model(model_filename)
157
+ class_ids = [
158
+ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
159
+ ]
160
+ if not class_ids:
161
+ return []
162
+
163
+ results = model.predict(
164
+ source=frame,
165
+ conf=conf_threshold,
166
+ iou=iou_threshold,
167
+ max_det=max_detections,
168
+ classes=class_ids,
169
+ imgsz=640,
170
+ device="cpu",
171
+ verbose=False,
172
+ )
173
+
174
+ if not results:
175
+ return []
176
+
177
+ boxes = results[0].boxes
178
+ if boxes is None or len(boxes) == 0:
179
+ return []
180
+
181
+ frame_width, frame_height = frame.size
182
+ candidates = []
183
+
184
+ for i, box in enumerate(boxes):
185
+ xywh = box.xywh[0].cpu().tolist()
186
+ conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
187
+ x_center, y_center, width, height = xywh
188
+
189
+ # Compute bounding box
190
+ x_min = int(round(max(0.0, x_center - width / 2.0)))
191
+ y_min = int(round(max(0.0, y_center - height / 2.0)))
192
+ x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
193
+ y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
194
+
195
+ if x_max <= x_min or y_max <= y_min:
196
+ continue
197
+
198
+ # Compute horizontal position ratio
199
+ x_ratio = x_center / frame_width
200
+
201
+ # Filter by ROI (horizontal position)
202
+ if not (roi_x_min <= x_ratio <= roi_x_max):
203
+ continue
204
+
205
+ candidates.append({
206
+ "id": len(candidates),
207
+ "center": (float(x_center), float(y_center)),
208
+ "box": (x_min, y_min, x_max, y_max),
209
+ "width": float(width),
210
+ "height": float(height),
211
+ "conf": conf,
212
+ "x_ratio": x_ratio,
213
+ })
214
+
215
+ # Sort by confidence descending
216
+ candidates.sort(key=lambda c: c["conf"], reverse=True)
217
+
218
+ # Re-assign IDs after sorting
219
+ for i, c in enumerate(candidates):
220
+ c["id"] = i
221
+
222
+ return candidates
223
+
224
+
225
  def detect_person_box(
226
  frame: Image.Image,
227
  model_filename: str = YOLO_DEFAULT_MODEL,
 
773
  state.sam_window = None
774
 
775
 
776
+ def _track_single_ball_candidate(
777
+ state: AppState,
778
+ candidate: dict,
779
+ progress: gr.Progress | None = None,
780
+ ) -> dict:
781
+ """
782
+ Track a single ball candidate across all frames using YOLO.
783
+ Uses proximity matching to follow the same ball.
784
+
785
+ Returns dict with tracking results:
786
+ - centers: dict[frame_idx, (x, y)]
787
+ - speeds: dict[frame_idx, speed]
788
+ - kick_frame: int | None
789
+ - max_velocity: float
790
+ - has_kick: bool
791
+ - coverage: float (fraction of frames with detection)
792
+ """
793
+ model = get_yolo_model()
794
+ class_ids = [
795
+ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
796
+ ]
797
+
798
+ frames = state.video_frames
799
+ total = len(frames)
800
+
801
+ # Initial position from candidate
802
+ last_center = candidate["center"]
803
+ max_distance_threshold = 100 # Max pixels to consider same ball
804
+
805
+ centers: dict[int, tuple[float, float]] = {}
806
+ boxes: dict[int, tuple[int, int, int, int]] = {}
807
+ confs: dict[int, float] = {}
808
+ areas: dict[int, float] = {}
809
+
810
+ for idx, frame in enumerate(frames):
811
+ if progress is not None:
812
+ progress((idx + 1) / total)
813
+
814
+ results = model.predict(
815
+ source=frame,
816
+ conf=0.05, # Lower threshold to catch more
817
+ iou=YOLO_IOU_THRESHOLD,
818
+ max_det=10, # Allow multiple detections
819
+ classes=class_ids,
820
+ imgsz=640,
821
+ device="cpu",
822
+ verbose=False,
823
+ )
824
+
825
+ if not results:
826
+ continue
827
+
828
+ boxes_result = results[0].boxes
829
+ if boxes_result is None or len(boxes_result) == 0:
830
+ continue
831
+
832
+ # Find the detection closest to last known position
833
+ best_box = None
834
+ best_distance = float("inf")
835
+
836
+ for box in boxes_result:
837
+ xywh = box.xywh[0].cpu().tolist()
838
+ x_center, y_center = xywh[0], xywh[1]
839
+ dist = math.hypot(x_center - last_center[0], y_center - last_center[1])
840
+
841
+ if dist < best_distance and dist < max_distance_threshold:
842
+ best_distance = dist
843
+ best_box = box
844
+
845
+ if best_box is None:
846
+ continue
847
+
848
+ xywh = best_box.xywh[0].cpu().tolist()
849
+ conf = float(best_box.conf[0].cpu().item()) if best_box.conf is not None else 0.0
850
+ x_center, y_center, width, height = xywh
851
+ x_center = float(x_center)
852
+ y_center = float(y_center)
853
+ width = max(1.0, float(width))
854
+ height = max(1.0, float(height))
855
+
856
+ frame_width, frame_height = frame.size
857
+ x_min = int(round(max(0.0, x_center - width / 2.0)))
858
+ y_min = int(round(max(0.0, y_center - height / 2.0)))
859
+ x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
860
+ y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
861
+
862
+ if x_max <= x_min or y_max <= y_min:
863
+ continue
864
+
865
+ centers[idx] = (x_center, y_center)
866
+ boxes[idx] = (x_min, y_min, x_max, y_max)
867
+ confs[idx] = conf
868
+ areas[idx] = float((x_max - x_min) * (y_max - y_min))
869
+ last_center = (x_center, y_center)
870
+
871
+ # Compute speeds
872
+ if len(centers) < 3:
873
+ return {
874
+ "centers": centers,
875
+ "boxes": boxes,
876
+ "confs": confs,
877
+ "areas": areas,
878
+ "speeds": {},
879
+ "smoothed_centers": {},
880
+ "frames_ordered": [],
881
+ "speed_series": [],
882
+ "kick_frame": None,
883
+ "max_velocity": 0.0,
884
+ "has_kick": False,
885
+ "coverage": len(centers) / total if total else 0.0,
886
+ }
887
+
888
+ items = sorted(centers.items())
889
+ dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
890
+ alpha = 0.35
891
+
892
+ smoothed: dict[int, tuple[float, float]] = {}
893
+ speeds: dict[int, float] = {}
894
+
895
+ prev_frame = None
896
+ prev_smooth = None
897
+ for frame_idx, (cx, cy) in items:
898
+ if prev_smooth is None:
899
+ smooth_x, smooth_y = float(cx), float(cy)
900
+ else:
901
+ smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
902
+ smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
903
+ smoothed[frame_idx] = (smooth_x, smooth_y)
904
+ if prev_smooth is None or prev_frame is None:
905
+ speeds[frame_idx] = 0.0
906
+ else:
907
+ frame_delta = max(1, frame_idx - prev_frame)
908
+ time_delta = frame_delta * dt
909
+ dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
910
+ speed = dist / time_delta if time_delta > 0 else dist
911
+ speeds[frame_idx] = speed
912
+ prev_smooth = (smooth_x, smooth_y)
913
+ prev_frame = frame_idx
914
+
915
+ frames_ordered = [frame_idx for frame_idx, _ in items]
916
+ speed_series = [speeds.get(f, 0.0) for f in frames_ordered]
917
+
918
+ # Detect kick (velocity spike)
919
+ baseline_window = min(10, len(frames_ordered) // 3 or 1)
920
+ baseline_speeds = speed_series[:baseline_window]
921
+ baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
922
+ speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
923
+ base_threshold = baseline_speed + 4.0 * speed_std
924
+ if base_threshold < baseline_speed * 3.0:
925
+ base_threshold = baseline_speed * 3.0
926
+ speed_threshold = max(base_threshold, 15.0)
927
+
928
+ kick_frame: int | None = None
929
+ max_velocity = max(speed_series) if speed_series else 0.0
930
+
931
+ for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window):
932
+ speed = speed_series[idx]
933
+ if speed < speed_threshold:
934
+ continue
935
+ # Check sustain
936
+ sustain_ok = True
937
+ for j in range(1, 4):
938
+ if idx + j >= len(frames_ordered):
939
+ break
940
+ if speed_series[idx + j] < speed_threshold * 0.7:
941
+ sustain_ok = False
942
+ break
943
+ if sustain_ok:
944
+ kick_frame = frame
945
+ break
946
+
947
+ return {
948
+ "centers": centers,
949
+ "boxes": boxes,
950
+ "confs": confs,
951
+ "areas": areas,
952
+ "speeds": speeds,
953
+ "smoothed_centers": smoothed,
954
+ "frames_ordered": frames_ordered,
955
+ "speed_series": speed_series,
956
+ "threshold": speed_threshold,
957
+ "baseline": baseline_speed,
958
+ "kick_frame": kick_frame,
959
+ "max_velocity": max_velocity,
960
+ "has_kick": kick_frame is not None,
961
+ "coverage": len(centers) / total if total else 0.0,
962
+ }
963
+
964
+
965
+ def _detect_and_track_all_ball_candidates(
966
+ state: AppState,
967
+ progress: gr.Progress | None = None,
968
+ ) -> None:
969
+ """
970
+ Detect all ball candidates in first frame, track each with YOLO,
971
+ score them, and auto-select the best candidate.
972
+ """
973
+ if state is None or state.num_frames == 0:
974
+ raise gr.Error("Load a video first.")
975
+
976
+ first_frame = state.video_frames[0]
977
+ frame_width, frame_height = first_frame.size
978
+
979
+ # Step 1: Detect all balls in first frame
980
+ candidates = detect_all_balls(first_frame)
981
+
982
+ if not candidates:
983
+ state.ball_candidates = []
984
+ state.multi_ball_status = "❌ No ball candidates detected in first frame."
985
+ return
986
+
987
+ state.multi_ball_status = f"🔍 Found {len(candidates)} ball candidate(s). Tracking..."
988
+
989
+ # Step 2: Track each candidate
990
+ tracking_results: dict[int, dict] = {}
991
+
992
+ for i, candidate in enumerate(candidates):
993
+ if progress is not None:
994
+ progress((i + 1) / len(candidates), desc=f"Tracking ball {i+1}/{len(candidates)}")
995
+
996
+ result = _track_single_ball_candidate(state, candidate, progress=None)
997
+ tracking_results[candidate["id"]] = result
998
+
999
+ # Add tracking summary to candidate
1000
+ candidate["tracking"] = result
1001
+ candidate["has_kick"] = result["has_kick"]
1002
+ candidate["kick_frame"] = result["kick_frame"]
1003
+ candidate["max_velocity"] = result["max_velocity"]
1004
+ candidate["coverage"] = result["coverage"]
1005
+
1006
+ # Step 3: Score candidates
1007
+ frame_center_x = frame_width / 2
1008
+
1009
+ for candidate in candidates:
1010
+ score = 0.0
1011
+
1012
+ # 1. Has a detected kick (velocity spike) — most important
1013
+ if candidate["has_kick"]:
1014
+ score += 50
1015
+
1016
+ # 2. Higher max velocity — ball that moves most
1017
+ score += min(30, candidate["max_velocity"] / 10)
1018
+
1019
+ # 3. Centered horizontally
1020
+ x_offset = abs(candidate["center"][0] - frame_center_x) / frame_center_x
1021
+ score += 20 * (1 - x_offset)
1022
+
1023
+ # 4. YOLO confidence as tiebreaker
1024
+ score += candidate["conf"] * 10
1025
+
1026
+ # 5. Better coverage
1027
+ score += candidate["coverage"] * 10
1028
+
1029
+ candidate["score"] = score
1030
+
1031
+ # Sort by score descending
1032
+ candidates.sort(key=lambda c: c["score"], reverse=True)
1033
+
1034
+ # Re-assign IDs after sorting
1035
+ for i, c in enumerate(candidates):
1036
+ c["id"] = i
1037
+
1038
+ state.ball_candidates = candidates
1039
+ state.ball_candidates_tracking = tracking_results
1040
+ state.selected_ball_idx = 0 # Auto-select best candidate
1041
+ state.ball_selection_confirmed = False
1042
+
1043
+ # Build status message
1044
+ if len(candidates) == 1:
1045
+ c = candidates[0]
1046
+ kick_info = f"Kick @ frame {c['kick_frame']}" if c["has_kick"] else "No kick detected"
1047
+ state.multi_ball_status = f"✅ 1 ball detected. {kick_info}."
1048
+ else:
1049
+ kicked_count = sum(1 for c in candidates if c["has_kick"])
1050
+ state.multi_ball_status = (
1051
+ f"⚠️ {len(candidates)} balls detected. "
1052
+ f"{kicked_count} show movement. "
1053
+ f"Best candidate auto-selected. Please confirm or change selection."
1054
+ )
1055
+
1056
+
1057
+ def _apply_selected_ball_to_yolo_state(state: AppState) -> None:
1058
+ """
1059
+ Copy the selected ball candidate's tracking data to the main YOLO state.
1060
+ This allows the rest of the pipeline to work unchanged.
1061
+ """
1062
+ if not state.ball_candidates:
1063
+ return
1064
+
1065
+ idx = state.selected_ball_idx
1066
+ if idx < 0 or idx >= len(state.ball_candidates):
1067
+ idx = 0
1068
+
1069
+ candidate = state.ball_candidates[idx]
1070
+ tracking = candidate.get("tracking", {})
1071
+
1072
+ # Copy to main YOLO state
1073
+ state.yolo_ball_centers = tracking.get("centers", {})
1074
+ state.yolo_ball_boxes = tracking.get("boxes", {})
1075
+ state.yolo_ball_conf = tracking.get("confs", {})
1076
+ state.yolo_smoothed_centers = tracking.get("smoothed_centers", {})
1077
+ state.yolo_speeds = tracking.get("speeds", {})
1078
+ state.yolo_kick_frames = tracking.get("frames_ordered", [])
1079
+ state.yolo_kick_speeds = tracking.get("speed_series", [])
1080
+ state.yolo_threshold = tracking.get("threshold")
1081
+ state.yolo_baseline_speed = tracking.get("baseline")
1082
+ state.yolo_kick_frame = tracking.get("kick_frame")
1083
+ state.yolo_initial_frame = tracking.get("frames_ordered", [None])[0] if tracking.get("frames_ordered") else None
1084
+
1085
+ # Compute areas
1086
+ areas = tracking.get("areas", {})
1087
+ frames_ordered = tracking.get("frames_ordered", [])
1088
+ state.yolo_mask_area_proxy = [areas.get(f, 0.0) for f in frames_ordered]
1089
+
1090
+ # Compute distance from start
1091
+ smoothed = tracking.get("smoothed_centers", {})
1092
+ if smoothed and frames_ordered:
1093
+ origin = smoothed.get(frames_ordered[0], (0, 0))
1094
+ distance_dict = {}
1095
+ for f, (sx, sy) in smoothed.items():
1096
+ distance_dict[f] = math.hypot(sx - origin[0], sy - origin[1])
1097
+ state.yolo_distance_from_start = distance_dict
1098
+ state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered]
1099
+
1100
+ # Update kick frame
1101
+ kick_frame = tracking.get("kick_frame")
1102
+ if kick_frame is not None:
1103
+ state.kick_frame = kick_frame
1104
+ _compute_sam_window_from_kick(state, kick_frame)
1105
+
1106
+ # Mark as tracked
1107
+ state.is_yolo_tracked = True
1108
+ state.ball_selection_confirmed = True
1109
+
1110
+ coverage = tracking.get("coverage", 0.0)
1111
+ if kick_frame is not None:
1112
+ state.yolo_status = f"✅ Ball {idx+1} tracked. Kick @ frame {kick_frame}."
1113
+ else:
1114
+ state.yolo_status = f"⚠️ Ball {idx+1} tracked ({coverage:.0%} coverage) but no kick detected."
1115
+
1116
+
1117
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
1118
  """Generate a deterministic pastel RGB color for a given object id.
1119
 
 
1308
  self.is_sam_tracked: bool = False
1309
  self.is_player_detected: bool = False
1310
  self.is_player_propagated: bool = False
1311
+
1312
+ # Multi-ball candidate tracking
1313
+ self.ball_candidates: list[dict] = [] # All detected ball candidates
1314
+ self.ball_candidates_tracking: dict[int, dict] = {} # Per-candidate tracking data
1315
+ self.selected_ball_idx: int = 0 # Currently selected candidate index
1316
+ self.ball_selection_confirmed: bool = False # True after user confirms selection
1317
+ self.multi_ball_status: str = "" # Status message for multi-ball detection
1318
+
1319
  self.goal_mode: str = GOAL_MODE_IDLE
1320
  self.goal_points_norm: list[tuple[float, float]] = []
1321
  self.goal_confirmed_points_norm: list[tuple[float, float]] = []
 
2664
  return fig
2665
 
2666
 
2667
+ def _build_multi_ball_chart(state: AppState):
2668
+ """
2669
+ Build a combined speed chart showing all ball candidates.
2670
+ The selected/kicked ball is highlighted in green, others in gray.
2671
+ """
2672
+ fig = go.Figure()
2673
+
2674
+ if state is None or not state.ball_candidates:
2675
+ fig.update_layout(
2676
+ title="Ball Candidates Speed Comparison",
2677
+ xaxis_title="Frame",
2678
+ yaxis_title="Speed (px/s)",
2679
+ )
2680
+ return fig
2681
+
2682
+ # Color palette for candidates
2683
+ colors = [
2684
+ "#4caf50", # Green (selected/kicked)
2685
+ "#9e9e9e", # Gray
2686
+ "#bdbdbd", # Light gray
2687
+ "#757575", # Dark gray
2688
+ "#e0e0e0", # Very light gray
2689
+ ]
2690
+
2691
+ selected_idx = state.selected_ball_idx
2692
+ max_speed = 0.0
2693
+ kick_frames_to_mark = []
2694
+
2695
+ for i, candidate in enumerate(state.ball_candidates):
2696
+ tracking = candidate.get("tracking", {})
2697
+ frames = tracking.get("frames_ordered", [])
2698
+ speeds = tracking.get("speed_series", [])
2699
+
2700
+ if not frames or not speeds:
2701
+ continue
2702
+
2703
+ max_speed = max(max_speed, max(speeds) if speeds else 0)
2704
+
2705
+ is_selected = (i == selected_idx)
2706
+ is_kicked = candidate.get("has_kick", False)
2707
+
2708
+ # Determine color and style
2709
+ if is_selected:
2710
+ color = "#4caf50" # Green
2711
+ width = 3
2712
+ opacity = 1.0
2713
+ elif is_kicked:
2714
+ color = "#ff9800" # Orange for other kicked balls
2715
+ width = 2
2716
+ opacity = 0.7
2717
+ else:
2718
+ color = "#9e9e9e" # Gray
2719
+ width = 1
2720
+ opacity = 0.5
2721
+
2722
+ # Build label
2723
+ label_parts = [f"Ball {i+1}"]
2724
+ if is_kicked:
2725
+ label_parts.append("⚽")
2726
+ if is_selected:
2727
+ label_parts.append("✓")
2728
+ label = " ".join(label_parts)
2729
+
2730
+ fig.add_trace(
2731
+ go.Scatter(
2732
+ x=frames,
2733
+ y=speeds,
2734
+ mode="lines",
2735
+ name=label,
2736
+ line=dict(color=color, width=width),
2737
+ opacity=opacity,
2738
+ )
2739
+ )
2740
+
2741
+ # Mark kick frame
2742
+ kick_frame = candidate.get("kick_frame")
2743
+ if kick_frame is not None:
2744
+ kick_frames_to_mark.append((kick_frame, i, is_selected))
2745
+
2746
+ # Add vertical lines for kick frames
2747
+ for kick_frame, ball_idx, is_selected in kick_frames_to_mark:
2748
+ color = "#e91e63" if is_selected else "#ffcdd2"
2749
+ width = 3 if is_selected else 1
2750
+
2751
+ fig.add_vline(
2752
+ x=kick_frame,
2753
+ line=dict(color=color, width=width, dash="solid" if is_selected else "dot"),
2754
+ annotation_text=f"Ball {ball_idx+1} kick" if is_selected else "",
2755
+ annotation_position="top right" if is_selected else None,
2756
+ )
2757
+
2758
+ fig.update_layout(
2759
+ title="Ball Candidates Speed Comparison",
2760
+ xaxis=dict(title="Frame"),
2761
+ yaxis=dict(title="Speed (px/s)", range=[0, max_speed * 1.1] if max_speed > 0 else None),
2762
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
2763
+ margin=dict(t=60, l=40, r=40, b=40),
2764
+ hovermode="x unified",
2765
+ )
2766
+
2767
+ return fig
2768
+
2769
+
2770
  def _jump_to_frame(state: AppState, target: int | None):
2771
  if state is None or state.num_frames == 0 or target is None:
2772
  return gr.update(), gr.update()
 
4222
  with gr.Row(elem_classes=["model-status"]):
4223
  yolo_kick_btn = gr.Button("⚽: N/A", interactive=False)
4224
  yolo_impact_btn = gr.Button("🚩: N/A", interactive=False)
4225
+
4226
+ # Multi-ball candidate selection UI
4227
+ with gr.Column(visible=False) as multi_ball_selection_col:
4228
+ multi_ball_status_md = gr.Markdown("", visible=True)
4229
+ ball_candidate_radio = gr.Radio(
4230
+ choices=[],
4231
+ value=None,
4232
+ label="Select Ball Candidate",
4233
+ interactive=True,
4234
+ )
4235
+ with gr.Row():
4236
+ confirm_ball_btn = gr.Button("Confirm Selection", variant="primary")
4237
+ multi_ball_chart = gr.Plot(label="Ball Candidates Speed Comparison", show_label=True)
4238
+
4239
  yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
4240
  with gr.Column(elem_classes=["model-section"]):
4241
  with gr.Row(elem_classes=["model-row"]):
 
4772
  state_in.is_ball_detected = False
4773
  frame_idx = 0
4774
  frame = state_in.video_frames[frame_idx]
4775
+
4776
+ # First, try multi-ball detection
4777
+ candidates = detect_all_balls(frame)
4778
+
4779
+ if not candidates:
4780
+ # Fallback to single-ball detection
4781
+ detection = detect_ball_center(frame)
4782
+ if detection is None:
4783
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
4784
+ status_updates = _ui_status_updates(state_in)
4785
+ return (
4786
+ update_frame_display(state_in, frame_idx),
4787
+ gr.update(
4788
+ value="❌ Unable to auto-detect the ball. Please add a point manually.",
4789
+ visible=True,
4790
+ ),
4791
+ gr.update(value=frame_idx),
4792
+ _build_kick_plot(state_in),
4793
+ propagate_main_update,
4794
+ detect_btn_update,
4795
+ propagate_player_update,
4796
+ *status_updates,
4797
+ )
4798
+ x_center, y_center, _, _, conf = detection
4799
+ else:
4800
+ # Use the best candidate (first one after sorting by confidence)
4801
+ best = candidates[0]
4802
+ x_center, y_center = best["center"]
4803
+ conf = best["conf"]
4804
+
4805
+ # Store candidates for potential multi-ball workflow
4806
+ state_in.ball_candidates = candidates
4807
+ state_in.selected_ball_idx = 0
4808
+
4809
+ if len(candidates) > 1:
4810
+ state_in.multi_ball_status = f"⚠️ {len(candidates)} balls detected. Using best candidate. Run 'Track Ball' to analyze all."
4811
+ else:
4812
+ state_in.multi_ball_status = ""
4813
 
 
4814
  frame_width, frame_height = frame.size
4815
  x_center = max(0, min(frame_width - 1, int(x_center)))
4816
  y_center = max(0, min(frame_height - 1, int(y_center)))
 
4836
  )
4837
 
4838
  state_in.is_ball_detected = True
4839
+ num_candidates = len(getattr(state_in, 'ball_candidates', []))
4840
+ if num_candidates > 1:
4841
+ status_text = f"⚠️ {num_candidates} balls found! Best at ({x_center}, {y_center}) (conf={conf:.2f})"
4842
+ else:
4843
+ status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
4844
  status_text += f" | {_format_kick_status(state_in)}"
4845
  propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
4846
  status_updates = _ui_status_updates(state_in)
 
4887
  raise gr.Error("Load a video first, then track the ball with YOLO.")
4888
  progress = gr.Progress(track_tqdm=False)
4889
  state_in.is_yolo_tracked = False
4890
+
4891
+ # Check if we have multiple ball candidates
4892
+ num_candidates = len(getattr(state_in, 'ball_candidates', []))
4893
+
4894
+ if num_candidates > 1:
4895
+ # Multi-ball mode: track all candidates and show comparison
4896
+ _detect_and_track_all_ball_candidates(state_in, progress=progress)
4897
+
4898
+ # Apply the best candidate to YOLO state
4899
+ _apply_selected_ball_to_yolo_state(state_in)
4900
+
4901
+ base_msg = state_in.multi_ball_status or state_in.yolo_status or ""
4902
+ else:
4903
+ # Single ball mode: use original tracking
4904
+ _perform_yolo_ball_tracking(state_in, progress=progress)
4905
+ base_msg = state_in.yolo_status or ""
4906
+
4907
  target_frame = (
4908
  state_in.yolo_kick_frame
4909
  if state_in.yolo_kick_frame is not None
 
4915
  target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
4916
  state_in.current_frame_idx = target_frame
4917
  preview_img = update_frame_display(state_in, target_frame)
 
4918
  kick_msg = _format_kick_status(state_in)
4919
  status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
4920
  state_in.is_yolo_tracked = True