Mirko Trasciatti
commited on
Commit
·
8df8982
1
Parent(s):
0653036
Add multi-ball candidate detection and tracking
Browse files- detect_all_balls(): Detect multiple balls in first frame with ROI filtering
- _track_single_ball_candidate(): Track each candidate using proximity matching
- _detect_and_track_all_ball_candidates(): Score candidates by kick detection, velocity, position
- _build_multi_ball_chart(): Combined speed chart comparing all candidates
- _apply_selected_ball_to_yolo_state(): Copy selected candidate to main YOLO state
- Updated _auto_detect_ball to detect multiple candidates
- Updated _track_ball_yolo to use multi-ball tracking when multiple candidates found
- Added ball_candidates, selected_ball_idx, ball_selection_confirmed to AppState
- Added multi-ball selection UI components (radio buttons, chart, confirm button)
app.py
CHANGED
|
@@ -132,6 +132,96 @@ def detect_ball_center(
|
|
| 132 |
)
|
| 133 |
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def detect_person_box(
|
| 136 |
frame: Image.Image,
|
| 137 |
model_filename: str = YOLO_DEFAULT_MODEL,
|
|
@@ -683,6 +773,347 @@ def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None =
|
|
| 683 |
state.sam_window = None
|
| 684 |
|
| 685 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
| 687 |
"""Generate a deterministic pastel RGB color for a given object id.
|
| 688 |
|
|
@@ -877,6 +1308,14 @@ class AppState:
|
|
| 877 |
self.is_sam_tracked: bool = False
|
| 878 |
self.is_player_detected: bool = False
|
| 879 |
self.is_player_propagated: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
self.goal_mode: str = GOAL_MODE_IDLE
|
| 881 |
self.goal_points_norm: list[tuple[float, float]] = []
|
| 882 |
self.goal_confirmed_points_norm: list[tuple[float, float]] = []
|
|
@@ -2225,6 +2664,109 @@ def _build_yolo_plot(state: AppState):
|
|
| 2225 |
return fig
|
| 2226 |
|
| 2227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2228 |
def _jump_to_frame(state: AppState, target: int | None):
|
| 2229 |
if state is None or state.num_frames == 0 or target is None:
|
| 2230 |
return gr.update(), gr.update()
|
|
@@ -3680,6 +4222,20 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 3680 |
with gr.Row(elem_classes=["model-status"]):
|
| 3681 |
yolo_kick_btn = gr.Button("⚽: N/A", interactive=False)
|
| 3682 |
yolo_impact_btn = gr.Button("🚩: N/A", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3683 |
yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
|
| 3684 |
with gr.Column(elem_classes=["model-section"]):
|
| 3685 |
with gr.Row(elem_classes=["model-row"]):
|
|
@@ -4216,25 +4772,45 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 4216 |
state_in.is_ball_detected = False
|
| 4217 |
frame_idx = 0
|
| 4218 |
frame = state_in.video_frames[frame_idx]
|
| 4219 |
-
|
| 4220 |
-
|
| 4221 |
-
|
| 4222 |
-
|
| 4223 |
-
|
| 4224 |
-
|
| 4225 |
-
|
| 4226 |
-
|
| 4227 |
-
|
| 4228 |
-
)
|
| 4229 |
-
|
| 4230 |
-
|
| 4231 |
-
|
| 4232 |
-
|
| 4233 |
-
|
| 4234 |
-
|
| 4235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4236 |
|
| 4237 |
-
x_center, y_center, _, _, conf = detection
|
| 4238 |
frame_width, frame_height = frame.size
|
| 4239 |
x_center = max(0, min(frame_width - 1, int(x_center)))
|
| 4240 |
y_center = max(0, min(frame_height - 1, int(y_center)))
|
|
@@ -4260,7 +4836,11 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 4260 |
)
|
| 4261 |
|
| 4262 |
state_in.is_ball_detected = True
|
| 4263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4264 |
status_text += f" | {_format_kick_status(state_in)}"
|
| 4265 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 4266 |
status_updates = _ui_status_updates(state_in)
|
|
@@ -4307,7 +4887,23 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 4307 |
raise gr.Error("Load a video first, then track the ball with YOLO.")
|
| 4308 |
progress = gr.Progress(track_tqdm=False)
|
| 4309 |
state_in.is_yolo_tracked = False
|
| 4310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4311 |
target_frame = (
|
| 4312 |
state_in.yolo_kick_frame
|
| 4313 |
if state_in.yolo_kick_frame is not None
|
|
@@ -4319,7 +4915,6 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 4319 |
target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
|
| 4320 |
state_in.current_frame_idx = target_frame
|
| 4321 |
preview_img = update_frame_display(state_in, target_frame)
|
| 4322 |
-
base_msg = state_in.yolo_status or ""
|
| 4323 |
kick_msg = _format_kick_status(state_in)
|
| 4324 |
status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
|
| 4325 |
state_in.is_yolo_tracked = True
|
|
|
|
| 132 |
)
|
| 133 |
|
| 134 |
|
| 135 |
+
def detect_all_balls(
|
| 136 |
+
frame: Image.Image,
|
| 137 |
+
model_filename: str = YOLO_DEFAULT_MODEL,
|
| 138 |
+
conf_threshold: float = 0.1, # Lower threshold to catch more candidates
|
| 139 |
+
iou_threshold: float = YOLO_IOU_THRESHOLD,
|
| 140 |
+
max_detections: int = 5,
|
| 141 |
+
roi_x_min: float = 0.20, # Filter: ball must be within 20-80% horizontal
|
| 142 |
+
roi_x_max: float = 0.80,
|
| 143 |
+
) -> list[dict]:
|
| 144 |
+
"""
|
| 145 |
+
Detect all ball candidates in a frame, filtered by ROI.
|
| 146 |
+
|
| 147 |
+
Returns list of dicts with keys:
|
| 148 |
+
- id: int (candidate index)
|
| 149 |
+
- center: (x, y) tuple
|
| 150 |
+
- box: (x_min, y_min, x_max, y_max) tuple
|
| 151 |
+
- width: float
|
| 152 |
+
- height: float
|
| 153 |
+
- conf: float (YOLO confidence)
|
| 154 |
+
- x_ratio: float (horizontal position as fraction of frame width)
|
| 155 |
+
"""
|
| 156 |
+
model = get_yolo_model(model_filename)
|
| 157 |
+
class_ids = [
|
| 158 |
+
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
|
| 159 |
+
]
|
| 160 |
+
if not class_ids:
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
results = model.predict(
|
| 164 |
+
source=frame,
|
| 165 |
+
conf=conf_threshold,
|
| 166 |
+
iou=iou_threshold,
|
| 167 |
+
max_det=max_detections,
|
| 168 |
+
classes=class_ids,
|
| 169 |
+
imgsz=640,
|
| 170 |
+
device="cpu",
|
| 171 |
+
verbose=False,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if not results:
|
| 175 |
+
return []
|
| 176 |
+
|
| 177 |
+
boxes = results[0].boxes
|
| 178 |
+
if boxes is None or len(boxes) == 0:
|
| 179 |
+
return []
|
| 180 |
+
|
| 181 |
+
frame_width, frame_height = frame.size
|
| 182 |
+
candidates = []
|
| 183 |
+
|
| 184 |
+
for i, box in enumerate(boxes):
|
| 185 |
+
xywh = box.xywh[0].cpu().tolist()
|
| 186 |
+
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
|
| 187 |
+
x_center, y_center, width, height = xywh
|
| 188 |
+
|
| 189 |
+
# Compute bounding box
|
| 190 |
+
x_min = int(round(max(0.0, x_center - width / 2.0)))
|
| 191 |
+
y_min = int(round(max(0.0, y_center - height / 2.0)))
|
| 192 |
+
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
|
| 193 |
+
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
|
| 194 |
+
|
| 195 |
+
if x_max <= x_min or y_max <= y_min:
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
# Compute horizontal position ratio
|
| 199 |
+
x_ratio = x_center / frame_width
|
| 200 |
+
|
| 201 |
+
# Filter by ROI (horizontal position)
|
| 202 |
+
if not (roi_x_min <= x_ratio <= roi_x_max):
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
candidates.append({
|
| 206 |
+
"id": len(candidates),
|
| 207 |
+
"center": (float(x_center), float(y_center)),
|
| 208 |
+
"box": (x_min, y_min, x_max, y_max),
|
| 209 |
+
"width": float(width),
|
| 210 |
+
"height": float(height),
|
| 211 |
+
"conf": conf,
|
| 212 |
+
"x_ratio": x_ratio,
|
| 213 |
+
})
|
| 214 |
+
|
| 215 |
+
# Sort by confidence descending
|
| 216 |
+
candidates.sort(key=lambda c: c["conf"], reverse=True)
|
| 217 |
+
|
| 218 |
+
# Re-assign IDs after sorting
|
| 219 |
+
for i, c in enumerate(candidates):
|
| 220 |
+
c["id"] = i
|
| 221 |
+
|
| 222 |
+
return candidates
|
| 223 |
+
|
| 224 |
+
|
| 225 |
def detect_person_box(
|
| 226 |
frame: Image.Image,
|
| 227 |
model_filename: str = YOLO_DEFAULT_MODEL,
|
|
|
|
| 773 |
state.sam_window = None
|
| 774 |
|
| 775 |
|
| 776 |
+
def _track_single_ball_candidate(
|
| 777 |
+
state: AppState,
|
| 778 |
+
candidate: dict,
|
| 779 |
+
progress: gr.Progress | None = None,
|
| 780 |
+
) -> dict:
|
| 781 |
+
"""
|
| 782 |
+
Track a single ball candidate across all frames using YOLO.
|
| 783 |
+
Uses proximity matching to follow the same ball.
|
| 784 |
+
|
| 785 |
+
Returns dict with tracking results:
|
| 786 |
+
- centers: dict[frame_idx, (x, y)]
|
| 787 |
+
- speeds: dict[frame_idx, speed]
|
| 788 |
+
- kick_frame: int | None
|
| 789 |
+
- max_velocity: float
|
| 790 |
+
- has_kick: bool
|
| 791 |
+
- coverage: float (fraction of frames with detection)
|
| 792 |
+
"""
|
| 793 |
+
model = get_yolo_model()
|
| 794 |
+
class_ids = [
|
| 795 |
+
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
|
| 796 |
+
]
|
| 797 |
+
|
| 798 |
+
frames = state.video_frames
|
| 799 |
+
total = len(frames)
|
| 800 |
+
|
| 801 |
+
# Initial position from candidate
|
| 802 |
+
last_center = candidate["center"]
|
| 803 |
+
max_distance_threshold = 100 # Max pixels to consider same ball
|
| 804 |
+
|
| 805 |
+
centers: dict[int, tuple[float, float]] = {}
|
| 806 |
+
boxes: dict[int, tuple[int, int, int, int]] = {}
|
| 807 |
+
confs: dict[int, float] = {}
|
| 808 |
+
areas: dict[int, float] = {}
|
| 809 |
+
|
| 810 |
+
for idx, frame in enumerate(frames):
|
| 811 |
+
if progress is not None:
|
| 812 |
+
progress((idx + 1) / total)
|
| 813 |
+
|
| 814 |
+
results = model.predict(
|
| 815 |
+
source=frame,
|
| 816 |
+
conf=0.05, # Lower threshold to catch more
|
| 817 |
+
iou=YOLO_IOU_THRESHOLD,
|
| 818 |
+
max_det=10, # Allow multiple detections
|
| 819 |
+
classes=class_ids,
|
| 820 |
+
imgsz=640,
|
| 821 |
+
device="cpu",
|
| 822 |
+
verbose=False,
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
if not results:
|
| 826 |
+
continue
|
| 827 |
+
|
| 828 |
+
boxes_result = results[0].boxes
|
| 829 |
+
if boxes_result is None or len(boxes_result) == 0:
|
| 830 |
+
continue
|
| 831 |
+
|
| 832 |
+
# Find the detection closest to last known position
|
| 833 |
+
best_box = None
|
| 834 |
+
best_distance = float("inf")
|
| 835 |
+
|
| 836 |
+
for box in boxes_result:
|
| 837 |
+
xywh = box.xywh[0].cpu().tolist()
|
| 838 |
+
x_center, y_center = xywh[0], xywh[1]
|
| 839 |
+
dist = math.hypot(x_center - last_center[0], y_center - last_center[1])
|
| 840 |
+
|
| 841 |
+
if dist < best_distance and dist < max_distance_threshold:
|
| 842 |
+
best_distance = dist
|
| 843 |
+
best_box = box
|
| 844 |
+
|
| 845 |
+
if best_box is None:
|
| 846 |
+
continue
|
| 847 |
+
|
| 848 |
+
xywh = best_box.xywh[0].cpu().tolist()
|
| 849 |
+
conf = float(best_box.conf[0].cpu().item()) if best_box.conf is not None else 0.0
|
| 850 |
+
x_center, y_center, width, height = xywh
|
| 851 |
+
x_center = float(x_center)
|
| 852 |
+
y_center = float(y_center)
|
| 853 |
+
width = max(1.0, float(width))
|
| 854 |
+
height = max(1.0, float(height))
|
| 855 |
+
|
| 856 |
+
frame_width, frame_height = frame.size
|
| 857 |
+
x_min = int(round(max(0.0, x_center - width / 2.0)))
|
| 858 |
+
y_min = int(round(max(0.0, y_center - height / 2.0)))
|
| 859 |
+
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
|
| 860 |
+
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
|
| 861 |
+
|
| 862 |
+
if x_max <= x_min or y_max <= y_min:
|
| 863 |
+
continue
|
| 864 |
+
|
| 865 |
+
centers[idx] = (x_center, y_center)
|
| 866 |
+
boxes[idx] = (x_min, y_min, x_max, y_max)
|
| 867 |
+
confs[idx] = conf
|
| 868 |
+
areas[idx] = float((x_max - x_min) * (y_max - y_min))
|
| 869 |
+
last_center = (x_center, y_center)
|
| 870 |
+
|
| 871 |
+
# Compute speeds
|
| 872 |
+
if len(centers) < 3:
|
| 873 |
+
return {
|
| 874 |
+
"centers": centers,
|
| 875 |
+
"boxes": boxes,
|
| 876 |
+
"confs": confs,
|
| 877 |
+
"areas": areas,
|
| 878 |
+
"speeds": {},
|
| 879 |
+
"smoothed_centers": {},
|
| 880 |
+
"frames_ordered": [],
|
| 881 |
+
"speed_series": [],
|
| 882 |
+
"kick_frame": None,
|
| 883 |
+
"max_velocity": 0.0,
|
| 884 |
+
"has_kick": False,
|
| 885 |
+
"coverage": len(centers) / total if total else 0.0,
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
items = sorted(centers.items())
|
| 889 |
+
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
|
| 890 |
+
alpha = 0.35
|
| 891 |
+
|
| 892 |
+
smoothed: dict[int, tuple[float, float]] = {}
|
| 893 |
+
speeds: dict[int, float] = {}
|
| 894 |
+
|
| 895 |
+
prev_frame = None
|
| 896 |
+
prev_smooth = None
|
| 897 |
+
for frame_idx, (cx, cy) in items:
|
| 898 |
+
if prev_smooth is None:
|
| 899 |
+
smooth_x, smooth_y = float(cx), float(cy)
|
| 900 |
+
else:
|
| 901 |
+
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
|
| 902 |
+
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
|
| 903 |
+
smoothed[frame_idx] = (smooth_x, smooth_y)
|
| 904 |
+
if prev_smooth is None or prev_frame is None:
|
| 905 |
+
speeds[frame_idx] = 0.0
|
| 906 |
+
else:
|
| 907 |
+
frame_delta = max(1, frame_idx - prev_frame)
|
| 908 |
+
time_delta = frame_delta * dt
|
| 909 |
+
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
|
| 910 |
+
speed = dist / time_delta if time_delta > 0 else dist
|
| 911 |
+
speeds[frame_idx] = speed
|
| 912 |
+
prev_smooth = (smooth_x, smooth_y)
|
| 913 |
+
prev_frame = frame_idx
|
| 914 |
+
|
| 915 |
+
frames_ordered = [frame_idx for frame_idx, _ in items]
|
| 916 |
+
speed_series = [speeds.get(f, 0.0) for f in frames_ordered]
|
| 917 |
+
|
| 918 |
+
# Detect kick (velocity spike)
|
| 919 |
+
baseline_window = min(10, len(frames_ordered) // 3 or 1)
|
| 920 |
+
baseline_speeds = speed_series[:baseline_window]
|
| 921 |
+
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
|
| 922 |
+
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
|
| 923 |
+
base_threshold = baseline_speed + 4.0 * speed_std
|
| 924 |
+
if base_threshold < baseline_speed * 3.0:
|
| 925 |
+
base_threshold = baseline_speed * 3.0
|
| 926 |
+
speed_threshold = max(base_threshold, 15.0)
|
| 927 |
+
|
| 928 |
+
kick_frame: int | None = None
|
| 929 |
+
max_velocity = max(speed_series) if speed_series else 0.0
|
| 930 |
+
|
| 931 |
+
for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window):
|
| 932 |
+
speed = speed_series[idx]
|
| 933 |
+
if speed < speed_threshold:
|
| 934 |
+
continue
|
| 935 |
+
# Check sustain
|
| 936 |
+
sustain_ok = True
|
| 937 |
+
for j in range(1, 4):
|
| 938 |
+
if idx + j >= len(frames_ordered):
|
| 939 |
+
break
|
| 940 |
+
if speed_series[idx + j] < speed_threshold * 0.7:
|
| 941 |
+
sustain_ok = False
|
| 942 |
+
break
|
| 943 |
+
if sustain_ok:
|
| 944 |
+
kick_frame = frame
|
| 945 |
+
break
|
| 946 |
+
|
| 947 |
+
return {
|
| 948 |
+
"centers": centers,
|
| 949 |
+
"boxes": boxes,
|
| 950 |
+
"confs": confs,
|
| 951 |
+
"areas": areas,
|
| 952 |
+
"speeds": speeds,
|
| 953 |
+
"smoothed_centers": smoothed,
|
| 954 |
+
"frames_ordered": frames_ordered,
|
| 955 |
+
"speed_series": speed_series,
|
| 956 |
+
"threshold": speed_threshold,
|
| 957 |
+
"baseline": baseline_speed,
|
| 958 |
+
"kick_frame": kick_frame,
|
| 959 |
+
"max_velocity": max_velocity,
|
| 960 |
+
"has_kick": kick_frame is not None,
|
| 961 |
+
"coverage": len(centers) / total if total else 0.0,
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def _detect_and_track_all_ball_candidates(
|
| 966 |
+
state: AppState,
|
| 967 |
+
progress: gr.Progress | None = None,
|
| 968 |
+
) -> None:
|
| 969 |
+
"""
|
| 970 |
+
Detect all ball candidates in first frame, track each with YOLO,
|
| 971 |
+
score them, and auto-select the best candidate.
|
| 972 |
+
"""
|
| 973 |
+
if state is None or state.num_frames == 0:
|
| 974 |
+
raise gr.Error("Load a video first.")
|
| 975 |
+
|
| 976 |
+
first_frame = state.video_frames[0]
|
| 977 |
+
frame_width, frame_height = first_frame.size
|
| 978 |
+
|
| 979 |
+
# Step 1: Detect all balls in first frame
|
| 980 |
+
candidates = detect_all_balls(first_frame)
|
| 981 |
+
|
| 982 |
+
if not candidates:
|
| 983 |
+
state.ball_candidates = []
|
| 984 |
+
state.multi_ball_status = "❌ No ball candidates detected in first frame."
|
| 985 |
+
return
|
| 986 |
+
|
| 987 |
+
state.multi_ball_status = f"🔍 Found {len(candidates)} ball candidate(s). Tracking..."
|
| 988 |
+
|
| 989 |
+
# Step 2: Track each candidate
|
| 990 |
+
tracking_results: dict[int, dict] = {}
|
| 991 |
+
|
| 992 |
+
for i, candidate in enumerate(candidates):
|
| 993 |
+
if progress is not None:
|
| 994 |
+
progress((i + 1) / len(candidates), desc=f"Tracking ball {i+1}/{len(candidates)}")
|
| 995 |
+
|
| 996 |
+
result = _track_single_ball_candidate(state, candidate, progress=None)
|
| 997 |
+
tracking_results[candidate["id"]] = result
|
| 998 |
+
|
| 999 |
+
# Add tracking summary to candidate
|
| 1000 |
+
candidate["tracking"] = result
|
| 1001 |
+
candidate["has_kick"] = result["has_kick"]
|
| 1002 |
+
candidate["kick_frame"] = result["kick_frame"]
|
| 1003 |
+
candidate["max_velocity"] = result["max_velocity"]
|
| 1004 |
+
candidate["coverage"] = result["coverage"]
|
| 1005 |
+
|
| 1006 |
+
# Step 3: Score candidates
|
| 1007 |
+
frame_center_x = frame_width / 2
|
| 1008 |
+
|
| 1009 |
+
for candidate in candidates:
|
| 1010 |
+
score = 0.0
|
| 1011 |
+
|
| 1012 |
+
# 1. Has a detected kick (velocity spike) — most important
|
| 1013 |
+
if candidate["has_kick"]:
|
| 1014 |
+
score += 50
|
| 1015 |
+
|
| 1016 |
+
# 2. Higher max velocity — ball that moves most
|
| 1017 |
+
score += min(30, candidate["max_velocity"] / 10)
|
| 1018 |
+
|
| 1019 |
+
# 3. Centered horizontally
|
| 1020 |
+
x_offset = abs(candidate["center"][0] - frame_center_x) / frame_center_x
|
| 1021 |
+
score += 20 * (1 - x_offset)
|
| 1022 |
+
|
| 1023 |
+
# 4. YOLO confidence as tiebreaker
|
| 1024 |
+
score += candidate["conf"] * 10
|
| 1025 |
+
|
| 1026 |
+
# 5. Better coverage
|
| 1027 |
+
score += candidate["coverage"] * 10
|
| 1028 |
+
|
| 1029 |
+
candidate["score"] = score
|
| 1030 |
+
|
| 1031 |
+
# Sort by score descending
|
| 1032 |
+
candidates.sort(key=lambda c: c["score"], reverse=True)
|
| 1033 |
+
|
| 1034 |
+
# Re-assign IDs after sorting
|
| 1035 |
+
for i, c in enumerate(candidates):
|
| 1036 |
+
c["id"] = i
|
| 1037 |
+
|
| 1038 |
+
state.ball_candidates = candidates
|
| 1039 |
+
state.ball_candidates_tracking = tracking_results
|
| 1040 |
+
state.selected_ball_idx = 0 # Auto-select best candidate
|
| 1041 |
+
state.ball_selection_confirmed = False
|
| 1042 |
+
|
| 1043 |
+
# Build status message
|
| 1044 |
+
if len(candidates) == 1:
|
| 1045 |
+
c = candidates[0]
|
| 1046 |
+
kick_info = f"Kick @ frame {c['kick_frame']}" if c["has_kick"] else "No kick detected"
|
| 1047 |
+
state.multi_ball_status = f"✅ 1 ball detected. {kick_info}."
|
| 1048 |
+
else:
|
| 1049 |
+
kicked_count = sum(1 for c in candidates if c["has_kick"])
|
| 1050 |
+
state.multi_ball_status = (
|
| 1051 |
+
f"⚠️ {len(candidates)} balls detected. "
|
| 1052 |
+
f"{kicked_count} show movement. "
|
| 1053 |
+
f"Best candidate auto-selected. Please confirm or change selection."
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
def _apply_selected_ball_to_yolo_state(state: AppState) -> None:
|
| 1058 |
+
"""
|
| 1059 |
+
Copy the selected ball candidate's tracking data to the main YOLO state.
|
| 1060 |
+
This allows the rest of the pipeline to work unchanged.
|
| 1061 |
+
"""
|
| 1062 |
+
if not state.ball_candidates:
|
| 1063 |
+
return
|
| 1064 |
+
|
| 1065 |
+
idx = state.selected_ball_idx
|
| 1066 |
+
if idx < 0 or idx >= len(state.ball_candidates):
|
| 1067 |
+
idx = 0
|
| 1068 |
+
|
| 1069 |
+
candidate = state.ball_candidates[idx]
|
| 1070 |
+
tracking = candidate.get("tracking", {})
|
| 1071 |
+
|
| 1072 |
+
# Copy to main YOLO state
|
| 1073 |
+
state.yolo_ball_centers = tracking.get("centers", {})
|
| 1074 |
+
state.yolo_ball_boxes = tracking.get("boxes", {})
|
| 1075 |
+
state.yolo_ball_conf = tracking.get("confs", {})
|
| 1076 |
+
state.yolo_smoothed_centers = tracking.get("smoothed_centers", {})
|
| 1077 |
+
state.yolo_speeds = tracking.get("speeds", {})
|
| 1078 |
+
state.yolo_kick_frames = tracking.get("frames_ordered", [])
|
| 1079 |
+
state.yolo_kick_speeds = tracking.get("speed_series", [])
|
| 1080 |
+
state.yolo_threshold = tracking.get("threshold")
|
| 1081 |
+
state.yolo_baseline_speed = tracking.get("baseline")
|
| 1082 |
+
state.yolo_kick_frame = tracking.get("kick_frame")
|
| 1083 |
+
state.yolo_initial_frame = tracking.get("frames_ordered", [None])[0] if tracking.get("frames_ordered") else None
|
| 1084 |
+
|
| 1085 |
+
# Compute areas
|
| 1086 |
+
areas = tracking.get("areas", {})
|
| 1087 |
+
frames_ordered = tracking.get("frames_ordered", [])
|
| 1088 |
+
state.yolo_mask_area_proxy = [areas.get(f, 0.0) for f in frames_ordered]
|
| 1089 |
+
|
| 1090 |
+
# Compute distance from start
|
| 1091 |
+
smoothed = tracking.get("smoothed_centers", {})
|
| 1092 |
+
if smoothed and frames_ordered:
|
| 1093 |
+
origin = smoothed.get(frames_ordered[0], (0, 0))
|
| 1094 |
+
distance_dict = {}
|
| 1095 |
+
for f, (sx, sy) in smoothed.items():
|
| 1096 |
+
distance_dict[f] = math.hypot(sx - origin[0], sy - origin[1])
|
| 1097 |
+
state.yolo_distance_from_start = distance_dict
|
| 1098 |
+
state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered]
|
| 1099 |
+
|
| 1100 |
+
# Update kick frame
|
| 1101 |
+
kick_frame = tracking.get("kick_frame")
|
| 1102 |
+
if kick_frame is not None:
|
| 1103 |
+
state.kick_frame = kick_frame
|
| 1104 |
+
_compute_sam_window_from_kick(state, kick_frame)
|
| 1105 |
+
|
| 1106 |
+
# Mark as tracked
|
| 1107 |
+
state.is_yolo_tracked = True
|
| 1108 |
+
state.ball_selection_confirmed = True
|
| 1109 |
+
|
| 1110 |
+
coverage = tracking.get("coverage", 0.0)
|
| 1111 |
+
if kick_frame is not None:
|
| 1112 |
+
state.yolo_status = f"✅ Ball {idx+1} tracked. Kick @ frame {kick_frame}."
|
| 1113 |
+
else:
|
| 1114 |
+
state.yolo_status = f"⚠️ Ball {idx+1} tracked ({coverage:.0%} coverage) but no kick detected."
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
| 1118 |
"""Generate a deterministic pastel RGB color for a given object id.
|
| 1119 |
|
|
|
|
| 1308 |
self.is_sam_tracked: bool = False
|
| 1309 |
self.is_player_detected: bool = False
|
| 1310 |
self.is_player_propagated: bool = False
|
| 1311 |
+
|
| 1312 |
+
# Multi-ball candidate tracking
|
| 1313 |
+
self.ball_candidates: list[dict] = [] # All detected ball candidates
|
| 1314 |
+
self.ball_candidates_tracking: dict[int, dict] = {} # Per-candidate tracking data
|
| 1315 |
+
self.selected_ball_idx: int = 0 # Currently selected candidate index
|
| 1316 |
+
self.ball_selection_confirmed: bool = False # True after user confirms selection
|
| 1317 |
+
self.multi_ball_status: str = "" # Status message for multi-ball detection
|
| 1318 |
+
|
| 1319 |
self.goal_mode: str = GOAL_MODE_IDLE
|
| 1320 |
self.goal_points_norm: list[tuple[float, float]] = []
|
| 1321 |
self.goal_confirmed_points_norm: list[tuple[float, float]] = []
|
|
|
|
| 2664 |
return fig
|
| 2665 |
|
| 2666 |
|
| 2667 |
+
def _build_multi_ball_chart(state: AppState):
|
| 2668 |
+
"""
|
| 2669 |
+
Build a combined speed chart showing all ball candidates.
|
| 2670 |
+
The selected/kicked ball is highlighted in green, others in gray.
|
| 2671 |
+
"""
|
| 2672 |
+
fig = go.Figure()
|
| 2673 |
+
|
| 2674 |
+
if state is None or not state.ball_candidates:
|
| 2675 |
+
fig.update_layout(
|
| 2676 |
+
title="Ball Candidates Speed Comparison",
|
| 2677 |
+
xaxis_title="Frame",
|
| 2678 |
+
yaxis_title="Speed (px/s)",
|
| 2679 |
+
)
|
| 2680 |
+
return fig
|
| 2681 |
+
|
| 2682 |
+
# Color palette for candidates
|
| 2683 |
+
colors = [
|
| 2684 |
+
"#4caf50", # Green (selected/kicked)
|
| 2685 |
+
"#9e9e9e", # Gray
|
| 2686 |
+
"#bdbdbd", # Light gray
|
| 2687 |
+
"#757575", # Dark gray
|
| 2688 |
+
"#e0e0e0", # Very light gray
|
| 2689 |
+
]
|
| 2690 |
+
|
| 2691 |
+
selected_idx = state.selected_ball_idx
|
| 2692 |
+
max_speed = 0.0
|
| 2693 |
+
kick_frames_to_mark = []
|
| 2694 |
+
|
| 2695 |
+
for i, candidate in enumerate(state.ball_candidates):
|
| 2696 |
+
tracking = candidate.get("tracking", {})
|
| 2697 |
+
frames = tracking.get("frames_ordered", [])
|
| 2698 |
+
speeds = tracking.get("speed_series", [])
|
| 2699 |
+
|
| 2700 |
+
if not frames or not speeds:
|
| 2701 |
+
continue
|
| 2702 |
+
|
| 2703 |
+
max_speed = max(max_speed, max(speeds) if speeds else 0)
|
| 2704 |
+
|
| 2705 |
+
is_selected = (i == selected_idx)
|
| 2706 |
+
is_kicked = candidate.get("has_kick", False)
|
| 2707 |
+
|
| 2708 |
+
# Determine color and style
|
| 2709 |
+
if is_selected:
|
| 2710 |
+
color = "#4caf50" # Green
|
| 2711 |
+
width = 3
|
| 2712 |
+
opacity = 1.0
|
| 2713 |
+
elif is_kicked:
|
| 2714 |
+
color = "#ff9800" # Orange for other kicked balls
|
| 2715 |
+
width = 2
|
| 2716 |
+
opacity = 0.7
|
| 2717 |
+
else:
|
| 2718 |
+
color = "#9e9e9e" # Gray
|
| 2719 |
+
width = 1
|
| 2720 |
+
opacity = 0.5
|
| 2721 |
+
|
| 2722 |
+
# Build label
|
| 2723 |
+
label_parts = [f"Ball {i+1}"]
|
| 2724 |
+
if is_kicked:
|
| 2725 |
+
label_parts.append("⚽")
|
| 2726 |
+
if is_selected:
|
| 2727 |
+
label_parts.append("✓")
|
| 2728 |
+
label = " ".join(label_parts)
|
| 2729 |
+
|
| 2730 |
+
fig.add_trace(
|
| 2731 |
+
go.Scatter(
|
| 2732 |
+
x=frames,
|
| 2733 |
+
y=speeds,
|
| 2734 |
+
mode="lines",
|
| 2735 |
+
name=label,
|
| 2736 |
+
line=dict(color=color, width=width),
|
| 2737 |
+
opacity=opacity,
|
| 2738 |
+
)
|
| 2739 |
+
)
|
| 2740 |
+
|
| 2741 |
+
# Mark kick frame
|
| 2742 |
+
kick_frame = candidate.get("kick_frame")
|
| 2743 |
+
if kick_frame is not None:
|
| 2744 |
+
kick_frames_to_mark.append((kick_frame, i, is_selected))
|
| 2745 |
+
|
| 2746 |
+
# Add vertical lines for kick frames
|
| 2747 |
+
for kick_frame, ball_idx, is_selected in kick_frames_to_mark:
|
| 2748 |
+
color = "#e91e63" if is_selected else "#ffcdd2"
|
| 2749 |
+
width = 3 if is_selected else 1
|
| 2750 |
+
|
| 2751 |
+
fig.add_vline(
|
| 2752 |
+
x=kick_frame,
|
| 2753 |
+
line=dict(color=color, width=width, dash="solid" if is_selected else "dot"),
|
| 2754 |
+
annotation_text=f"Ball {ball_idx+1} kick" if is_selected else "",
|
| 2755 |
+
annotation_position="top right" if is_selected else None,
|
| 2756 |
+
)
|
| 2757 |
+
|
| 2758 |
+
fig.update_layout(
|
| 2759 |
+
title="Ball Candidates Speed Comparison",
|
| 2760 |
+
xaxis=dict(title="Frame"),
|
| 2761 |
+
yaxis=dict(title="Speed (px/s)", range=[0, max_speed * 1.1] if max_speed > 0 else None),
|
| 2762 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 2763 |
+
margin=dict(t=60, l=40, r=40, b=40),
|
| 2764 |
+
hovermode="x unified",
|
| 2765 |
+
)
|
| 2766 |
+
|
| 2767 |
+
return fig
|
| 2768 |
+
|
| 2769 |
+
|
| 2770 |
def _jump_to_frame(state: AppState, target: int | None):
|
| 2771 |
if state is None or state.num_frames == 0 or target is None:
|
| 2772 |
return gr.update(), gr.update()
|
|
|
|
| 4222 |
with gr.Row(elem_classes=["model-status"]):
|
| 4223 |
yolo_kick_btn = gr.Button("⚽: N/A", interactive=False)
|
| 4224 |
yolo_impact_btn = gr.Button("🚩: N/A", interactive=False)
|
| 4225 |
+
|
| 4226 |
+
# Multi-ball candidate selection UI
|
| 4227 |
+
with gr.Column(visible=False) as multi_ball_selection_col:
|
| 4228 |
+
multi_ball_status_md = gr.Markdown("", visible=True)
|
| 4229 |
+
ball_candidate_radio = gr.Radio(
|
| 4230 |
+
choices=[],
|
| 4231 |
+
value=None,
|
| 4232 |
+
label="Select Ball Candidate",
|
| 4233 |
+
interactive=True,
|
| 4234 |
+
)
|
| 4235 |
+
with gr.Row():
|
| 4236 |
+
confirm_ball_btn = gr.Button("Confirm Selection", variant="primary")
|
| 4237 |
+
multi_ball_chart = gr.Plot(label="Ball Candidates Speed Comparison", show_label=True)
|
| 4238 |
+
|
| 4239 |
yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
|
| 4240 |
with gr.Column(elem_classes=["model-section"]):
|
| 4241 |
with gr.Row(elem_classes=["model-row"]):
|
|
|
|
| 4772 |
state_in.is_ball_detected = False
|
| 4773 |
frame_idx = 0
|
| 4774 |
frame = state_in.video_frames[frame_idx]
|
| 4775 |
+
|
| 4776 |
+
# First, try multi-ball detection
|
| 4777 |
+
candidates = detect_all_balls(frame)
|
| 4778 |
+
|
| 4779 |
+
if not candidates:
|
| 4780 |
+
# Fallback to single-ball detection
|
| 4781 |
+
detection = detect_ball_center(frame)
|
| 4782 |
+
if detection is None:
|
| 4783 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 4784 |
+
status_updates = _ui_status_updates(state_in)
|
| 4785 |
+
return (
|
| 4786 |
+
update_frame_display(state_in, frame_idx),
|
| 4787 |
+
gr.update(
|
| 4788 |
+
value="❌ Unable to auto-detect the ball. Please add a point manually.",
|
| 4789 |
+
visible=True,
|
| 4790 |
+
),
|
| 4791 |
+
gr.update(value=frame_idx),
|
| 4792 |
+
_build_kick_plot(state_in),
|
| 4793 |
+
propagate_main_update,
|
| 4794 |
+
detect_btn_update,
|
| 4795 |
+
propagate_player_update,
|
| 4796 |
+
*status_updates,
|
| 4797 |
+
)
|
| 4798 |
+
x_center, y_center, _, _, conf = detection
|
| 4799 |
+
else:
|
| 4800 |
+
# Use the best candidate (first one after sorting by confidence)
|
| 4801 |
+
best = candidates[0]
|
| 4802 |
+
x_center, y_center = best["center"]
|
| 4803 |
+
conf = best["conf"]
|
| 4804 |
+
|
| 4805 |
+
# Store candidates for potential multi-ball workflow
|
| 4806 |
+
state_in.ball_candidates = candidates
|
| 4807 |
+
state_in.selected_ball_idx = 0
|
| 4808 |
+
|
| 4809 |
+
if len(candidates) > 1:
|
| 4810 |
+
state_in.multi_ball_status = f"⚠️ {len(candidates)} balls detected. Using best candidate. Run 'Track Ball' to analyze all."
|
| 4811 |
+
else:
|
| 4812 |
+
state_in.multi_ball_status = ""
|
| 4813 |
|
|
|
|
| 4814 |
frame_width, frame_height = frame.size
|
| 4815 |
x_center = max(0, min(frame_width - 1, int(x_center)))
|
| 4816 |
y_center = max(0, min(frame_height - 1, int(y_center)))
|
|
|
|
| 4836 |
)
|
| 4837 |
|
| 4838 |
state_in.is_ball_detected = True
|
| 4839 |
+
num_candidates = len(getattr(state_in, 'ball_candidates', []))
|
| 4840 |
+
if num_candidates > 1:
|
| 4841 |
+
status_text = f"⚠️ {num_candidates} balls found! Best at ({x_center}, {y_center}) (conf={conf:.2f})"
|
| 4842 |
+
else:
|
| 4843 |
+
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
|
| 4844 |
status_text += f" | {_format_kick_status(state_in)}"
|
| 4845 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 4846 |
status_updates = _ui_status_updates(state_in)
|
|
|
|
| 4887 |
raise gr.Error("Load a video first, then track the ball with YOLO.")
|
| 4888 |
progress = gr.Progress(track_tqdm=False)
|
| 4889 |
state_in.is_yolo_tracked = False
|
| 4890 |
+
|
| 4891 |
+
# Check if we have multiple ball candidates
|
| 4892 |
+
num_candidates = len(getattr(state_in, 'ball_candidates', []))
|
| 4893 |
+
|
| 4894 |
+
if num_candidates > 1:
|
| 4895 |
+
# Multi-ball mode: track all candidates and show comparison
|
| 4896 |
+
_detect_and_track_all_ball_candidates(state_in, progress=progress)
|
| 4897 |
+
|
| 4898 |
+
# Apply the best candidate to YOLO state
|
| 4899 |
+
_apply_selected_ball_to_yolo_state(state_in)
|
| 4900 |
+
|
| 4901 |
+
base_msg = state_in.multi_ball_status or state_in.yolo_status or ""
|
| 4902 |
+
else:
|
| 4903 |
+
# Single ball mode: use original tracking
|
| 4904 |
+
_perform_yolo_ball_tracking(state_in, progress=progress)
|
| 4905 |
+
base_msg = state_in.yolo_status or ""
|
| 4906 |
+
|
| 4907 |
target_frame = (
|
| 4908 |
state_in.yolo_kick_frame
|
| 4909 |
if state_in.yolo_kick_frame is not None
|
|
|
|
| 4915 |
target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
|
| 4916 |
state_in.current_frame_idx = target_frame
|
| 4917 |
preview_img = update_frame_display(state_in, target_frame)
|
|
|
|
| 4918 |
kick_msg = _format_kick_status(state_in)
|
| 4919 |
status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
|
| 4920 |
state_in.is_yolo_tracked = True
|