Mirko Trasciatti commited on
Commit
83490b7
·
1 Parent(s): 70dff07

Add visual YOLO bounding boxes for all ball candidates

Browse files

- draw_yolo_detections_on_frame(): Draws boxes with labels on preview
- Selected candidate: Green thick box with checkmark
- Kicked candidates: Orange box with soccer ball emoji
- Other candidates: Yellow thin box
- Each labeled 'Ball N (conf%)' with center crosshair
- Updated _auto_detect_ball to draw boxes immediately after detection
- Updated _track_ball_yolo to draw boxes after tracking (with kick info)
- Updated radio button handler to redraw preview when selection changes

This provides clear visual feedback of what YOLO is detecting.

Files changed (1) hide show
  1. app.py +125 -2
app.py CHANGED
@@ -1114,6 +1114,100 @@ def _apply_selected_ball_to_yolo_state(state: AppState) -> None:
1114
  state.yolo_status = f"⚠️ Ball {idx+1} tracked ({coverage:.0%} coverage) but no kick detected."
1115
 
1116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1117
  def _format_ball_candidates_for_radio(candidates: list[dict]) -> list[str]:
1118
  """Format ball candidates as radio button choices."""
1119
  choices = []
@@ -4893,6 +4987,15 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
4893
 
4894
  state_in.is_ball_detected = True
4895
  num_candidates = len(getattr(state_in, 'ball_candidates', []))
 
 
 
 
 
 
 
 
 
4896
  if num_candidates > 1:
4897
  status_text = f"⚠️ {num_candidates} balls found! Best at ({x_center}, {y_center}) (conf={conf:.2f}). Click 'Track Ball' to analyze all."
4898
  else:
@@ -5003,6 +5106,16 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
5003
  target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
5004
  state_in.current_frame_idx = target_frame
5005
  preview_img = update_frame_display(state_in, target_frame)
 
 
 
 
 
 
 
 
 
 
5006
  kick_msg = _format_kick_status(state_in)
5007
  status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
5008
  state_in.is_yolo_tracked = True
@@ -5072,18 +5185,28 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
5072
 
5073
  state_in.selected_ball_idx = new_idx
5074
 
 
 
 
 
 
 
 
 
 
 
5075
  # Update the chart to highlight the new selection
5076
  chart_update = gr.update(value=_build_multi_ball_chart(state_in))
5077
  status_update = gr.update(
5078
  value=_format_ball_candidates_markdown(state_in.ball_candidates, new_idx)
5079
  )
5080
 
5081
- return chart_update, status_update
5082
 
5083
  ball_candidate_radio.change(
5084
  _on_ball_candidate_change,
5085
  inputs=[GLOBAL_STATE, ball_candidate_radio],
5086
- outputs=[multi_ball_chart, multi_ball_status_md],
5087
  )
5088
 
5089
  def _on_confirm_ball_selection(state_in: AppState):
 
1114
  state.yolo_status = f"⚠️ Ball {idx+1} tracked ({coverage:.0%} coverage) but no kick detected."
1115
 
1116
 
1117
+ def draw_yolo_detections_on_frame(
1118
+ frame: Image.Image,
1119
+ candidates: list[dict],
1120
+ selected_idx: int = 0,
1121
+ show_all: bool = True,
1122
+ ) -> Image.Image:
1123
+ """
1124
+ Draw YOLO bounding boxes for all ball candidates on the frame.
1125
+
1126
+ - Selected candidate: Green box with thick border
1127
+ - Other candidates: Yellow/orange boxes with thinner border
1128
+ - Each box labeled with "Ball N (conf%)"
1129
+ """
1130
+ from PIL import ImageDraw, ImageFont
1131
+
1132
+ result = frame.copy()
1133
+ draw = ImageDraw.Draw(result)
1134
+
1135
+ # Try to get a font, fallback to default
1136
+ try:
1137
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
1138
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
1139
+ except:
1140
+ font = ImageFont.load_default()
1141
+ small_font = font
1142
+
1143
+ for i, candidate in enumerate(candidates):
1144
+ box = candidate.get("box")
1145
+ if not box:
1146
+ continue
1147
+
1148
+ x_min, y_min, x_max, y_max = box
1149
+ conf = candidate.get("conf", 0)
1150
+ is_selected = (i == selected_idx)
1151
+ has_kick = candidate.get("has_kick", False)
1152
+
1153
+ # Colors and styles
1154
+ if is_selected:
1155
+ box_color = (0, 255, 0) # Green for selected
1156
+ text_color = (0, 255, 0)
1157
+ width = 4
1158
+ elif has_kick:
1159
+ box_color = (255, 165, 0) # Orange for kicked but not selected
1160
+ text_color = (255, 165, 0)
1161
+ width = 3
1162
+ else:
1163
+ box_color = (255, 255, 0) # Yellow for others
1164
+ text_color = (255, 255, 0)
1165
+ width = 2
1166
+
1167
+ # Draw bounding box
1168
+ for offset in range(width):
1169
+ draw.rectangle(
1170
+ [x_min - offset, y_min - offset, x_max + offset, y_max + offset],
1171
+ outline=box_color,
1172
+ )
1173
+
1174
+ # Draw dark outline for visibility
1175
+ draw.rectangle(
1176
+ [x_min - width - 1, y_min - width - 1, x_max + width + 1, y_max + width + 1],
1177
+ outline=(0, 0, 0),
1178
+ )
1179
+
1180
+ # Label
1181
+ label = f"Ball {i + 1} ({conf:.0%})"
1182
+ if is_selected:
1183
+ label = f"✓ {label}"
1184
+ if has_kick:
1185
+ label += " ⚽"
1186
+
1187
+ # Draw label background
1188
+ text_bbox = draw.textbbox((x_min, y_min - 22), label, font=font)
1189
+ padding = 3
1190
+ bg_box = [
1191
+ text_bbox[0] - padding,
1192
+ text_bbox[1] - padding,
1193
+ text_bbox[2] + padding,
1194
+ text_bbox[3] + padding,
1195
+ ]
1196
+ draw.rectangle(bg_box, fill=(0, 0, 0, 200))
1197
+
1198
+ # Draw label text
1199
+ draw.text((x_min, y_min - 22), label, fill=text_color, font=font)
1200
+
1201
+ # Draw center crosshair
1202
+ cx, cy = candidate.get("center", (0, 0))
1203
+ cx, cy = int(cx), int(cy)
1204
+ cross_size = 8
1205
+ draw.line([(cx - cross_size, cy), (cx + cross_size, cy)], fill=box_color, width=2)
1206
+ draw.line([(cx, cy - cross_size), (cx, cy + cross_size)], fill=box_color, width=2)
1207
+
1208
+ return result
1209
+
1210
+
1211
  def _format_ball_candidates_for_radio(candidates: list[dict]) -> list[str]:
1212
  """Format ball candidates as radio button choices."""
1213
  choices = []
 
4987
 
4988
  state_in.is_ball_detected = True
4989
  num_candidates = len(getattr(state_in, 'ball_candidates', []))
4990
+
4991
+ # Draw YOLO bounding boxes on preview if we have candidates
4992
+ if num_candidates > 0 and isinstance(preview_img, Image.Image):
4993
+ preview_img = draw_yolo_detections_on_frame(
4994
+ preview_img,
4995
+ state_in.ball_candidates,
4996
+ selected_idx=0,
4997
+ )
4998
+
4999
  if num_candidates > 1:
5000
  status_text = f"⚠️ {num_candidates} balls found! Best at ({x_center}, {y_center}) (conf={conf:.2f}). Click 'Track Ball' to analyze all."
5001
  else:
 
5106
  target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
5107
  state_in.current_frame_idx = target_frame
5108
  preview_img = update_frame_display(state_in, target_frame)
5109
+
5110
+ # Draw YOLO bounding boxes on preview if we have candidates (after tracking, with kick info)
5111
+ candidates = getattr(state_in, 'ball_candidates', [])
5112
+ if len(candidates) > 0 and isinstance(preview_img, Image.Image):
5113
+ preview_img = draw_yolo_detections_on_frame(
5114
+ preview_img,
5115
+ candidates,
5116
+ selected_idx=state_in.selected_ball_idx,
5117
+ )
5118
+
5119
  kick_msg = _format_kick_status(state_in)
5120
  status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
5121
  state_in.is_yolo_tracked = True
 
5185
 
5186
  state_in.selected_ball_idx = new_idx
5187
 
5188
+ # Update the preview to show the new selection highlighted
5189
+ frame_idx = state_in.current_frame_idx
5190
+ preview_img = update_frame_display(state_in, frame_idx)
5191
+ if isinstance(preview_img, Image.Image):
5192
+ preview_img = draw_yolo_detections_on_frame(
5193
+ preview_img,
5194
+ state_in.ball_candidates,
5195
+ selected_idx=new_idx,
5196
+ )
5197
+
5198
  # Update the chart to highlight the new selection
5199
  chart_update = gr.update(value=_build_multi_ball_chart(state_in))
5200
  status_update = gr.update(
5201
  value=_format_ball_candidates_markdown(state_in.ball_candidates, new_idx)
5202
  )
5203
 
5204
+ return preview_img, chart_update, status_update
5205
 
5206
  ball_candidate_radio.change(
5207
  _on_ball_candidate_change,
5208
  inputs=[GLOBAL_STATE, ball_candidate_radio],
5209
+ outputs=[preview, multi_ball_chart, multi_ball_status_md],
5210
  )
5211
 
5212
  def _on_confirm_ball_selection(state_in: AppState):