Mirko Trasciatti commited on
Commit
2feeac4
·
1 Parent(s): 707381b

Disable propagate buttons until detections

Browse files
Files changed (1) hide show
  1. app.py +84 -37
app.py CHANGED
@@ -47,6 +47,7 @@ YOLO_CONF_THRESHOLD = 0.0
47
  YOLO_IOU_THRESHOLD = 0.02
48
  PLAYER_TARGET_NAME = "person"
49
  PLAYER_OBJECT_ID = 2
 
50
 
51
 
52
  def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO:
@@ -997,6 +998,15 @@ def _format_kick_status(state: AppState) -> str:
997
  return f"Kick frame ≈ {frame}{time_part}"
998
 
999
 
 
 
 
 
 
 
 
 
 
1000
  def _player_has_masks(state: AppState) -> bool:
1001
  if state is None or state.player_obj_id is None:
1002
  return False
@@ -1007,17 +1017,20 @@ def _player_has_masks(state: AppState) -> bool:
1007
  return False
1008
 
1009
 
1010
- def _player_button_updates(state: AppState) -> tuple[Any, Any]:
1011
- detect_enabled = False
1012
- propagate_enabled = False
 
1013
  if isinstance(state, AppState):
1014
  kick_candidate = state.kick_frame or getattr(state, "kick_debug_kick_frame", None)
1015
  if kick_candidate is not None:
1016
- detect_enabled = True
1017
- propagate_enabled = _player_has_masks(state)
 
1018
  return (
1019
- gr.update(interactive=detect_enabled),
1020
- gr.update(interactive=propagate_enabled),
 
1021
  )
1022
 
1023
 
@@ -1431,11 +1444,25 @@ def on_image_click(
1431
  return update_frame_display(state, int(frame_idx))
1432
 
1433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434
  @spaces.GPU()
1435
  def propagate_masks(GLOBAL_STATE: gr.State):
1436
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
1437
  # yield GLOBAL_STATE, "Load a video first.", gr.update()
1438
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1439
  return (
1440
  GLOBAL_STATE,
1441
  "Load a video first.",
@@ -1443,6 +1470,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1443
  _build_kick_plot(GLOBAL_STATE),
1444
  _format_impact_status(GLOBAL_STATE),
1445
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
1446
  detect_btn_update,
1447
  propagate_player_update,
1448
  )
@@ -1459,7 +1487,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1459
  processed = 0
1460
 
1461
  # Initial status; no slider change yet
1462
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1463
  yield (
1464
  GLOBAL_STATE,
1465
  f"Propagating masks: {processed}/{total}",
@@ -1467,6 +1495,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1467
  _build_kick_plot(GLOBAL_STATE),
1468
  _format_impact_status(GLOBAL_STATE),
1469
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
1470
  detect_btn_update,
1471
  propagate_player_update,
1472
  )
@@ -1496,7 +1525,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1496
  processed += 1
1497
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
1498
  if processed % 30 == 0 or processed == total:
1499
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1500
  yield (
1501
  GLOBAL_STATE,
1502
  f"Propagating masks: {processed}/{total}",
@@ -1504,6 +1533,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1504
  _build_kick_plot(GLOBAL_STATE),
1505
  _format_impact_status(GLOBAL_STATE),
1506
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
1507
  detect_btn_update,
1508
  propagate_player_update,
1509
  )
@@ -1518,7 +1548,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1518
  GLOBAL_STATE.current_frame_idx = target_frame
1519
 
1520
  # Final status; ensure slider points to the target frame (kick frame when detected)
1521
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1522
  yield (
1523
  GLOBAL_STATE,
1524
  text,
@@ -1526,16 +1556,17 @@ def propagate_masks(GLOBAL_STATE: gr.State):
1526
  _build_kick_plot(GLOBAL_STATE),
1527
  _format_impact_status(GLOBAL_STATE),
1528
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
1529
  detect_btn_update,
1530
  propagate_player_update,
1531
  )
1532
 
1533
 
1534
- def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any]:
1535
  # Reset only session-related state, keep uploaded video and model
1536
  if not GLOBAL_STATE.video_frames:
1537
  # Nothing loaded; keep behavior
1538
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1539
  return (
1540
  GLOBAL_STATE,
1541
  None,
@@ -1545,6 +1576,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
1545
  gr.update(visible=False, value=""),
1546
  _build_kick_plot(GLOBAL_STATE),
1547
  _format_impact_status(GLOBAL_STATE),
 
1548
  detect_btn_update,
1549
  propagate_player_update,
1550
  )
@@ -1604,7 +1636,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
1604
  slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
1605
  slider_value = gr.update(value=current_idx)
1606
  status = "Session reset. Prompts cleared; video preserved."
1607
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1608
  # clear and reload model and processor
1609
  return (
1610
  GLOBAL_STATE,
@@ -1615,6 +1647,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
1615
  gr.update(visible=False, value=""),
1616
  _build_kick_plot(GLOBAL_STATE),
1617
  _format_impact_status(GLOBAL_STATE),
 
1618
  detect_btn_update,
1619
  propagate_player_update,
1620
  )
@@ -1918,7 +1951,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1918
  )
1919
  with gr.Row():
1920
  detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
1921
- propagate_btn = gr.Button("Propagate across video", variant="primary")
1922
  detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
1923
  propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
1924
  ball_status = gr.Markdown(visible=False)
@@ -1934,7 +1967,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1934
  # Wire events
1935
  def _on_video_change(GLOBAL_STATE: gr.State, video):
1936
  GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
1937
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
1938
  return (
1939
  GLOBAL_STATE,
1940
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
@@ -1943,6 +1976,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1943
  gr.update(visible=False, value=""),
1944
  _build_kick_plot(GLOBAL_STATE),
1945
  _format_impact_status(GLOBAL_STATE),
 
1946
  detect_btn_update,
1947
  propagate_player_update,
1948
  )
@@ -1950,7 +1984,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1950
  video_in.change(
1951
  _on_video_change,
1952
  inputs=[GLOBAL_STATE, video_in],
1953
- outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, detect_player_btn, propagate_player_btn],
1954
  show_progress=True,
1955
  )
1956
 
@@ -1963,7 +1997,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
1963
  examples=examples_list,
1964
  inputs=[GLOBAL_STATE, video_in],
1965
  fn=_on_video_change,
1966
- outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, detect_player_btn, propagate_player_btn],
1967
  label="Examples",
1968
  cache_examples=False,
1969
  examples_per_page=5,
@@ -2046,11 +2080,12 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2046
  if s is not None and val is not None:
2047
  s.min_impact_speed_kmh = float(val)
2048
  _recompute_motion_metrics(s)
2049
- detect_btn_update, propagate_player_update = _player_button_updates(s)
2050
  return (
2051
  _build_kick_plot(s),
2052
  _format_impact_status(s),
2053
  gr.update(value=_format_kick_status(s), visible=True),
 
2054
  detect_btn_update,
2055
  propagate_player_update,
2056
  )
@@ -2059,11 +2094,12 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2059
  if s is not None and val is not None:
2060
  s.goal_distance_m = float(val)
2061
  _recompute_motion_metrics(s)
2062
- detect_btn_update, propagate_player_update = _player_button_updates(s)
2063
  return (
2064
  _build_kick_plot(s),
2065
  _format_impact_status(s),
2066
  gr.update(value=_format_kick_status(s), visible=True),
 
2067
  detect_btn_update,
2068
  propagate_player_update,
2069
  )
@@ -2071,13 +2107,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2071
  min_impact_speed_slider.change(
2072
  _update_min_impact_speed,
2073
  inputs=[GLOBAL_STATE, min_impact_speed_slider],
2074
- outputs=[kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
2075
  )
2076
 
2077
  goal_distance_slider.change(
2078
  _update_goal_distance,
2079
  inputs=[GLOBAL_STATE, goal_distance_slider],
2080
- outputs=[kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
2081
  )
2082
 
2083
  def _auto_detect_ball(
@@ -2093,7 +2129,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2093
  frame = state_in.video_frames[frame_idx]
2094
  detection = detect_ball_center(frame)
2095
  if detection is None:
2096
- detect_btn_update, propagate_player_update = _player_button_updates(state_in)
2097
  return (
2098
  update_frame_display(state_in, frame_idx),
2099
  gr.update(
@@ -2102,6 +2138,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2102
  ),
2103
  gr.update(value=frame_idx),
2104
  _build_kick_plot(state_in),
 
2105
  detect_btn_update,
2106
  propagate_player_update,
2107
  )
@@ -2133,12 +2170,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2133
 
2134
  status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
2135
  status_text += f" | {_format_kick_status(state_in)}"
2136
- detect_btn_update, propagate_player_update = _player_button_updates(state_in)
2137
  return (
2138
  preview_img,
2139
  gr.update(value=status_text, visible=True),
2140
  gr.update(value=frame_idx),
2141
  _build_kick_plot(state_in),
 
2142
  detect_btn_update,
2143
  propagate_player_update,
2144
  )
@@ -2146,7 +2184,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2146
  detect_ball_btn.click(
2147
  _auto_detect_ball,
2148
  inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
2149
- outputs=[preview, ball_status, frame_slider, kick_plot, detect_player_btn, propagate_player_btn],
2150
  )
2151
 
2152
  def _auto_detect_player(state_in: AppState):
@@ -2163,7 +2201,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2163
  frame = state_in.video_frames[frame_idx]
2164
  detection = detect_person_box(frame)
2165
  if detection is None:
2166
- detect_btn_update, propagate_player_update = _player_button_updates(state_in)
2167
  status_text = (
2168
  f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. "
2169
  "Please add a box manually."
@@ -2173,6 +2211,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2173
  gr.update(value=status_text, visible=True),
2174
  gr.update(value=frame_idx),
2175
  _build_kick_plot(state_in),
 
2176
  detect_btn_update,
2177
  propagate_player_update,
2178
  gr.update(),
@@ -2234,7 +2273,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2234
  state_in.composited_frames.pop(frame_idx, None)
2235
  state_in.current_frame_idx = frame_idx
2236
 
2237
- detect_btn_update, propagate_player_update = _player_button_updates(state_in)
2238
  status_text = (
2239
  f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})"
2240
  )
@@ -2243,6 +2282,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2243
  gr.update(value=status_text, visible=True),
2244
  gr.update(value=frame_idx),
2245
  _build_kick_plot(state_in),
 
2246
  detect_btn_update,
2247
  propagate_player_update,
2248
  gr.update(value=PLAYER_OBJECT_ID),
@@ -2251,13 +2291,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2251
  detect_player_btn.click(
2252
  _auto_detect_player,
2253
  inputs=[GLOBAL_STATE],
2254
- outputs=[preview, ball_status, frame_slider, kick_plot, detect_player_btn, propagate_player_btn, obj_id_inp],
2255
  )
2256
 
2257
  @spaces.GPU()
2258
  def propagate_player_masks(GLOBAL_STATE: gr.State):
2259
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
2260
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
2261
  return (
2262
  GLOBAL_STATE,
2263
  "Load a video first.",
@@ -2265,11 +2305,12 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2265
  _build_kick_plot(GLOBAL_STATE),
2266
  _format_impact_status(GLOBAL_STATE),
2267
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
2268
  detect_btn_update,
2269
  propagate_player_update,
2270
  )
2271
  if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE):
2272
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
2273
  return (
2274
  GLOBAL_STATE,
2275
  "Detect the player before propagating.",
@@ -2277,6 +2318,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2277
  _build_kick_plot(GLOBAL_STATE),
2278
  _format_impact_status(GLOBAL_STATE),
2279
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
2280
  detect_btn_update,
2281
  propagate_player_update,
2282
  )
@@ -2291,7 +2333,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2291
  total = max(1, GLOBAL_STATE.num_frames)
2292
  processed = 0
2293
 
2294
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
2295
  yield (
2296
  GLOBAL_STATE,
2297
  f"Propagating player: {processed}/{total}",
@@ -2299,6 +2341,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2299
  _build_kick_plot(GLOBAL_STATE),
2300
  _format_impact_status(GLOBAL_STATE),
2301
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
2302
  detect_btn_update,
2303
  propagate_player_update,
2304
  )
@@ -2333,7 +2376,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2333
 
2334
  processed += 1
2335
  if processed % 30 == 0 or processed == total:
2336
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
2337
  yield (
2338
  GLOBAL_STATE,
2339
  f"Propagating player: {processed}/{total}",
@@ -2341,6 +2384,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2341
  _build_kick_plot(GLOBAL_STATE),
2342
  _format_impact_status(GLOBAL_STATE),
2343
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
2344
  detect_btn_update,
2345
  propagate_player_update,
2346
  )
@@ -2354,7 +2398,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2354
  target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
2355
  GLOBAL_STATE.current_frame_idx = target_frame
2356
 
2357
- detect_btn_update, propagate_player_update = _player_button_updates(GLOBAL_STATE)
2358
  yield (
2359
  GLOBAL_STATE,
2360
  text,
@@ -2362,6 +2406,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2362
  _build_kick_plot(GLOBAL_STATE),
2363
  _format_impact_status(GLOBAL_STATE),
2364
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
 
2365
  detect_btn_update,
2366
  propagate_player_update,
2367
  )
@@ -2369,12 +2414,14 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2369
  propagate_player_btn.click(
2370
  propagate_player_masks,
2371
  inputs=[GLOBAL_STATE],
2372
- outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
2373
  )
2374
 
2375
  # Image click to add a point and run forward on that frame
2376
  preview.select(
2377
- on_image_click, [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], preview
 
 
2378
  )
2379
 
2380
  # Playback via MP4 rendering only
@@ -2436,13 +2483,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
2436
  propagate_btn.click(
2437
  propagate_masks,
2438
  inputs=[GLOBAL_STATE],
2439
- outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, detect_player_btn, propagate_player_btn],
2440
  )
2441
 
2442
  reset_btn.click(
2443
  reset_session,
2444
  inputs=GLOBAL_STATE,
2445
- outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, impact_status, detect_player_btn, propagate_player_btn],
2446
  )
2447
 
2448
  # ============================================================================
 
47
  YOLO_IOU_THRESHOLD = 0.02
48
  PLAYER_TARGET_NAME = "person"
49
  PLAYER_OBJECT_ID = 2
50
+ BALL_OBJECT_ID = 1
51
 
52
 
53
  def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO:
 
998
  return f"Kick frame ≈ {frame}{time_part}"
999
 
1000
 
1001
+ def _ball_has_masks(state: AppState, target_obj_id: int = BALL_OBJECT_ID) -> bool:
1002
+ if state is None:
1003
+ return False
1004
+ for masks in state.masks_by_frame.values():
1005
+ if target_obj_id in masks:
1006
+ return True
1007
+ return False
1008
+
1009
+
1010
  def _player_has_masks(state: AppState) -> bool:
1011
  if state is None or state.player_obj_id is None:
1012
  return False
 
1017
  return False
1018
 
1019
 
1020
+ def _button_updates(state: AppState) -> tuple[Any, Any, Any]:
1021
+ propagate_main_enabled = _ball_has_masks(state)
1022
+ detect_player_enabled = False
1023
+ propagate_player_enabled = False
1024
  if isinstance(state, AppState):
1025
  kick_candidate = state.kick_frame or getattr(state, "kick_debug_kick_frame", None)
1026
  if kick_candidate is not None:
1027
+ detect_player_enabled = True
1028
+ if detect_player_enabled and _player_has_masks(state):
1029
+ propagate_player_enabled = True
1030
  return (
1031
+ gr.update(interactive=propagate_main_enabled),
1032
+ gr.update(interactive=detect_player_enabled),
1033
+ gr.update(interactive=propagate_player_enabled),
1034
  )
1035
 
1036
 
 
1444
  return update_frame_display(state, int(frame_idx))
1445
 
1446
 
1447
+ def _on_image_click_with_updates(
1448
+ img: Image.Image | np.ndarray,
1449
+ state: AppState,
1450
+ frame_idx: int,
1451
+ obj_id: int,
1452
+ label: str,
1453
+ clear_old: bool,
1454
+ evt: gr.SelectData,
1455
+ ):
1456
+ preview_img = on_image_click(img, state, frame_idx, obj_id, label, clear_old, evt)
1457
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state)
1458
+ return preview_img, propagate_main_update, detect_btn_update, propagate_player_update
1459
+
1460
+
1461
  @spaces.GPU()
1462
  def propagate_masks(GLOBAL_STATE: gr.State):
1463
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
1464
  # yield GLOBAL_STATE, "Load a video first.", gr.update()
1465
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1466
  return (
1467
  GLOBAL_STATE,
1468
  "Load a video first.",
 
1470
  _build_kick_plot(GLOBAL_STATE),
1471
  _format_impact_status(GLOBAL_STATE),
1472
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1473
+ propagate_main_update,
1474
  detect_btn_update,
1475
  propagate_player_update,
1476
  )
 
1487
  processed = 0
1488
 
1489
  # Initial status; no slider change yet
1490
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1491
  yield (
1492
  GLOBAL_STATE,
1493
  f"Propagating masks: {processed}/{total}",
 
1495
  _build_kick_plot(GLOBAL_STATE),
1496
  _format_impact_status(GLOBAL_STATE),
1497
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1498
+ propagate_main_update,
1499
  detect_btn_update,
1500
  propagate_player_update,
1501
  )
 
1525
  processed += 1
1526
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
1527
  if processed % 30 == 0 or processed == total:
1528
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1529
  yield (
1530
  GLOBAL_STATE,
1531
  f"Propagating masks: {processed}/{total}",
 
1533
  _build_kick_plot(GLOBAL_STATE),
1534
  _format_impact_status(GLOBAL_STATE),
1535
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1536
+ propagate_main_update,
1537
  detect_btn_update,
1538
  propagate_player_update,
1539
  )
 
1548
  GLOBAL_STATE.current_frame_idx = target_frame
1549
 
1550
  # Final status; ensure slider points to the target frame (kick frame when detected)
1551
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1552
  yield (
1553
  GLOBAL_STATE,
1554
  text,
 
1556
  _build_kick_plot(GLOBAL_STATE),
1557
  _format_impact_status(GLOBAL_STATE),
1558
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
1559
+ propagate_main_update,
1560
  detect_btn_update,
1561
  propagate_player_update,
1562
  )
1563
 
1564
 
1565
+ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any, Any]:
1566
  # Reset only session-related state, keep uploaded video and model
1567
  if not GLOBAL_STATE.video_frames:
1568
  # Nothing loaded; keep behavior
1569
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1570
  return (
1571
  GLOBAL_STATE,
1572
  None,
 
1576
  gr.update(visible=False, value=""),
1577
  _build_kick_plot(GLOBAL_STATE),
1578
  _format_impact_status(GLOBAL_STATE),
1579
+ propagate_main_update,
1580
  detect_btn_update,
1581
  propagate_player_update,
1582
  )
 
1636
  slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
1637
  slider_value = gr.update(value=current_idx)
1638
  status = "Session reset. Prompts cleared; video preserved."
1639
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1640
  # clear and reload model and processor
1641
  return (
1642
  GLOBAL_STATE,
 
1647
  gr.update(visible=False, value=""),
1648
  _build_kick_plot(GLOBAL_STATE),
1649
  _format_impact_status(GLOBAL_STATE),
1650
+ propagate_main_update,
1651
  detect_btn_update,
1652
  propagate_player_update,
1653
  )
 
1951
  )
1952
  with gr.Row():
1953
  detect_ball_btn = gr.Button("Detect Ball", variant="secondary")
1954
+ propagate_btn = gr.Button("Propagate across video", variant="primary", interactive=False)
1955
  detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False)
1956
  propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False)
1957
  ball_status = gr.Markdown(visible=False)
 
1967
  # Wire events
1968
  def _on_video_change(GLOBAL_STATE: gr.State, video):
1969
  GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video)
1970
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
1971
  return (
1972
  GLOBAL_STATE,
1973
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
 
1976
  gr.update(visible=False, value=""),
1977
  _build_kick_plot(GLOBAL_STATE),
1978
  _format_impact_status(GLOBAL_STATE),
1979
+ propagate_main_update,
1980
  detect_btn_update,
1981
  propagate_player_update,
1982
  )
 
1984
  video_in.change(
1985
  _on_video_change,
1986
  inputs=[GLOBAL_STATE, video_in],
1987
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
1988
  show_progress=True,
1989
  )
1990
 
 
1997
  examples=examples_list,
1998
  inputs=[GLOBAL_STATE, video_in],
1999
  fn=_on_video_change,
2000
+ outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2001
  label="Examples",
2002
  cache_examples=False,
2003
  examples_per_page=5,
 
2080
  if s is not None and val is not None:
2081
  s.min_impact_speed_kmh = float(val)
2082
  _recompute_motion_metrics(s)
2083
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s)
2084
  return (
2085
  _build_kick_plot(s),
2086
  _format_impact_status(s),
2087
  gr.update(value=_format_kick_status(s), visible=True),
2088
+ propagate_main_update,
2089
  detect_btn_update,
2090
  propagate_player_update,
2091
  )
 
2094
  if s is not None and val is not None:
2095
  s.goal_distance_m = float(val)
2096
  _recompute_motion_metrics(s)
2097
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s)
2098
  return (
2099
  _build_kick_plot(s),
2100
  _format_impact_status(s),
2101
  gr.update(value=_format_kick_status(s), visible=True),
2102
+ propagate_main_update,
2103
  detect_btn_update,
2104
  propagate_player_update,
2105
  )
 
2107
  min_impact_speed_slider.change(
2108
  _update_min_impact_speed,
2109
  inputs=[GLOBAL_STATE, min_impact_speed_slider],
2110
+ outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2111
  )
2112
 
2113
  goal_distance_slider.change(
2114
  _update_goal_distance,
2115
  inputs=[GLOBAL_STATE, goal_distance_slider],
2116
+ outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2117
  )
2118
 
2119
  def _auto_detect_ball(
 
2129
  frame = state_in.video_frames[frame_idx]
2130
  detection = detect_ball_center(frame)
2131
  if detection is None:
2132
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
2133
  return (
2134
  update_frame_display(state_in, frame_idx),
2135
  gr.update(
 
2138
  ),
2139
  gr.update(value=frame_idx),
2140
  _build_kick_plot(state_in),
2141
+ propagate_main_update,
2142
  detect_btn_update,
2143
  propagate_player_update,
2144
  )
 
2170
 
2171
  status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
2172
  status_text += f" | {_format_kick_status(state_in)}"
2173
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
2174
  return (
2175
  preview_img,
2176
  gr.update(value=status_text, visible=True),
2177
  gr.update(value=frame_idx),
2178
  _build_kick_plot(state_in),
2179
+ propagate_main_update,
2180
  detect_btn_update,
2181
  propagate_player_update,
2182
  )
 
2184
  detect_ball_btn.click(
2185
  _auto_detect_ball,
2186
  inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
2187
+ outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn],
2188
  )
2189
 
2190
  def _auto_detect_player(state_in: AppState):
 
2201
  frame = state_in.video_frames[frame_idx]
2202
  detection = detect_person_box(frame)
2203
  if detection is None:
2204
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
2205
  status_text = (
2206
  f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. "
2207
  "Please add a box manually."
 
2211
  gr.update(value=status_text, visible=True),
2212
  gr.update(value=frame_idx),
2213
  _build_kick_plot(state_in),
2214
+ propagate_main_update,
2215
  detect_btn_update,
2216
  propagate_player_update,
2217
  gr.update(),
 
2273
  state_in.composited_frames.pop(frame_idx, None)
2274
  state_in.current_frame_idx = frame_idx
2275
 
2276
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in)
2277
  status_text = (
2278
  f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})"
2279
  )
 
2282
  gr.update(value=status_text, visible=True),
2283
  gr.update(value=frame_idx),
2284
  _build_kick_plot(state_in),
2285
+ propagate_main_update,
2286
  detect_btn_update,
2287
  propagate_player_update,
2288
  gr.update(value=PLAYER_OBJECT_ID),
 
2291
  detect_player_btn.click(
2292
  _auto_detect_player,
2293
  inputs=[GLOBAL_STATE],
2294
+ outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, obj_id_inp],
2295
  )
2296
 
2297
  @spaces.GPU()
2298
  def propagate_player_masks(GLOBAL_STATE: gr.State):
2299
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
2300
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2301
  return (
2302
  GLOBAL_STATE,
2303
  "Load a video first.",
 
2305
  _build_kick_plot(GLOBAL_STATE),
2306
  _format_impact_status(GLOBAL_STATE),
2307
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2308
+ propagate_main_update,
2309
  detect_btn_update,
2310
  propagate_player_update,
2311
  )
2312
  if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE):
2313
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2314
  return (
2315
  GLOBAL_STATE,
2316
  "Detect the player before propagating.",
 
2318
  _build_kick_plot(GLOBAL_STATE),
2319
  _format_impact_status(GLOBAL_STATE),
2320
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2321
+ propagate_main_update,
2322
  detect_btn_update,
2323
  propagate_player_update,
2324
  )
 
2333
  total = max(1, GLOBAL_STATE.num_frames)
2334
  processed = 0
2335
 
2336
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2337
  yield (
2338
  GLOBAL_STATE,
2339
  f"Propagating player: {processed}/{total}",
 
2341
  _build_kick_plot(GLOBAL_STATE),
2342
  _format_impact_status(GLOBAL_STATE),
2343
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2344
+ propagate_main_update,
2345
  detect_btn_update,
2346
  propagate_player_update,
2347
  )
 
2376
 
2377
  processed += 1
2378
  if processed % 30 == 0 or processed == total:
2379
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2380
  yield (
2381
  GLOBAL_STATE,
2382
  f"Propagating player: {processed}/{total}",
 
2384
  _build_kick_plot(GLOBAL_STATE),
2385
  _format_impact_status(GLOBAL_STATE),
2386
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2387
+ propagate_main_update,
2388
  detect_btn_update,
2389
  propagate_player_update,
2390
  )
 
2398
  target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1)))
2399
  GLOBAL_STATE.current_frame_idx = target_frame
2400
 
2401
+ propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE)
2402
  yield (
2403
  GLOBAL_STATE,
2404
  text,
 
2406
  _build_kick_plot(GLOBAL_STATE),
2407
  _format_impact_status(GLOBAL_STATE),
2408
  gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True),
2409
+ propagate_main_update,
2410
  detect_btn_update,
2411
  propagate_player_update,
2412
  )
 
2414
  propagate_player_btn.click(
2415
  propagate_player_masks,
2416
  inputs=[GLOBAL_STATE],
2417
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2418
  )
2419
 
2420
  # Image click to add a point and run forward on that frame
2421
  preview.select(
2422
+ _on_image_click_with_updates,
2423
+ [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk],
2424
+ [preview, propagate_btn, detect_player_btn, propagate_player_btn],
2425
  )
2426
 
2427
  # Playback via MP4 rendering only
 
2483
  propagate_btn.click(
2484
  propagate_masks,
2485
  inputs=[GLOBAL_STATE],
2486
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn],
2487
  )
2488
 
2489
  reset_btn.click(
2490
  reset_session,
2491
  inputs=GLOBAL_STATE,
2492
+ outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn],
2493
  )
2494
 
2495
  # ============================================================================