Mirko Trasciatti
commited on
Commit
·
fb2fd45
1
Parent(s):
2feeac4
Add YOLO-driven kick detection and chart
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
| 3 |
from copy import deepcopy
|
|
@@ -161,6 +163,236 @@ def detect_person_box(
|
|
| 161 |
return x_min, y_min, x_max, y_max, conf
|
| 162 |
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
| 165 |
"""Generate a deterministic pastel RGB color for a given object id.
|
| 166 |
|
|
@@ -305,6 +537,25 @@ class AppState:
|
|
| 305 |
self.player_obj_id: int | None = None
|
| 306 |
self.player_detection_frame: int | None = None
|
| 307 |
self.player_detection_conf: float | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
def __repr__(self):
|
| 310 |
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
|
|
@@ -403,9 +654,43 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
|
|
| 403 |
GLOBAL_STATE.impact_debug_speed_kmh = []
|
| 404 |
GLOBAL_STATE.impact_debug_speed_threshold_px = None
|
| 405 |
GLOBAL_STATE.impact_meters_per_px = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
GLOBAL_STATE.player_obj_id = None
|
| 407 |
GLOBAL_STATE.player_detection_frame = None
|
| 408 |
GLOBAL_STATE.player_detection_conf = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
load_model_if_needed(GLOBAL_STATE)
|
| 411 |
|
|
@@ -957,6 +1242,137 @@ def _build_kick_plot(state: AppState):
|
|
| 957 |
return fig
|
| 958 |
|
| 959 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
def _format_impact_status(state: AppState) -> str:
|
| 961 |
if state is None:
|
| 962 |
return "Impact frame: not computed"
|
|
@@ -1018,15 +1434,10 @@ def _player_has_masks(state: AppState) -> bool:
|
|
| 1018 |
|
| 1019 |
|
| 1020 |
def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
kick_candidate = state.kick_frame or getattr(state, "kick_debug_kick_frame", None)
|
| 1026 |
-
if kick_candidate is not None:
|
| 1027 |
-
detect_player_enabled = True
|
| 1028 |
-
if detect_player_enabled and _player_has_masks(state):
|
| 1029 |
-
propagate_player_enabled = True
|
| 1030 |
return (
|
| 1031 |
gr.update(interactive=propagate_main_enabled),
|
| 1032 |
gr.update(interactive=detect_player_enabled),
|
|
@@ -1468,6 +1879,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1468 |
"Load a video first.",
|
| 1469 |
gr.update(),
|
| 1470 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 1471 |
_format_impact_status(GLOBAL_STATE),
|
| 1472 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1473 |
propagate_main_update,
|
|
@@ -1475,6 +1887,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1475 |
propagate_player_update,
|
| 1476 |
)
|
| 1477 |
|
|
|
|
|
|
|
| 1478 |
processor = deepcopy(GLOBAL_STATE.processor)
|
| 1479 |
model = deepcopy(GLOBAL_STATE.model)
|
| 1480 |
inference_session = deepcopy(GLOBAL_STATE.inference_session)
|
|
@@ -1483,9 +1897,19 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1483 |
inference_session.cache.inference_device = "cuda"
|
| 1484 |
model.to("cuda")
|
| 1485 |
|
| 1486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1487 |
processed = 0
|
| 1488 |
|
|
|
|
|
|
|
| 1489 |
# Initial status; no slider change yet
|
| 1490 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1491 |
yield (
|
|
@@ -1493,6 +1917,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1493 |
f"Propagating masks: {processed}/{total}",
|
| 1494 |
gr.update(),
|
| 1495 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 1496 |
_format_impact_status(GLOBAL_STATE),
|
| 1497 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1498 |
propagate_main_update,
|
|
@@ -1500,9 +1925,10 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1500 |
propagate_player_update,
|
| 1501 |
)
|
| 1502 |
|
| 1503 |
-
last_frame_idx =
|
| 1504 |
with torch.inference_mode():
|
| 1505 |
-
for frame_idx
|
|
|
|
| 1506 |
pixel_values = None
|
| 1507 |
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
| 1508 |
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
|
|
@@ -1531,6 +1957,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1531 |
f"Propagating masks: {processed}/{total}",
|
| 1532 |
gr.update(value=frame_idx),
|
| 1533 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 1534 |
_format_impact_status(GLOBAL_STATE),
|
| 1535 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1536 |
propagate_main_update,
|
|
@@ -1554,6 +1981,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1554 |
text,
|
| 1555 |
gr.update(value=target_frame),
|
| 1556 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 1557 |
_format_impact_status(GLOBAL_STATE),
|
| 1558 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1559 |
propagate_main_update,
|
|
@@ -1646,6 +2074,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 1646 |
status,
|
| 1647 |
gr.update(visible=False, value=""),
|
| 1648 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 1649 |
_format_impact_status(GLOBAL_STATE),
|
| 1650 |
propagate_main_update,
|
| 1651 |
detect_btn_update,
|
|
@@ -1893,7 +2322,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1893 |
"""
|
| 1894 |
**Working with results**
|
| 1895 |
- **Preview**: Use the slider to navigate frames and see the current masks.
|
| 1896 |
-
- **
|
| 1897 |
- **Export**: Render an MP4 for smooth playback using the original video FPS.
|
| 1898 |
- **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
|
| 1899 |
"""
|
|
@@ -1951,7 +2380,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1951 |
)
|
| 1952 |
with gr.Row():
|
| 1953 |
detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
|
| 1954 |
-
|
|
|
|
| 1955 |
detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
|
| 1956 |
propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
|
| 1957 |
ball_status = gr.Markdown(visible=False)
|
|
@@ -1963,6 +2393,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1963 |
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
|
| 1964 |
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
|
| 1965 |
kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True)
|
|
|
|
| 1966 |
|
| 1967 |
# Wire events
|
| 1968 |
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
|
@@ -1975,6 +2406,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1975 |
status,
|
| 1976 |
gr.update(visible=False, value=""),
|
| 1977 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 1978 |
_format_impact_status(GLOBAL_STATE),
|
| 1979 |
propagate_main_update,
|
| 1980 |
detect_btn_update,
|
|
@@ -1984,7 +2416,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1984 |
video_in.change(
|
| 1985 |
_on_video_change,
|
| 1986 |
inputs=[GLOBAL_STATE, video_in],
|
| 1987 |
-
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 1988 |
show_progress=True,
|
| 1989 |
)
|
| 1990 |
|
|
@@ -1997,7 +2429,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1997 |
examples=examples_list,
|
| 1998 |
inputs=[GLOBAL_STATE, video_in],
|
| 1999 |
fn=_on_video_change,
|
| 2000 |
-
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2001 |
label="Examples",
|
| 2002 |
cache_examples=False,
|
| 2003 |
examples_per_page=5,
|
|
@@ -2187,6 +2619,43 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2187 |
outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2188 |
)
|
| 2189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2190 |
def _auto_detect_player(state_in: AppState):
|
| 2191 |
if state_in is None or state_in.num_frames == 0:
|
| 2192 |
raise gr.Error("Load a video first, then try auto-detect.")
|
|
@@ -2303,6 +2772,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2303 |
"Load a video first.",
|
| 2304 |
gr.update(),
|
| 2305 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 2306 |
_format_impact_status(GLOBAL_STATE),
|
| 2307 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2308 |
propagate_main_update,
|
|
@@ -2316,6 +2786,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2316 |
"Detect the player before propagating.",
|
| 2317 |
gr.update(),
|
| 2318 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 2319 |
_format_impact_status(GLOBAL_STATE),
|
| 2320 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2321 |
propagate_main_update,
|
|
@@ -2330,8 +2801,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2330 |
inference_session.cache.inference_device = "cuda"
|
| 2331 |
model.to("cuda")
|
| 2332 |
|
| 2333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2334 |
processed = 0
|
|
|
|
| 2335 |
|
| 2336 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2337 |
yield (
|
|
@@ -2339,6 +2819,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2339 |
f"Propagating player: {processed}/{total}",
|
| 2340 |
gr.update(),
|
| 2341 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 2342 |
_format_impact_status(GLOBAL_STATE),
|
| 2343 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2344 |
propagate_main_update,
|
|
@@ -2349,7 +2830,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2349 |
player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID
|
| 2350 |
|
| 2351 |
with torch.inference_mode():
|
| 2352 |
-
for frame_idx
|
|
|
|
| 2353 |
pixel_values = None
|
| 2354 |
if (
|
| 2355 |
inference_session.processed_frames is None
|
|
@@ -2375,6 +2857,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2375 |
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
|
| 2376 |
|
| 2377 |
processed += 1
|
|
|
|
| 2378 |
if processed % 30 == 0 or processed == total:
|
| 2379 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2380 |
yield (
|
|
@@ -2382,6 +2865,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2382 |
f"Propagating player: {processed}/{total}",
|
| 2383 |
gr.update(value=frame_idx),
|
| 2384 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 2385 |
_format_impact_status(GLOBAL_STATE),
|
| 2386 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2387 |
propagate_main_update,
|
|
@@ -2394,7 +2878,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2394 |
if target_frame is None:
|
| 2395 |
target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None)
|
| 2396 |
if target_frame is None:
|
| 2397 |
-
target_frame =
|
| 2398 |
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
|
| 2399 |
GLOBAL_STATE.current_frame_idx = target_frame
|
| 2400 |
|
|
@@ -2404,6 +2888,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2404 |
text,
|
| 2405 |
gr.update(value=target_frame),
|
| 2406 |
_build_kick_plot(GLOBAL_STATE),
|
|
|
|
| 2407 |
_format_impact_status(GLOBAL_STATE),
|
| 2408 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2409 |
propagate_main_update,
|
|
@@ -2414,7 +2899,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2414 |
propagate_player_btn.click(
|
| 2415 |
propagate_player_masks,
|
| 2416 |
inputs=[GLOBAL_STATE],
|
| 2417 |
-
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2418 |
)
|
| 2419 |
|
| 2420 |
# Image click to add a point and run forward on that frame
|
|
@@ -2483,13 +2968,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2483 |
propagate_btn.click(
|
| 2484 |
propagate_masks,
|
| 2485 |
inputs=[GLOBAL_STATE],
|
| 2486 |
-
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2487 |
)
|
| 2488 |
|
| 2489 |
reset_btn.click(
|
| 2490 |
reset_session,
|
| 2491 |
inputs=GLOBAL_STATE,
|
| 2492 |
-
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2493 |
)
|
| 2494 |
|
| 2495 |
# ============================================================================
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import colorsys
|
| 4 |
import gc
|
| 5 |
from copy import deepcopy
|
|
|
|
| 163 |
return x_min, y_min, x_max, y_max, conf
|
| 164 |
|
| 165 |
|
| 166 |
+
def _compute_sam_window_from_kick(state: AppState, kick_frame: int | None) -> tuple[int, int]:
|
| 167 |
+
total_frames = state.num_frames
|
| 168 |
+
if total_frames == 0:
|
| 169 |
+
return 0, 0
|
| 170 |
+
fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0
|
| 171 |
+
target_window_frames = max(1, int(round(fps * 4.0)))
|
| 172 |
+
half_window = target_window_frames // 2
|
| 173 |
+
if kick_frame is None:
|
| 174 |
+
start_idx = 0
|
| 175 |
+
else:
|
| 176 |
+
start_idx = max(0, int(kick_frame) - half_window)
|
| 177 |
+
end_idx = min(total_frames, start_idx + target_window_frames)
|
| 178 |
+
if end_idx <= start_idx:
|
| 179 |
+
end_idx = min(total_frames, start_idx + 1)
|
| 180 |
+
state.sam_window = (start_idx, end_idx)
|
| 181 |
+
return start_idx, end_idx
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None = None) -> None:
|
| 185 |
+
if state is None or state.num_frames == 0:
|
| 186 |
+
raise gr.Error("Load a video first, then track with YOLO.")
|
| 187 |
+
|
| 188 |
+
model = get_yolo_model()
|
| 189 |
+
class_ids = [
|
| 190 |
+
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME
|
| 191 |
+
]
|
| 192 |
+
if not class_ids:
|
| 193 |
+
raise gr.Error("YOLO model does not contain the sports ball class.")
|
| 194 |
+
|
| 195 |
+
frames = state.video_frames
|
| 196 |
+
total = len(frames)
|
| 197 |
+
centers: dict[int, tuple[float, float]] = {}
|
| 198 |
+
boxes: dict[int, tuple[int, int, int, int]] = {}
|
| 199 |
+
confs: dict[int, float] = {}
|
| 200 |
+
areas: dict[int, float] = {}
|
| 201 |
+
first_detection_frame: int | None = None
|
| 202 |
+
|
| 203 |
+
for idx, frame in enumerate(frames):
|
| 204 |
+
if progress is not None:
|
| 205 |
+
progress((idx + 1) / total)
|
| 206 |
+
|
| 207 |
+
results = model.predict(
|
| 208 |
+
source=frame,
|
| 209 |
+
conf=YOLO_CONF_THRESHOLD,
|
| 210 |
+
iou=YOLO_IOU_THRESHOLD,
|
| 211 |
+
max_det=1,
|
| 212 |
+
classes=class_ids,
|
| 213 |
+
imgsz=640,
|
| 214 |
+
device="cpu",
|
| 215 |
+
verbose=False,
|
| 216 |
+
)
|
| 217 |
+
if not results:
|
| 218 |
+
continue
|
| 219 |
+
boxes_result = results[0].boxes
|
| 220 |
+
if boxes_result is None or len(boxes_result) == 0:
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
box = boxes_result[0]
|
| 224 |
+
xywh = box.xywh[0].cpu().tolist()
|
| 225 |
+
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0
|
| 226 |
+
x_center, y_center, width, height = xywh
|
| 227 |
+
x_center = float(x_center)
|
| 228 |
+
y_center = float(y_center)
|
| 229 |
+
width = max(1.0, float(width))
|
| 230 |
+
height = max(1.0, float(height))
|
| 231 |
+
|
| 232 |
+
frame_width, frame_height = frame.size
|
| 233 |
+
x_min = int(round(max(0.0, x_center - width / 2.0)))
|
| 234 |
+
y_min = int(round(max(0.0, y_center - height / 2.0)))
|
| 235 |
+
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0)))
|
| 236 |
+
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0)))
|
| 237 |
+
if x_max <= x_min or y_max <= y_min:
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
centers[idx] = (x_center, y_center)
|
| 241 |
+
boxes[idx] = (x_min, y_min, x_max, y_max)
|
| 242 |
+
confs[idx] = conf
|
| 243 |
+
areas[idx] = float((x_max - x_min) * (y_max - y_min))
|
| 244 |
+
if first_detection_frame is None:
|
| 245 |
+
first_detection_frame = idx
|
| 246 |
+
|
| 247 |
+
state.yolo_ball_centers = centers
|
| 248 |
+
state.yolo_ball_boxes = boxes
|
| 249 |
+
state.yolo_ball_conf = confs
|
| 250 |
+
state.yolo_mask_area_proxy = [areas.get(k, 0.0) for k in sorted(centers.keys())]
|
| 251 |
+
state.yolo_initial_frame = first_detection_frame
|
| 252 |
+
|
| 253 |
+
if len(centers) < 3:
|
| 254 |
+
state.yolo_smoothed_centers = {}
|
| 255 |
+
state.yolo_speeds = {}
|
| 256 |
+
state.yolo_distance_from_start = {}
|
| 257 |
+
state.yolo_threshold = None
|
| 258 |
+
state.yolo_baseline_speed = None
|
| 259 |
+
state.yolo_speed_std = None
|
| 260 |
+
state.yolo_kick_frame = None
|
| 261 |
+
state.yolo_status = "❌ YOLO13: insufficient detections to estimate kick. Please retry or annotate manually."
|
| 262 |
+
state.sam_window = None
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
items = sorted(centers.items())
|
| 266 |
+
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
|
| 267 |
+
alpha = 0.35
|
| 268 |
+
|
| 269 |
+
smoothed: dict[int, tuple[float, float]] = {}
|
| 270 |
+
speeds: dict[int, float] = {}
|
| 271 |
+
|
| 272 |
+
prev_frame = None
|
| 273 |
+
prev_smooth = None
|
| 274 |
+
for frame_idx, (cx, cy) in items:
|
| 275 |
+
if prev_smooth is None:
|
| 276 |
+
smooth_x, smooth_y = float(cx), float(cy)
|
| 277 |
+
else:
|
| 278 |
+
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
|
| 279 |
+
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
|
| 280 |
+
smoothed[frame_idx] = (smooth_x, smooth_y)
|
| 281 |
+
if prev_smooth is None or prev_frame is None:
|
| 282 |
+
speeds[frame_idx] = 0.0
|
| 283 |
+
else:
|
| 284 |
+
frame_delta = max(1, frame_idx - prev_frame)
|
| 285 |
+
time_delta = frame_delta * dt
|
| 286 |
+
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
|
| 287 |
+
speed = dist / time_delta if time_delta > 0 else dist
|
| 288 |
+
speeds[frame_idx] = speed
|
| 289 |
+
prev_smooth = (smooth_x, smooth_y)
|
| 290 |
+
prev_frame = frame_idx
|
| 291 |
+
|
| 292 |
+
frames_ordered = [frame_idx for frame_idx, _ in items]
|
| 293 |
+
speed_series = [speeds.get(f, 0.0) for f in frames_ordered]
|
| 294 |
+
|
| 295 |
+
baseline_window = min(10, len(frames_ordered) // 3 or 1)
|
| 296 |
+
baseline_speeds = speed_series[:baseline_window]
|
| 297 |
+
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
|
| 298 |
+
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0
|
| 299 |
+
base_threshold = baseline_speed + 4.0 * speed_std
|
| 300 |
+
if base_threshold < baseline_speed * 3.0:
|
| 301 |
+
base_threshold = baseline_speed * 3.0
|
| 302 |
+
speed_threshold = max(base_threshold, 15.0)
|
| 303 |
+
|
| 304 |
+
distance_dict: dict[int, float] = {}
|
| 305 |
+
if smoothed:
|
| 306 |
+
first_frame = frames_ordered[0]
|
| 307 |
+
origin = smoothed[first_frame]
|
| 308 |
+
for frame_idx, (sx, sy) in smoothed.items():
|
| 309 |
+
distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1])
|
| 310 |
+
|
| 311 |
+
areas_dict = {idx: areas.get(idx, 0.0) for idx in frames_ordered}
|
| 312 |
+
initial_area = areas_dict.get(frames_ordered[0], 1.0) or 1.0
|
| 313 |
+
radius_estimate = math.sqrt(initial_area / math.pi)
|
| 314 |
+
adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0))
|
| 315 |
+
|
| 316 |
+
sustain_frames = 3
|
| 317 |
+
holdout_frames = 8
|
| 318 |
+
area_window = 4
|
| 319 |
+
area_drop_ratio = 0.75
|
| 320 |
+
|
| 321 |
+
kalman_pos, kalman_speed, _ = _run_kalman_filter(items, dt)
|
| 322 |
+
kalman_speed_series = [kalman_speed.get(f, 0.0) for f in frames_ordered]
|
| 323 |
+
|
| 324 |
+
kick_frame: int | None = None
|
| 325 |
+
for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window):
|
| 326 |
+
speed = speed_series[idx]
|
| 327 |
+
if speed < speed_threshold:
|
| 328 |
+
continue
|
| 329 |
+
sustain_ok = True
|
| 330 |
+
for j in range(1, sustain_frames + 1):
|
| 331 |
+
if idx + j >= len(frames_ordered):
|
| 332 |
+
break
|
| 333 |
+
if speed_series[idx + j] < speed_threshold * 0.7:
|
| 334 |
+
sustain_ok = False
|
| 335 |
+
break
|
| 336 |
+
if not sustain_ok:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
area_pass = True
|
| 340 |
+
current_area = areas_dict.get(frame)
|
| 341 |
+
if current_area:
|
| 342 |
+
prev_areas = [
|
| 343 |
+
areas_dict.get(f)
|
| 344 |
+
for f in frames_ordered[max(0, idx - area_window):idx]
|
| 345 |
+
if areas_dict.get(f) is not None
|
| 346 |
+
]
|
| 347 |
+
if prev_areas:
|
| 348 |
+
median_prev = statistics.median(prev_areas)
|
| 349 |
+
if median_prev > 0:
|
| 350 |
+
ratio = current_area / median_prev
|
| 351 |
+
if ratio > area_drop_ratio:
|
| 352 |
+
area_pass = False
|
| 353 |
+
if not area_pass and speed < speed_threshold * 1.2:
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
future_slice = frames_ordered[idx: min(len(frames_ordered), idx + holdout_frames)]
|
| 357 |
+
max_future_dist = 0.0
|
| 358 |
+
for future_frame in future_slice:
|
| 359 |
+
dist = distance_dict.get(future_frame, 0.0)
|
| 360 |
+
if dist > max_future_dist:
|
| 361 |
+
max_future_dist = dist
|
| 362 |
+
if max_future_dist < adaptive_return_distance:
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
kick_frame = frame
|
| 366 |
+
break
|
| 367 |
+
|
| 368 |
+
state.yolo_smoothed_centers = smoothed
|
| 369 |
+
state.yolo_speeds = speeds
|
| 370 |
+
state.yolo_distance_from_start = distance_dict
|
| 371 |
+
state.yolo_threshold = speed_threshold
|
| 372 |
+
state.yolo_baseline_speed = baseline_speed
|
| 373 |
+
state.yolo_speed_std = speed_std
|
| 374 |
+
state.yolo_kick_frames = frames_ordered
|
| 375 |
+
state.yolo_kick_speeds = speed_series
|
| 376 |
+
state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered]
|
| 377 |
+
state.yolo_mask_area_proxy = [areas_dict.get(f, 0.0) for f in frames_ordered]
|
| 378 |
+
state.yolo_kick_frame = kick_frame
|
| 379 |
+
coverage = len(centers) / total if total else 0.0
|
| 380 |
+
if kick_frame is not None:
|
| 381 |
+
state.yolo_status = f"✅ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%})."
|
| 382 |
+
else:
|
| 383 |
+
state.yolo_status = (
|
| 384 |
+
f"⚠️ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%}) but did not find a definitive kick."
|
| 385 |
+
)
|
| 386 |
+
state.kalman_centers[BALL_OBJECT_ID] = kalman_pos
|
| 387 |
+
state.kalman_speeds[BALL_OBJECT_ID] = kalman_speed
|
| 388 |
+
|
| 389 |
+
if kick_frame is not None:
|
| 390 |
+
state.kick_frame = kick_frame
|
| 391 |
+
_compute_sam_window_from_kick(state, kick_frame)
|
| 392 |
+
else:
|
| 393 |
+
state.sam_window = None
|
| 394 |
+
|
| 395 |
+
|
| 396 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
| 397 |
"""Generate a deterministic pastel RGB color for a given object id.
|
| 398 |
|
|
|
|
| 537 |
self.player_obj_id: int | None = None
|
| 538 |
self.player_detection_frame: int | None = None
|
| 539 |
self.player_detection_conf: float | None = None
|
| 540 |
+
# YOLO tracking caches
|
| 541 |
+
self.yolo_ball_centers: dict[int, tuple[float, float]] = {}
|
| 542 |
+
self.yolo_ball_boxes: dict[int, tuple[int, int, int, int]] = {}
|
| 543 |
+
self.yolo_ball_conf: dict[int, float] = {}
|
| 544 |
+
self.yolo_smoothed_centers: dict[int, tuple[float, float]] = {}
|
| 545 |
+
self.yolo_speeds: dict[int, float] = {}
|
| 546 |
+
self.yolo_distance_from_start: dict[int, float] = {}
|
| 547 |
+
self.yolo_threshold: float | None = None
|
| 548 |
+
self.yolo_baseline_speed: float | None = None
|
| 549 |
+
self.yolo_speed_std: float | None = None
|
| 550 |
+
self.yolo_kick_frame: int | None = None
|
| 551 |
+
self.yolo_status: str = ""
|
| 552 |
+
self.yolo_kick_frames: list[int] = []
|
| 553 |
+
self.yolo_kick_speeds: list[float] = []
|
| 554 |
+
self.yolo_kick_distance: list[float] = []
|
| 555 |
+
self.yolo_mask_area_proxy: list[float] = []
|
| 556 |
+
self.yolo_initial_frame: int | None = None
|
| 557 |
+
# SAM window (start_idx inclusive, end_idx exclusive)
|
| 558 |
+
self.sam_window: tuple[int, int] | None = None
|
| 559 |
|
| 560 |
def __repr__(self):
|
| 561 |
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
|
|
|
|
| 654 |
GLOBAL_STATE.impact_debug_speed_kmh = []
|
| 655 |
GLOBAL_STATE.impact_debug_speed_threshold_px = None
|
| 656 |
GLOBAL_STATE.impact_meters_per_px = None
|
| 657 |
+
GLOBAL_STATE.yolo_ball_centers = {}
|
| 658 |
+
GLOBAL_STATE.yolo_ball_boxes = {}
|
| 659 |
+
GLOBAL_STATE.yolo_ball_conf = {}
|
| 660 |
+
GLOBAL_STATE.yolo_smoothed_centers = {}
|
| 661 |
+
GLOBAL_STATE.yolo_speeds = {}
|
| 662 |
+
GLOBAL_STATE.yolo_distance_from_start = {}
|
| 663 |
+
GLOBAL_STATE.yolo_threshold = None
|
| 664 |
+
GLOBAL_STATE.yolo_baseline_speed = None
|
| 665 |
+
GLOBAL_STATE.yolo_speed_std = None
|
| 666 |
+
GLOBAL_STATE.yolo_kick_frame = None
|
| 667 |
+
GLOBAL_STATE.yolo_status = ""
|
| 668 |
+
GLOBAL_STATE.yolo_kick_frames = []
|
| 669 |
+
GLOBAL_STATE.yolo_kick_speeds = []
|
| 670 |
+
GLOBAL_STATE.yolo_kick_distance = []
|
| 671 |
+
GLOBAL_STATE.yolo_mask_area_proxy = []
|
| 672 |
+
GLOBAL_STATE.yolo_initial_frame = None
|
| 673 |
+
GLOBAL_STATE.sam_window = None
|
| 674 |
GLOBAL_STATE.player_obj_id = None
|
| 675 |
GLOBAL_STATE.player_detection_frame = None
|
| 676 |
GLOBAL_STATE.player_detection_conf = None
|
| 677 |
+
GLOBAL_STATE.yolo_ball_centers = {}
|
| 678 |
+
GLOBAL_STATE.yolo_ball_boxes = {}
|
| 679 |
+
GLOBAL_STATE.yolo_ball_conf = {}
|
| 680 |
+
GLOBAL_STATE.yolo_smoothed_centers = {}
|
| 681 |
+
GLOBAL_STATE.yolo_speeds = {}
|
| 682 |
+
GLOBAL_STATE.yolo_distance_from_start = {}
|
| 683 |
+
GLOBAL_STATE.yolo_threshold = None
|
| 684 |
+
GLOBAL_STATE.yolo_baseline_speed = None
|
| 685 |
+
GLOBAL_STATE.yolo_speed_std = None
|
| 686 |
+
GLOBAL_STATE.yolo_kick_frame = None
|
| 687 |
+
GLOBAL_STATE.yolo_status = ""
|
| 688 |
+
GLOBAL_STATE.yolo_kick_frames = []
|
| 689 |
+
GLOBAL_STATE.yolo_kick_speeds = []
|
| 690 |
+
GLOBAL_STATE.yolo_kick_distance = []
|
| 691 |
+
GLOBAL_STATE.yolo_mask_area_proxy = []
|
| 692 |
+
GLOBAL_STATE.yolo_initial_frame = None
|
| 693 |
+
GLOBAL_STATE.sam_window = None
|
| 694 |
|
| 695 |
load_model_if_needed(GLOBAL_STATE)
|
| 696 |
|
|
|
|
| 1242 |
return fig
|
| 1243 |
|
| 1244 |
|
| 1245 |
+
def _ensure_ball_prompt_from_yolo(state: AppState):
|
| 1246 |
+
if (
|
| 1247 |
+
state is None
|
| 1248 |
+
or state.inference_session is None
|
| 1249 |
+
or not state.yolo_ball_centers
|
| 1250 |
+
):
|
| 1251 |
+
return
|
| 1252 |
+
# Check if we already have clicks for the ball
|
| 1253 |
+
for frame_clicks in state.clicks_by_frame_obj.values():
|
| 1254 |
+
if frame_clicks.get(BALL_OBJECT_ID):
|
| 1255 |
+
return
|
| 1256 |
+
anchor_frame = state.yolo_initial_frame
|
| 1257 |
+
if anchor_frame is None and state.yolo_ball_centers:
|
| 1258 |
+
anchor_frame = min(state.yolo_ball_centers.keys())
|
| 1259 |
+
if anchor_frame is None or anchor_frame >= state.num_frames:
|
| 1260 |
+
return
|
| 1261 |
+
center = state.yolo_ball_centers.get(anchor_frame)
|
| 1262 |
+
if center is None:
|
| 1263 |
+
return
|
| 1264 |
+
x_center, y_center = center
|
| 1265 |
+
frame_width, frame_height = state.video_frames[anchor_frame].size
|
| 1266 |
+
x_center = int(np.clip(round(x_center), 0, frame_width - 1))
|
| 1267 |
+
y_center = int(np.clip(round(y_center), 0, frame_height - 1))
|
| 1268 |
+
event = SimpleNamespace(
|
| 1269 |
+
index=(x_center, y_center),
|
| 1270 |
+
value={"x": x_center, "y": y_center},
|
| 1271 |
+
)
|
| 1272 |
+
state.current_obj_id = BALL_OBJECT_ID
|
| 1273 |
+
state.current_label = "positive"
|
| 1274 |
+
state.current_frame_idx = anchor_frame
|
| 1275 |
+
on_image_click(
|
| 1276 |
+
update_frame_display(state, anchor_frame),
|
| 1277 |
+
state,
|
| 1278 |
+
anchor_frame,
|
| 1279 |
+
BALL_OBJECT_ID,
|
| 1280 |
+
"positive",
|
| 1281 |
+
False,
|
| 1282 |
+
event,
|
| 1283 |
+
)
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
def _build_yolo_plot(state: AppState):
|
| 1287 |
+
fig = go.Figure()
|
| 1288 |
+
if state is None or not state.yolo_kick_frames or not state.yolo_kick_speeds:
|
| 1289 |
+
fig.update_layout(
|
| 1290 |
+
title="YOLO kick diagnostics",
|
| 1291 |
+
xaxis_title="Frame",
|
| 1292 |
+
yaxis_title="Speed (px/s)",
|
| 1293 |
+
)
|
| 1294 |
+
return fig
|
| 1295 |
+
|
| 1296 |
+
frames = state.yolo_kick_frames
|
| 1297 |
+
speeds = state.yolo_kick_speeds
|
| 1298 |
+
distance = state.yolo_kick_distance if state.yolo_kick_distance else [0.0] * len(frames)
|
| 1299 |
+
areas = state.yolo_mask_area_proxy if state.yolo_mask_area_proxy else [0.0] * len(frames)
|
| 1300 |
+
threshold = state.yolo_threshold or 0.0
|
| 1301 |
+
baseline = state.yolo_baseline_speed or 0.0
|
| 1302 |
+
kick_frame = state.yolo_kick_frame
|
| 1303 |
+
|
| 1304 |
+
fig.add_trace(
|
| 1305 |
+
go.Scatter(
|
| 1306 |
+
x=frames,
|
| 1307 |
+
y=speeds,
|
| 1308 |
+
mode="lines+markers",
|
| 1309 |
+
name="YOLO speed",
|
| 1310 |
+
line=dict(color="#4caf50"),
|
| 1311 |
+
)
|
| 1312 |
+
)
|
| 1313 |
+
fig.add_trace(
|
| 1314 |
+
go.Scatter(
|
| 1315 |
+
x=frames,
|
| 1316 |
+
y=[threshold] * len(frames),
|
| 1317 |
+
mode="lines",
|
| 1318 |
+
name="Adaptive threshold",
|
| 1319 |
+
line=dict(color="#ff9800", dash="dash"),
|
| 1320 |
+
)
|
| 1321 |
+
)
|
| 1322 |
+
fig.add_trace(
|
| 1323 |
+
go.Scatter(
|
| 1324 |
+
x=frames,
|
| 1325 |
+
y=[baseline] * len(frames),
|
| 1326 |
+
mode="lines",
|
| 1327 |
+
name="Baseline speed",
|
| 1328 |
+
line=dict(color="#9e9e9e", dash="dot"),
|
| 1329 |
+
)
|
| 1330 |
+
)
|
| 1331 |
+
fig.add_trace(
|
| 1332 |
+
go.Scatter(
|
| 1333 |
+
x=frames,
|
| 1334 |
+
y=distance,
|
| 1335 |
+
mode="lines",
|
| 1336 |
+
name="Distance from start",
|
| 1337 |
+
line=dict(color="#03a9f4"),
|
| 1338 |
+
yaxis="y2",
|
| 1339 |
+
)
|
| 1340 |
+
)
|
| 1341 |
+
fig.add_trace(
|
| 1342 |
+
go.Scatter(
|
| 1343 |
+
x=frames,
|
| 1344 |
+
y=areas,
|
| 1345 |
+
mode="lines",
|
| 1346 |
+
name="Box area proxy",
|
| 1347 |
+
line=dict(color="#ab47bc", dash="dot"),
|
| 1348 |
+
yaxis="y2",
|
| 1349 |
+
)
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
if kick_frame is not None:
|
| 1353 |
+
fig.add_vline(
|
| 1354 |
+
x=kick_frame,
|
| 1355 |
+
line=dict(color="#e91e63", width=2),
|
| 1356 |
+
annotation_text=f"Kick {kick_frame}",
|
| 1357 |
+
annotation_position="top right",
|
| 1358 |
+
)
|
| 1359 |
+
|
| 1360 |
+
fig.update_layout(
|
| 1361 |
+
title="YOLO kick diagnostics",
|
| 1362 |
+
xaxis=dict(title="Frame"),
|
| 1363 |
+
yaxis=dict(title="Speed (px/s)"),
|
| 1364 |
+
yaxis2=dict(
|
| 1365 |
+
title="Distance / Area",
|
| 1366 |
+
overlaying="y",
|
| 1367 |
+
side="right",
|
| 1368 |
+
showgrid=False,
|
| 1369 |
+
),
|
| 1370 |
+
legend=dict(orientation="h"),
|
| 1371 |
+
margin=dict(t=40, l=40, r=40, b=40),
|
| 1372 |
+
)
|
| 1373 |
+
return fig
|
| 1374 |
+
|
| 1375 |
+
|
| 1376 |
def _format_impact_status(state: AppState) -> str:
|
| 1377 |
if state is None:
|
| 1378 |
return "Impact frame: not computed"
|
|
|
|
| 1434 |
|
| 1435 |
|
| 1436 |
def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
|
| 1437 |
+
yolo_ready = isinstance(state, AppState) and state.yolo_kick_frame is not None
|
| 1438 |
+
propagate_main_enabled = _ball_has_masks(state) or yolo_ready
|
| 1439 |
+
detect_player_enabled = yolo_ready
|
| 1440 |
+
propagate_player_enabled = _player_has_masks(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1441 |
return (
|
| 1442 |
gr.update(interactive=propagate_main_enabled),
|
| 1443 |
gr.update(interactive=detect_player_enabled),
|
|
|
|
| 1879 |
"Load a video first.",
|
| 1880 |
gr.update(),
|
| 1881 |
_build_kick_plot(GLOBAL_STATE),
|
| 1882 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 1883 |
_format_impact_status(GLOBAL_STATE),
|
| 1884 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1885 |
propagate_main_update,
|
|
|
|
| 1887 |
propagate_player_update,
|
| 1888 |
)
|
| 1889 |
|
| 1890 |
+
_ensure_ball_prompt_from_yolo(GLOBAL_STATE)
|
| 1891 |
+
|
| 1892 |
processor = deepcopy(GLOBAL_STATE.processor)
|
| 1893 |
model = deepcopy(GLOBAL_STATE.model)
|
| 1894 |
inference_session = deepcopy(GLOBAL_STATE.inference_session)
|
|
|
|
| 1897 |
inference_session.cache.inference_device = "cuda"
|
| 1898 |
model.to("cuda")
|
| 1899 |
|
| 1900 |
+
if not GLOBAL_STATE.sam_window:
|
| 1901 |
+
_compute_sam_window_from_kick(
|
| 1902 |
+
GLOBAL_STATE,
|
| 1903 |
+
GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None),
|
| 1904 |
+
)
|
| 1905 |
+
start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames)
|
| 1906 |
+
start_idx = max(0, int(start_idx))
|
| 1907 |
+
end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx)))
|
| 1908 |
+
total = max(1, end_idx - start_idx)
|
| 1909 |
processed = 0
|
| 1910 |
|
| 1911 |
+
_ensure_ball_prompt_from_yolo(GLOBAL_STATE)
|
| 1912 |
+
|
| 1913 |
# Initial status; no slider change yet
|
| 1914 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1915 |
yield (
|
|
|
|
| 1917 |
f"Propagating masks: {processed}/{total}",
|
| 1918 |
gr.update(),
|
| 1919 |
_build_kick_plot(GLOBAL_STATE),
|
| 1920 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 1921 |
_format_impact_status(GLOBAL_STATE),
|
| 1922 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1923 |
propagate_main_update,
|
|
|
|
| 1925 |
propagate_player_update,
|
| 1926 |
)
|
| 1927 |
|
| 1928 |
+
last_frame_idx = start_idx
|
| 1929 |
with torch.inference_mode():
|
| 1930 |
+
for frame_idx in range(start_idx, end_idx):
|
| 1931 |
+
frame = GLOBAL_STATE.video_frames[frame_idx]
|
| 1932 |
pixel_values = None
|
| 1933 |
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
| 1934 |
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
|
|
|
|
| 1957 |
f"Propagating masks: {processed}/{total}",
|
| 1958 |
gr.update(value=frame_idx),
|
| 1959 |
_build_kick_plot(GLOBAL_STATE),
|
| 1960 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 1961 |
_format_impact_status(GLOBAL_STATE),
|
| 1962 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1963 |
propagate_main_update,
|
|
|
|
| 1981 |
text,
|
| 1982 |
gr.update(value=target_frame),
|
| 1983 |
_build_kick_plot(GLOBAL_STATE),
|
| 1984 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 1985 |
_format_impact_status(GLOBAL_STATE),
|
| 1986 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1987 |
propagate_main_update,
|
|
|
|
| 2074 |
status,
|
| 2075 |
gr.update(visible=False, value=""),
|
| 2076 |
_build_kick_plot(GLOBAL_STATE),
|
| 2077 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2078 |
_format_impact_status(GLOBAL_STATE),
|
| 2079 |
propagate_main_update,
|
| 2080 |
detect_btn_update,
|
|
|
|
| 2322 |
"""
|
| 2323 |
**Working with results**
|
| 2324 |
- **Preview**: Use the slider to navigate frames and see the current masks.
|
| 2325 |
+
- **Track**: Click “Track ball (SAM2)” to track all defined objects across the selected window. The preview follows progress periodically to keep things responsive.
|
| 2326 |
- **Export**: Render an MP4 for smooth playback using the original video FPS.
|
| 2327 |
- **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video).
|
| 2328 |
"""
|
|
|
|
| 2380 |
)
|
| 2381 |
with gr.Row():
|
| 2382 |
detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
|
| 2383 |
+
track_ball_yolo_btn = gr.Button("Track ball (YOLO13)", variant="secondary")
|
| 2384 |
+
propagate_btn = gr.Button("Track ball (SAM2)", variant="primary", interactive=False)
|
| 2385 |
detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
|
| 2386 |
propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
|
| 2387 |
ball_status = gr.Markdown(visible=False)
|
|
|
|
| 2393 |
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
|
| 2394 |
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
|
| 2395 |
kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True)
|
| 2396 |
+
yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True)
|
| 2397 |
|
| 2398 |
# Wire events
|
| 2399 |
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
|
|
|
| 2406 |
status,
|
| 2407 |
gr.update(visible=False, value=""),
|
| 2408 |
_build_kick_plot(GLOBAL_STATE),
|
| 2409 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2410 |
_format_impact_status(GLOBAL_STATE),
|
| 2411 |
propagate_main_update,
|
| 2412 |
detect_btn_update,
|
|
|
|
| 2416 |
video_in.change(
|
| 2417 |
_on_video_change,
|
| 2418 |
inputs=[GLOBAL_STATE, video_in],
|
| 2419 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2420 |
show_progress=True,
|
| 2421 |
)
|
| 2422 |
|
|
|
|
| 2429 |
examples=examples_list,
|
| 2430 |
inputs=[GLOBAL_STATE, video_in],
|
| 2431 |
fn=_on_video_change,
|
| 2432 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2433 |
label="Examples",
|
| 2434 |
cache_examples=False,
|
| 2435 |
examples_per_page=5,
|
|
|
|
| 2619 |
outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2620 |
)
|
| 2621 |
|
| 2622 |
+
def _track_ball_yolo(state_in: AppState):
|
| 2623 |
+
if state_in is None or state_in.num_frames == 0:
|
| 2624 |
+
raise gr.Error("Load a video first, then track the ball with YOLO.")
|
| 2625 |
+
progress = gr.Progress(track_tqdm=False)
|
| 2626 |
+
_perform_yolo_ball_tracking(state_in, progress=progress)
|
| 2627 |
+
target_frame = (
|
| 2628 |
+
state_in.yolo_kick_frame
|
| 2629 |
+
if state_in.yolo_kick_frame is not None
|
| 2630 |
+
else state_in.yolo_initial_frame
|
| 2631 |
+
if state_in.yolo_initial_frame is not None
|
| 2632 |
+
else 0
|
| 2633 |
+
)
|
| 2634 |
+
if state_in.num_frames:
|
| 2635 |
+
target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1))
|
| 2636 |
+
state_in.current_frame_idx = target_frame
|
| 2637 |
+
preview_img = update_frame_display(state_in, target_frame)
|
| 2638 |
+
base_msg = state_in.yolo_status or ""
|
| 2639 |
+
kick_msg = _format_kick_status(state_in)
|
| 2640 |
+
status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg
|
| 2641 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 2642 |
+
return (
|
| 2643 |
+
preview_img,
|
| 2644 |
+
gr.update(value=status_text, visible=True),
|
| 2645 |
+
gr.update(value=target_frame),
|
| 2646 |
+
_build_kick_plot(state_in),
|
| 2647 |
+
_build_yolo_plot(state_in),
|
| 2648 |
+
propagate_main_update,
|
| 2649 |
+
detect_btn_update,
|
| 2650 |
+
propagate_player_update,
|
| 2651 |
+
)
|
| 2652 |
+
|
| 2653 |
+
track_ball_yolo_btn.click(
|
| 2654 |
+
_track_ball_yolo,
|
| 2655 |
+
inputs=[GLOBAL_STATE],
|
| 2656 |
+
outputs=[preview, ball_status, frame_slider, kick_plot, yolo_plot, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2657 |
+
)
|
| 2658 |
+
|
| 2659 |
def _auto_detect_player(state_in: AppState):
|
| 2660 |
if state_in is None or state_in.num_frames == 0:
|
| 2661 |
raise gr.Error("Load a video first, then try auto-detect.")
|
|
|
|
| 2772 |
"Load a video first.",
|
| 2773 |
gr.update(),
|
| 2774 |
_build_kick_plot(GLOBAL_STATE),
|
| 2775 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2776 |
_format_impact_status(GLOBAL_STATE),
|
| 2777 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2778 |
propagate_main_update,
|
|
|
|
| 2786 |
"Detect the player before propagating.",
|
| 2787 |
gr.update(),
|
| 2788 |
_build_kick_plot(GLOBAL_STATE),
|
| 2789 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2790 |
_format_impact_status(GLOBAL_STATE),
|
| 2791 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2792 |
propagate_main_update,
|
|
|
|
| 2801 |
inference_session.cache.inference_device = "cuda"
|
| 2802 |
model.to("cuda")
|
| 2803 |
|
| 2804 |
+
if not GLOBAL_STATE.sam_window:
|
| 2805 |
+
_compute_sam_window_from_kick(
|
| 2806 |
+
GLOBAL_STATE,
|
| 2807 |
+
GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None),
|
| 2808 |
+
)
|
| 2809 |
+
start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames)
|
| 2810 |
+
start_idx = max(0, int(start_idx))
|
| 2811 |
+
end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx)))
|
| 2812 |
+
total = max(1, end_idx - start_idx)
|
| 2813 |
processed = 0
|
| 2814 |
+
last_frame_idx = start_idx
|
| 2815 |
|
| 2816 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2817 |
yield (
|
|
|
|
| 2819 |
f"Propagating player: {processed}/{total}",
|
| 2820 |
gr.update(),
|
| 2821 |
_build_kick_plot(GLOBAL_STATE),
|
| 2822 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2823 |
_format_impact_status(GLOBAL_STATE),
|
| 2824 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2825 |
propagate_main_update,
|
|
|
|
| 2830 |
player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID
|
| 2831 |
|
| 2832 |
with torch.inference_mode():
|
| 2833 |
+
for frame_idx in range(start_idx, end_idx):
|
| 2834 |
+
frame = GLOBAL_STATE.video_frames[frame_idx]
|
| 2835 |
pixel_values = None
|
| 2836 |
if (
|
| 2837 |
inference_session.processed_frames is None
|
|
|
|
| 2857 |
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
|
| 2858 |
|
| 2859 |
processed += 1
|
| 2860 |
+
last_frame_idx = frame_idx
|
| 2861 |
if processed % 30 == 0 or processed == total:
|
| 2862 |
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2863 |
yield (
|
|
|
|
| 2865 |
f"Propagating player: {processed}/{total}",
|
| 2866 |
gr.update(value=frame_idx),
|
| 2867 |
_build_kick_plot(GLOBAL_STATE),
|
| 2868 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2869 |
_format_impact_status(GLOBAL_STATE),
|
| 2870 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2871 |
propagate_main_update,
|
|
|
|
| 2878 |
if target_frame is None:
|
| 2879 |
target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None)
|
| 2880 |
if target_frame is None:
|
| 2881 |
+
target_frame = last_frame_idx
|
| 2882 |
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
|
| 2883 |
GLOBAL_STATE.current_frame_idx = target_frame
|
| 2884 |
|
|
|
|
| 2888 |
text,
|
| 2889 |
gr.update(value=target_frame),
|
| 2890 |
_build_kick_plot(GLOBAL_STATE),
|
| 2891 |
+
_build_yolo_plot(GLOBAL_STATE),
|
| 2892 |
_format_impact_status(GLOBAL_STATE),
|
| 2893 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2894 |
propagate_main_update,
|
|
|
|
| 2899 |
propagate_player_btn.click(
|
| 2900 |
propagate_player_masks,
|
| 2901 |
inputs=[GLOBAL_STATE],
|
| 2902 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2903 |
)
|
| 2904 |
|
| 2905 |
# Image click to add a point and run forward on that frame
|
|
|
|
| 2968 |
propagate_btn.click(
|
| 2969 |
propagate_masks,
|
| 2970 |
inputs=[GLOBAL_STATE],
|
| 2971 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2972 |
)
|
| 2973 |
|
| 2974 |
reset_btn.click(
|
| 2975 |
reset_session,
|
| 2976 |
inputs=GLOBAL_STATE,
|
| 2977 |
+
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2978 |
)
|
| 2979 |
|
| 2980 |
# ============================================================================
|