Mirko Trasciatti
commited on
Commit
·
2feeac4
1
Parent(s):
707381b
Disable propagate buttons until detections
Browse files
app.py
CHANGED
|
@@ -47,6 +47,7 @@ YOLO_CONF_THRESHOLD = 0.0
|
|
| 47 |
YOLO_IOU_THRESHOLD = 0.02
|
| 48 |
PLAYER_TARGET_NAME = "person"
|
| 49 |
PLAYER_OBJECT_ID = 2
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO:
|
|
@@ -997,6 +998,15 @@ def _format_kick_status(state: AppState) -> str:
|
|
| 997 |
return f"Kick frame ≈ {frame}{time_part}"
|
| 998 |
|
| 999 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
def _player_has_masks(state: AppState) -> bool:
|
| 1001 |
if state is None or state.player_obj_id is None:
|
| 1002 |
return False
|
|
@@ -1007,17 +1017,20 @@ def _player_has_masks(state: AppState) -> bool:
|
|
| 1007 |
return False
|
| 1008 |
|
| 1009 |
|
| 1010 |
-
def
|
| 1011 |
-
|
| 1012 |
-
|
|
|
|
| 1013 |
if isinstance(state, AppState):
|
| 1014 |
kick_candidate = state.kick_frame or getattr(state, "kick_debug_kick_frame", None)
|
| 1015 |
if kick_candidate is not None:
|
| 1016 |
-
|
| 1017 |
-
|
|
|
|
| 1018 |
return (
|
| 1019 |
-
gr.update(interactive=
|
| 1020 |
-
gr.update(interactive=
|
|
|
|
| 1021 |
)
|
| 1022 |
|
| 1023 |
|
|
@@ -1431,11 +1444,25 @@ def on_image_click(
|
|
| 1431 |
return update_frame_display(state, int(frame_idx))
|
| 1432 |
|
| 1433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1434 |
@spaces.GPU()
|
| 1435 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 1436 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 1437 |
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
| 1438 |
-
detect_btn_update, propagate_player_update =
|
| 1439 |
return (
|
| 1440 |
GLOBAL_STATE,
|
| 1441 |
"Load a video first.",
|
|
@@ -1443,6 +1470,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1443 |
_build_kick_plot(GLOBAL_STATE),
|
| 1444 |
_format_impact_status(GLOBAL_STATE),
|
| 1445 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 1446 |
detect_btn_update,
|
| 1447 |
propagate_player_update,
|
| 1448 |
)
|
|
@@ -1459,7 +1487,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1459 |
processed = 0
|
| 1460 |
|
| 1461 |
# Initial status; no slider change yet
|
| 1462 |
-
detect_btn_update, propagate_player_update =
|
| 1463 |
yield (
|
| 1464 |
GLOBAL_STATE,
|
| 1465 |
f"Propagating masks: {processed}/{total}",
|
|
@@ -1467,6 +1495,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1467 |
_build_kick_plot(GLOBAL_STATE),
|
| 1468 |
_format_impact_status(GLOBAL_STATE),
|
| 1469 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 1470 |
detect_btn_update,
|
| 1471 |
propagate_player_update,
|
| 1472 |
)
|
|
@@ -1496,7 +1525,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1496 |
processed += 1
|
| 1497 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 1498 |
if processed % 30 == 0 or processed == total:
|
| 1499 |
-
detect_btn_update, propagate_player_update =
|
| 1500 |
yield (
|
| 1501 |
GLOBAL_STATE,
|
| 1502 |
f"Propagating masks: {processed}/{total}",
|
|
@@ -1504,6 +1533,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1504 |
_build_kick_plot(GLOBAL_STATE),
|
| 1505 |
_format_impact_status(GLOBAL_STATE),
|
| 1506 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 1507 |
detect_btn_update,
|
| 1508 |
propagate_player_update,
|
| 1509 |
)
|
|
@@ -1518,7 +1548,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1518 |
GLOBAL_STATE.current_frame_idx = target_frame
|
| 1519 |
|
| 1520 |
# Final status; ensure slider points to the target frame (kick frame when detected)
|
| 1521 |
-
detect_btn_update, propagate_player_update =
|
| 1522 |
yield (
|
| 1523 |
GLOBAL_STATE,
|
| 1524 |
text,
|
|
@@ -1526,16 +1556,17 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 1526 |
_build_kick_plot(GLOBAL_STATE),
|
| 1527 |
_format_impact_status(GLOBAL_STATE),
|
| 1528 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 1529 |
detect_btn_update,
|
| 1530 |
propagate_player_update,
|
| 1531 |
)
|
| 1532 |
|
| 1533 |
|
| 1534 |
-
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any]:
|
| 1535 |
# Reset only session-related state, keep uploaded video and model
|
| 1536 |
if not GLOBAL_STATE.video_frames:
|
| 1537 |
# Nothing loaded; keep behavior
|
| 1538 |
-
detect_btn_update, propagate_player_update =
|
| 1539 |
return (
|
| 1540 |
GLOBAL_STATE,
|
| 1541 |
None,
|
|
@@ -1545,6 +1576,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 1545 |
gr.update(visible=False, value=""),
|
| 1546 |
_build_kick_plot(GLOBAL_STATE),
|
| 1547 |
_format_impact_status(GLOBAL_STATE),
|
|
|
|
| 1548 |
detect_btn_update,
|
| 1549 |
propagate_player_update,
|
| 1550 |
)
|
|
@@ -1604,7 +1636,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 1604 |
slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
|
| 1605 |
slider_value = gr.update(value=current_idx)
|
| 1606 |
status = "Session reset. Prompts cleared; video preserved."
|
| 1607 |
-
detect_btn_update, propagate_player_update =
|
| 1608 |
# clear and reload model and processor
|
| 1609 |
return (
|
| 1610 |
GLOBAL_STATE,
|
|
@@ -1615,6 +1647,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 1615 |
gr.update(visible=False, value=""),
|
| 1616 |
_build_kick_plot(GLOBAL_STATE),
|
| 1617 |
_format_impact_status(GLOBAL_STATE),
|
|
|
|
| 1618 |
detect_btn_update,
|
| 1619 |
propagate_player_update,
|
| 1620 |
)
|
|
@@ -1918,7 +1951,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1918 |
)
|
| 1919 |
with gr.Row():
|
| 1920 |
detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
|
| 1921 |
-
propagate_btn = gr.Button("Propagate across video", variant="primary")
|
| 1922 |
detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
|
| 1923 |
propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
|
| 1924 |
ball_status = gr.Markdown(visible=False)
|
|
@@ -1934,7 +1967,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1934 |
# Wire events
|
| 1935 |
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
| 1936 |
GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
|
| 1937 |
-
detect_btn_update, propagate_player_update =
|
| 1938 |
return (
|
| 1939 |
GLOBAL_STATE,
|
| 1940 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
|
@@ -1943,6 +1976,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1943 |
gr.update(visible=False, value=""),
|
| 1944 |
_build_kick_plot(GLOBAL_STATE),
|
| 1945 |
_format_impact_status(GLOBAL_STATE),
|
|
|
|
| 1946 |
detect_btn_update,
|
| 1947 |
propagate_player_update,
|
| 1948 |
)
|
|
@@ -1950,7 +1984,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1950 |
video_in.change(
|
| 1951 |
_on_video_change,
|
| 1952 |
inputs=[GLOBAL_STATE, video_in],
|
| 1953 |
-
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, detect_player_btn, propagate_player_btn],
|
| 1954 |
show_progress=True,
|
| 1955 |
)
|
| 1956 |
|
|
@@ -1963,7 +1997,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1963 |
examples=examples_list,
|
| 1964 |
inputs=[GLOBAL_STATE, video_in],
|
| 1965 |
fn=_on_video_change,
|
| 1966 |
-
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, detect_player_btn, propagate_player_btn],
|
| 1967 |
label="Examples",
|
| 1968 |
cache_examples=False,
|
| 1969 |
examples_per_page=5,
|
|
@@ -2046,11 +2080,12 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2046 |
if s is not None and val is not None:
|
| 2047 |
s.min_impact_speed_kmh = float(val)
|
| 2048 |
_recompute_motion_metrics(s)
|
| 2049 |
-
detect_btn_update, propagate_player_update =
|
| 2050 |
return (
|
| 2051 |
_build_kick_plot(s),
|
| 2052 |
_format_impact_status(s),
|
| 2053 |
gr.update(value=_format_kick_status(s), visible=True),
|
|
|
|
| 2054 |
detect_btn_update,
|
| 2055 |
propagate_player_update,
|
| 2056 |
)
|
|
@@ -2059,11 +2094,12 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2059 |
if s is not None and val is not None:
|
| 2060 |
s.goal_distance_m = float(val)
|
| 2061 |
_recompute_motion_metrics(s)
|
| 2062 |
-
detect_btn_update, propagate_player_update =
|
| 2063 |
return (
|
| 2064 |
_build_kick_plot(s),
|
| 2065 |
_format_impact_status(s),
|
| 2066 |
gr.update(value=_format_kick_status(s), visible=True),
|
|
|
|
| 2067 |
detect_btn_update,
|
| 2068 |
propagate_player_update,
|
| 2069 |
)
|
|
@@ -2071,13 +2107,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2071 |
min_impact_speed_slider.change(
|
| 2072 |
_update_min_impact_speed,
|
| 2073 |
inputs=[GLOBAL_STATE, min_impact_speed_slider],
|
| 2074 |
-
outputs=[kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
|
| 2075 |
)
|
| 2076 |
|
| 2077 |
goal_distance_slider.change(
|
| 2078 |
_update_goal_distance,
|
| 2079 |
inputs=[GLOBAL_STATE, goal_distance_slider],
|
| 2080 |
-
outputs=[kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
|
| 2081 |
)
|
| 2082 |
|
| 2083 |
def _auto_detect_ball(
|
|
@@ -2093,7 +2129,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2093 |
frame = state_in.video_frames[frame_idx]
|
| 2094 |
detection = detect_ball_center(frame)
|
| 2095 |
if detection is None:
|
| 2096 |
-
detect_btn_update, propagate_player_update =
|
| 2097 |
return (
|
| 2098 |
update_frame_display(state_in, frame_idx),
|
| 2099 |
gr.update(
|
|
@@ -2102,6 +2138,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2102 |
),
|
| 2103 |
gr.update(value=frame_idx),
|
| 2104 |
_build_kick_plot(state_in),
|
|
|
|
| 2105 |
detect_btn_update,
|
| 2106 |
propagate_player_update,
|
| 2107 |
)
|
|
@@ -2133,12 +2170,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2133 |
|
| 2134 |
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
|
| 2135 |
status_text += f" | {_format_kick_status(state_in)}"
|
| 2136 |
-
detect_btn_update, propagate_player_update =
|
| 2137 |
return (
|
| 2138 |
preview_img,
|
| 2139 |
gr.update(value=status_text, visible=True),
|
| 2140 |
gr.update(value=frame_idx),
|
| 2141 |
_build_kick_plot(state_in),
|
|
|
|
| 2142 |
detect_btn_update,
|
| 2143 |
propagate_player_update,
|
| 2144 |
)
|
|
@@ -2146,7 +2184,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2146 |
detect_ball_btn.click(
|
| 2147 |
_auto_detect_ball,
|
| 2148 |
inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
|
| 2149 |
-
outputs=[preview, ball_status, frame_slider, kick_plot, detect_player_btn, propagate_player_btn],
|
| 2150 |
)
|
| 2151 |
|
| 2152 |
def _auto_detect_player(state_in: AppState):
|
|
@@ -2163,7 +2201,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2163 |
frame = state_in.video_frames[frame_idx]
|
| 2164 |
detection = detect_person_box(frame)
|
| 2165 |
if detection is None:
|
| 2166 |
-
detect_btn_update, propagate_player_update =
|
| 2167 |
status_text = (
|
| 2168 |
f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. "
|
| 2169 |
"Please add a box manually."
|
|
@@ -2173,6 +2211,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2173 |
gr.update(value=status_text, visible=True),
|
| 2174 |
gr.update(value=frame_idx),
|
| 2175 |
_build_kick_plot(state_in),
|
|
|
|
| 2176 |
detect_btn_update,
|
| 2177 |
propagate_player_update,
|
| 2178 |
gr.update(),
|
|
@@ -2234,7 +2273,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2234 |
state_in.composited_frames.pop(frame_idx, None)
|
| 2235 |
state_in.current_frame_idx = frame_idx
|
| 2236 |
|
| 2237 |
-
detect_btn_update, propagate_player_update =
|
| 2238 |
status_text = (
|
| 2239 |
f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})"
|
| 2240 |
)
|
|
@@ -2243,6 +2282,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2243 |
gr.update(value=status_text, visible=True),
|
| 2244 |
gr.update(value=frame_idx),
|
| 2245 |
_build_kick_plot(state_in),
|
|
|
|
| 2246 |
detect_btn_update,
|
| 2247 |
propagate_player_update,
|
| 2248 |
gr.update(value=PLAYER_OBJECT_ID),
|
|
@@ -2251,13 +2291,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2251 |
detect_player_btn.click(
|
| 2252 |
_auto_detect_player,
|
| 2253 |
inputs=[GLOBAL_STATE],
|
| 2254 |
-
outputs=[preview, ball_status, frame_slider, kick_plot, detect_player_btn, propagate_player_btn, obj_id_inp],
|
| 2255 |
)
|
| 2256 |
|
| 2257 |
@spaces.GPU()
|
| 2258 |
def propagate_player_masks(GLOBAL_STATE: gr.State):
|
| 2259 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 2260 |
-
detect_btn_update, propagate_player_update =
|
| 2261 |
return (
|
| 2262 |
GLOBAL_STATE,
|
| 2263 |
"Load a video first.",
|
|
@@ -2265,11 +2305,12 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2265 |
_build_kick_plot(GLOBAL_STATE),
|
| 2266 |
_format_impact_status(GLOBAL_STATE),
|
| 2267 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 2268 |
detect_btn_update,
|
| 2269 |
propagate_player_update,
|
| 2270 |
)
|
| 2271 |
if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE):
|
| 2272 |
-
detect_btn_update, propagate_player_update =
|
| 2273 |
return (
|
| 2274 |
GLOBAL_STATE,
|
| 2275 |
"Detect the player before propagating.",
|
|
@@ -2277,6 +2318,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2277 |
_build_kick_plot(GLOBAL_STATE),
|
| 2278 |
_format_impact_status(GLOBAL_STATE),
|
| 2279 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 2280 |
detect_btn_update,
|
| 2281 |
propagate_player_update,
|
| 2282 |
)
|
|
@@ -2291,7 +2333,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2291 |
total = max(1, GLOBAL_STATE.num_frames)
|
| 2292 |
processed = 0
|
| 2293 |
|
| 2294 |
-
detect_btn_update, propagate_player_update =
|
| 2295 |
yield (
|
| 2296 |
GLOBAL_STATE,
|
| 2297 |
f"Propagating player: {processed}/{total}",
|
|
@@ -2299,6 +2341,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2299 |
_build_kick_plot(GLOBAL_STATE),
|
| 2300 |
_format_impact_status(GLOBAL_STATE),
|
| 2301 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 2302 |
detect_btn_update,
|
| 2303 |
propagate_player_update,
|
| 2304 |
)
|
|
@@ -2333,7 +2376,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2333 |
|
| 2334 |
processed += 1
|
| 2335 |
if processed % 30 == 0 or processed == total:
|
| 2336 |
-
detect_btn_update, propagate_player_update =
|
| 2337 |
yield (
|
| 2338 |
GLOBAL_STATE,
|
| 2339 |
f"Propagating player: {processed}/{total}",
|
|
@@ -2341,6 +2384,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2341 |
_build_kick_plot(GLOBAL_STATE),
|
| 2342 |
_format_impact_status(GLOBAL_STATE),
|
| 2343 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 2344 |
detect_btn_update,
|
| 2345 |
propagate_player_update,
|
| 2346 |
)
|
|
@@ -2354,7 +2398,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2354 |
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
|
| 2355 |
GLOBAL_STATE.current_frame_idx = target_frame
|
| 2356 |
|
| 2357 |
-
detect_btn_update, propagate_player_update =
|
| 2358 |
yield (
|
| 2359 |
GLOBAL_STATE,
|
| 2360 |
text,
|
|
@@ -2362,6 +2406,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2362 |
_build_kick_plot(GLOBAL_STATE),
|
| 2363 |
_format_impact_status(GLOBAL_STATE),
|
| 2364 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
|
|
|
| 2365 |
detect_btn_update,
|
| 2366 |
propagate_player_update,
|
| 2367 |
)
|
|
@@ -2369,12 +2414,14 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2369 |
propagate_player_btn.click(
|
| 2370 |
propagate_player_masks,
|
| 2371 |
inputs=[GLOBAL_STATE],
|
| 2372 |
-
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
|
| 2373 |
)
|
| 2374 |
|
| 2375 |
# Image click to add a point and run forward on that frame
|
| 2376 |
preview.select(
|
| 2377 |
-
|
|
|
|
|
|
|
| 2378 |
)
|
| 2379 |
|
| 2380 |
# Playback via MP4 rendering only
|
|
@@ -2436,13 +2483,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 2436 |
propagate_btn.click(
|
| 2437 |
propagate_masks,
|
| 2438 |
inputs=[GLOBAL_STATE],
|
| 2439 |
-
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
|
| 2440 |
)
|
| 2441 |
|
| 2442 |
reset_btn.click(
|
| 2443 |
reset_session,
|
| 2444 |
inputs=GLOBAL_STATE,
|
| 2445 |
-
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, impact_status, detect_player_btn, propagate_player_btn],
|
| 2446 |
)
|
| 2447 |
|
| 2448 |
# ============================================================================
|
|
|
|
| 47 |
YOLO_IOU_THRESHOLD = 0.02
|
| 48 |
PLAYER_TARGET_NAME = "person"
|
| 49 |
PLAYER_OBJECT_ID = 2
|
| 50 |
+
BALL_OBJECT_ID = 1
|
| 51 |
|
| 52 |
|
| 53 |
def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO:
|
|
|
|
| 998 |
return f"Kick frame ≈ {frame}{time_part}"
|
| 999 |
|
| 1000 |
|
| 1001 |
+
def _ball_has_masks(state: AppState, target_obj_id: int = BALL_OBJECT_ID) -> bool:
|
| 1002 |
+
if state is None:
|
| 1003 |
+
return False
|
| 1004 |
+
for masks in state.masks_by_frame.values():
|
| 1005 |
+
if target_obj_id in masks:
|
| 1006 |
+
return True
|
| 1007 |
+
return False
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
def _player_has_masks(state: AppState) -> bool:
|
| 1011 |
if state is None or state.player_obj_id is None:
|
| 1012 |
return False
|
|
|
|
| 1017 |
return False
|
| 1018 |
|
| 1019 |
|
| 1020 |
+
def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
|
| 1021 |
+
propagate_main_enabled = _ball_has_masks(state)
|
| 1022 |
+
detect_player_enabled = False
|
| 1023 |
+
propagate_player_enabled = False
|
| 1024 |
if isinstance(state, AppState):
|
| 1025 |
kick_candidate = state.kick_frame or getattr(state, "kick_debug_kick_frame", None)
|
| 1026 |
if kick_candidate is not None:
|
| 1027 |
+
detect_player_enabled = True
|
| 1028 |
+
if detect_player_enabled and _player_has_masks(state):
|
| 1029 |
+
propagate_player_enabled = True
|
| 1030 |
return (
|
| 1031 |
+
gr.update(interactive=propagate_main_enabled),
|
| 1032 |
+
gr.update(interactive=detect_player_enabled),
|
| 1033 |
+
gr.update(interactive=propagate_player_enabled),
|
| 1034 |
)
|
| 1035 |
|
| 1036 |
|
|
|
|
| 1444 |
return update_frame_display(state, int(frame_idx))
|
| 1445 |
|
| 1446 |
|
| 1447 |
+
def _on_image_click_with_updates(
|
| 1448 |
+
img: Image.Image | np.ndarray,
|
| 1449 |
+
state: AppState,
|
| 1450 |
+
frame_idx: int,
|
| 1451 |
+
obj_id: int,
|
| 1452 |
+
label: str,
|
| 1453 |
+
clear_old: bool,
|
| 1454 |
+
evt: gr.SelectData,
|
| 1455 |
+
):
|
| 1456 |
+
preview_img = on_image_click(img, state, frame_idx, obj_id, label, clear_old, evt)
|
| 1457 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
|
| 1458 |
+
return preview_img, propagate_main_update, detect_btn_update, propagate_player_update
|
| 1459 |
+
|
| 1460 |
+
|
| 1461 |
@spaces.GPU()
|
| 1462 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 1463 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 1464 |
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
| 1465 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1466 |
return (
|
| 1467 |
GLOBAL_STATE,
|
| 1468 |
"Load a video first.",
|
|
|
|
| 1470 |
_build_kick_plot(GLOBAL_STATE),
|
| 1471 |
_format_impact_status(GLOBAL_STATE),
|
| 1472 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1473 |
+
propagate_main_update,
|
| 1474 |
detect_btn_update,
|
| 1475 |
propagate_player_update,
|
| 1476 |
)
|
|
|
|
| 1487 |
processed = 0
|
| 1488 |
|
| 1489 |
# Initial status; no slider change yet
|
| 1490 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1491 |
yield (
|
| 1492 |
GLOBAL_STATE,
|
| 1493 |
f"Propagating masks: {processed}/{total}",
|
|
|
|
| 1495 |
_build_kick_plot(GLOBAL_STATE),
|
| 1496 |
_format_impact_status(GLOBAL_STATE),
|
| 1497 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1498 |
+
propagate_main_update,
|
| 1499 |
detect_btn_update,
|
| 1500 |
propagate_player_update,
|
| 1501 |
)
|
|
|
|
| 1525 |
processed += 1
|
| 1526 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 1527 |
if processed % 30 == 0 or processed == total:
|
| 1528 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1529 |
yield (
|
| 1530 |
GLOBAL_STATE,
|
| 1531 |
f"Propagating masks: {processed}/{total}",
|
|
|
|
| 1533 |
_build_kick_plot(GLOBAL_STATE),
|
| 1534 |
_format_impact_status(GLOBAL_STATE),
|
| 1535 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1536 |
+
propagate_main_update,
|
| 1537 |
detect_btn_update,
|
| 1538 |
propagate_player_update,
|
| 1539 |
)
|
|
|
|
| 1548 |
GLOBAL_STATE.current_frame_idx = target_frame
|
| 1549 |
|
| 1550 |
# Final status; ensure slider points to the target frame (kick frame when detected)
|
| 1551 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1552 |
yield (
|
| 1553 |
GLOBAL_STATE,
|
| 1554 |
text,
|
|
|
|
| 1556 |
_build_kick_plot(GLOBAL_STATE),
|
| 1557 |
_format_impact_status(GLOBAL_STATE),
|
| 1558 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 1559 |
+
propagate_main_update,
|
| 1560 |
detect_btn_update,
|
| 1561 |
propagate_player_update,
|
| 1562 |
)
|
| 1563 |
|
| 1564 |
|
| 1565 |
+
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any, Any]:
|
| 1566 |
# Reset only session-related state, keep uploaded video and model
|
| 1567 |
if not GLOBAL_STATE.video_frames:
|
| 1568 |
# Nothing loaded; keep behavior
|
| 1569 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1570 |
return (
|
| 1571 |
GLOBAL_STATE,
|
| 1572 |
None,
|
|
|
|
| 1576 |
gr.update(visible=False, value=""),
|
| 1577 |
_build_kick_plot(GLOBAL_STATE),
|
| 1578 |
_format_impact_status(GLOBAL_STATE),
|
| 1579 |
+
propagate_main_update,
|
| 1580 |
detect_btn_update,
|
| 1581 |
propagate_player_update,
|
| 1582 |
)
|
|
|
|
| 1636 |
slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
|
| 1637 |
slider_value = gr.update(value=current_idx)
|
| 1638 |
status = "Session reset. Prompts cleared; video preserved."
|
| 1639 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1640 |
# clear and reload model and processor
|
| 1641 |
return (
|
| 1642 |
GLOBAL_STATE,
|
|
|
|
| 1647 |
gr.update(visible=False, value=""),
|
| 1648 |
_build_kick_plot(GLOBAL_STATE),
|
| 1649 |
_format_impact_status(GLOBAL_STATE),
|
| 1650 |
+
propagate_main_update,
|
| 1651 |
detect_btn_update,
|
| 1652 |
propagate_player_update,
|
| 1653 |
)
|
|
|
|
| 1951 |
)
|
| 1952 |
with gr.Row():
|
| 1953 |
detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
|
| 1954 |
+
propagate_btn = gr.Button("Propagate across video", variant="primary", interactive=False)
|
| 1955 |
detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
|
| 1956 |
propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
|
| 1957 |
ball_status = gr.Markdown(visible=False)
|
|
|
|
| 1967 |
# Wire events
|
| 1968 |
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
| 1969 |
GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
|
| 1970 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 1971 |
return (
|
| 1972 |
GLOBAL_STATE,
|
| 1973 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
|
|
|
| 1976 |
gr.update(visible=False, value=""),
|
| 1977 |
_build_kick_plot(GLOBAL_STATE),
|
| 1978 |
_format_impact_status(GLOBAL_STATE),
|
| 1979 |
+
propagate_main_update,
|
| 1980 |
detect_btn_update,
|
| 1981 |
propagate_player_update,
|
| 1982 |
)
|
|
|
|
| 1984 |
video_in.change(
|
| 1985 |
_on_video_change,
|
| 1986 |
inputs=[GLOBAL_STATE, video_in],
|
| 1987 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 1988 |
show_progress=True,
|
| 1989 |
)
|
| 1990 |
|
|
|
|
| 1997 |
examples=examples_list,
|
| 1998 |
inputs=[GLOBAL_STATE, video_in],
|
| 1999 |
fn=_on_video_change,
|
| 2000 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2001 |
label="Examples",
|
| 2002 |
cache_examples=False,
|
| 2003 |
examples_per_page=5,
|
|
|
|
| 2080 |
if s is not None and val is not None:
|
| 2081 |
s.min_impact_speed_kmh = float(val)
|
| 2082 |
_recompute_motion_metrics(s)
|
| 2083 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s)
|
| 2084 |
return (
|
| 2085 |
_build_kick_plot(s),
|
| 2086 |
_format_impact_status(s),
|
| 2087 |
gr.update(value=_format_kick_status(s), visible=True),
|
| 2088 |
+
propagate_main_update,
|
| 2089 |
detect_btn_update,
|
| 2090 |
propagate_player_update,
|
| 2091 |
)
|
|
|
|
| 2094 |
if s is not None and val is not None:
|
| 2095 |
s.goal_distance_m = float(val)
|
| 2096 |
_recompute_motion_metrics(s)
|
| 2097 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s)
|
| 2098 |
return (
|
| 2099 |
_build_kick_plot(s),
|
| 2100 |
_format_impact_status(s),
|
| 2101 |
gr.update(value=_format_kick_status(s), visible=True),
|
| 2102 |
+
propagate_main_update,
|
| 2103 |
detect_btn_update,
|
| 2104 |
propagate_player_update,
|
| 2105 |
)
|
|
|
|
| 2107 |
min_impact_speed_slider.change(
|
| 2108 |
_update_min_impact_speed,
|
| 2109 |
inputs=[GLOBAL_STATE, min_impact_speed_slider],
|
| 2110 |
+
outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2111 |
)
|
| 2112 |
|
| 2113 |
goal_distance_slider.change(
|
| 2114 |
_update_goal_distance,
|
| 2115 |
inputs=[GLOBAL_STATE, goal_distance_slider],
|
| 2116 |
+
outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2117 |
)
|
| 2118 |
|
| 2119 |
def _auto_detect_ball(
|
|
|
|
| 2129 |
frame = state_in.video_frames[frame_idx]
|
| 2130 |
detection = detect_ball_center(frame)
|
| 2131 |
if detection is None:
|
| 2132 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 2133 |
return (
|
| 2134 |
update_frame_display(state_in, frame_idx),
|
| 2135 |
gr.update(
|
|
|
|
| 2138 |
),
|
| 2139 |
gr.update(value=frame_idx),
|
| 2140 |
_build_kick_plot(state_in),
|
| 2141 |
+
propagate_main_update,
|
| 2142 |
detect_btn_update,
|
| 2143 |
propagate_player_update,
|
| 2144 |
)
|
|
|
|
| 2170 |
|
| 2171 |
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
|
| 2172 |
status_text += f" | {_format_kick_status(state_in)}"
|
| 2173 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 2174 |
return (
|
| 2175 |
preview_img,
|
| 2176 |
gr.update(value=status_text, visible=True),
|
| 2177 |
gr.update(value=frame_idx),
|
| 2178 |
_build_kick_plot(state_in),
|
| 2179 |
+
propagate_main_update,
|
| 2180 |
detect_btn_update,
|
| 2181 |
propagate_player_update,
|
| 2182 |
)
|
|
|
|
| 2184 |
detect_ball_btn.click(
|
| 2185 |
_auto_detect_ball,
|
| 2186 |
inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
|
| 2187 |
+
outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2188 |
)
|
| 2189 |
|
| 2190 |
def _auto_detect_player(state_in: AppState):
|
|
|
|
| 2201 |
frame = state_in.video_frames[frame_idx]
|
| 2202 |
detection = detect_person_box(frame)
|
| 2203 |
if detection is None:
|
| 2204 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 2205 |
status_text = (
|
| 2206 |
f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. "
|
| 2207 |
"Please add a box manually."
|
|
|
|
| 2211 |
gr.update(value=status_text, visible=True),
|
| 2212 |
gr.update(value=frame_idx),
|
| 2213 |
_build_kick_plot(state_in),
|
| 2214 |
+
propagate_main_update,
|
| 2215 |
detect_btn_update,
|
| 2216 |
propagate_player_update,
|
| 2217 |
gr.update(),
|
|
|
|
| 2273 |
state_in.composited_frames.pop(frame_idx, None)
|
| 2274 |
state_in.current_frame_idx = frame_idx
|
| 2275 |
|
| 2276 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
|
| 2277 |
status_text = (
|
| 2278 |
f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})"
|
| 2279 |
)
|
|
|
|
| 2282 |
gr.update(value=status_text, visible=True),
|
| 2283 |
gr.update(value=frame_idx),
|
| 2284 |
_build_kick_plot(state_in),
|
| 2285 |
+
propagate_main_update,
|
| 2286 |
detect_btn_update,
|
| 2287 |
propagate_player_update,
|
| 2288 |
gr.update(value=PLAYER_OBJECT_ID),
|
|
|
|
| 2291 |
detect_player_btn.click(
|
| 2292 |
_auto_detect_player,
|
| 2293 |
inputs=[GLOBAL_STATE],
|
| 2294 |
+
outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, obj_id_inp],
|
| 2295 |
)
|
| 2296 |
|
| 2297 |
@spaces.GPU()
|
| 2298 |
def propagate_player_masks(GLOBAL_STATE: gr.State):
|
| 2299 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 2300 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2301 |
return (
|
| 2302 |
GLOBAL_STATE,
|
| 2303 |
"Load a video first.",
|
|
|
|
| 2305 |
_build_kick_plot(GLOBAL_STATE),
|
| 2306 |
_format_impact_status(GLOBAL_STATE),
|
| 2307 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2308 |
+
propagate_main_update,
|
| 2309 |
detect_btn_update,
|
| 2310 |
propagate_player_update,
|
| 2311 |
)
|
| 2312 |
if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE):
|
| 2313 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2314 |
return (
|
| 2315 |
GLOBAL_STATE,
|
| 2316 |
"Detect the player before propagating.",
|
|
|
|
| 2318 |
_build_kick_plot(GLOBAL_STATE),
|
| 2319 |
_format_impact_status(GLOBAL_STATE),
|
| 2320 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2321 |
+
propagate_main_update,
|
| 2322 |
detect_btn_update,
|
| 2323 |
propagate_player_update,
|
| 2324 |
)
|
|
|
|
| 2333 |
total = max(1, GLOBAL_STATE.num_frames)
|
| 2334 |
processed = 0
|
| 2335 |
|
| 2336 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2337 |
yield (
|
| 2338 |
GLOBAL_STATE,
|
| 2339 |
f"Propagating player: {processed}/{total}",
|
|
|
|
| 2341 |
_build_kick_plot(GLOBAL_STATE),
|
| 2342 |
_format_impact_status(GLOBAL_STATE),
|
| 2343 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2344 |
+
propagate_main_update,
|
| 2345 |
detect_btn_update,
|
| 2346 |
propagate_player_update,
|
| 2347 |
)
|
|
|
|
| 2376 |
|
| 2377 |
processed += 1
|
| 2378 |
if processed % 30 == 0 or processed == total:
|
| 2379 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2380 |
yield (
|
| 2381 |
GLOBAL_STATE,
|
| 2382 |
f"Propagating player: {processed}/{total}",
|
|
|
|
| 2384 |
_build_kick_plot(GLOBAL_STATE),
|
| 2385 |
_format_impact_status(GLOBAL_STATE),
|
| 2386 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2387 |
+
propagate_main_update,
|
| 2388 |
detect_btn_update,
|
| 2389 |
propagate_player_update,
|
| 2390 |
)
|
|
|
|
| 2398 |
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
|
| 2399 |
GLOBAL_STATE.current_frame_idx = target_frame
|
| 2400 |
|
| 2401 |
+
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
|
| 2402 |
yield (
|
| 2403 |
GLOBAL_STATE,
|
| 2404 |
text,
|
|
|
|
| 2406 |
_build_kick_plot(GLOBAL_STATE),
|
| 2407 |
_format_impact_status(GLOBAL_STATE),
|
| 2408 |
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
|
| 2409 |
+
propagate_main_update,
|
| 2410 |
detect_btn_update,
|
| 2411 |
propagate_player_update,
|
| 2412 |
)
|
|
|
|
| 2414 |
propagate_player_btn.click(
|
| 2415 |
propagate_player_masks,
|
| 2416 |
inputs=[GLOBAL_STATE],
|
| 2417 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2418 |
)
|
| 2419 |
|
| 2420 |
# Image click to add a point and run forward on that frame
|
| 2421 |
preview.select(
|
| 2422 |
+
_on_image_click_with_updates,
|
| 2423 |
+
[preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk],
|
| 2424 |
+
[preview, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2425 |
)
|
| 2426 |
|
| 2427 |
# Playback via MP4 rendering only
|
|
|
|
| 2483 |
propagate_btn.click(
|
| 2484 |
propagate_masks,
|
| 2485 |
inputs=[GLOBAL_STATE],
|
| 2486 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2487 |
)
|
| 2488 |
|
| 2489 |
reset_btn.click(
|
| 2490 |
reset_session,
|
| 2491 |
inputs=GLOBAL_STATE,
|
| 2492 |
+
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
|
| 2493 |
)
|
| 2494 |
|
| 2495 |
# ============================================================================
|