Mirko Trasciatti
commited on
Commit
·
4668c44
1
Parent(s):
cccfb86
Add kick diagnostics plot and smoothing overlay
Browse files- app.py +160 -23
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import base64
|
|
| 5 |
import math
|
| 6 |
import statistics
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
| 8 |
BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
|
| 9 |
EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
|
| 10 |
|
|
@@ -219,6 +221,13 @@ class AppState:
|
|
| 219 |
self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
|
| 220 |
self.ball_speeds: dict[int, dict[int, float]] = {}
|
| 221 |
self.kick_frame: int | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# Model selection
|
| 223 |
self.model_repo_key: str = "tiny"
|
| 224 |
self.model_repo_id: str | None = None
|
|
@@ -435,28 +444,35 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 435 |
for obj_id, centers in state.ball_centers.items():
|
| 436 |
if not centers:
|
| 437 |
continue
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
distances: list[float] = []
|
| 440 |
prev_center = None
|
| 441 |
-
for _, (
|
| 442 |
if prev_center is None:
|
| 443 |
distances.append(0.0)
|
| 444 |
else:
|
| 445 |
-
dx =
|
| 446 |
-
dy =
|
| 447 |
distances.append(float(np.hypot(dx, dy)))
|
| 448 |
-
prev_center = (
|
| 449 |
max_dist = max(distances[1:], default=0.0)
|
| 450 |
color_by_frame: dict[int, tuple[int, int, int]] = {}
|
| 451 |
-
for (f_idx, _), dist in zip(
|
| 452 |
ratio = dist / max_dist if max_dist > 0 else 0.0
|
| 453 |
color_by_frame[f_idx] = _speed_to_color(ratio)
|
| 454 |
-
for f_idx, (
|
| 455 |
highlight = (f_idx == frame_idx)
|
| 456 |
color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0))
|
| 457 |
line_width = 1 if not highlight else 2
|
| 458 |
-
draw.line([(
|
| 459 |
-
draw.line([(
|
| 460 |
# Save to cache and return
|
| 461 |
state.composited_frames[frame_idx] = out_img
|
| 462 |
return out_img
|
|
@@ -526,12 +542,100 @@ def _update_centroids_for_frame(state: AppState, frame_idx: int):
|
|
| 526 |
_recompute_motion_metrics(state)
|
| 527 |
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
|
| 530 |
centers = state.ball_centers.get(target_obj_id)
|
| 531 |
if not centers or len(centers) < 3:
|
| 532 |
state.smoothed_centers[target_obj_id] = {}
|
| 533 |
state.ball_speeds[target_obj_id] = {}
|
| 534 |
state.kick_frame = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
return
|
| 536 |
|
| 537 |
items = sorted(centers.items())
|
|
@@ -591,6 +695,14 @@ def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
|
|
| 591 |
areas_dict = state.mask_areas.get(target_obj_id, {})
|
| 592 |
initial_center = smoothed[frames[0]]
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
for idx in range(baseline_window, len(frames)):
|
| 595 |
frame = frames[idx]
|
| 596 |
speed = speed_series[idx]
|
|
@@ -634,6 +746,7 @@ def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
|
|
| 634 |
if not moved_far:
|
| 635 |
continue
|
| 636 |
|
|
|
|
| 637 |
return frame
|
| 638 |
|
| 639 |
return None
|
|
@@ -778,7 +891,7 @@ def on_image_click(
|
|
| 778 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 779 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 780 |
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
| 781 |
-
return GLOBAL_STATE, "Load a video first.", gr.update()
|
| 782 |
|
| 783 |
processor = deepcopy(GLOBAL_STATE.processor)
|
| 784 |
model = deepcopy(GLOBAL_STATE.model)
|
|
@@ -792,7 +905,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 792 |
processed = 0
|
| 793 |
|
| 794 |
# Initial status; no slider change yet
|
| 795 |
-
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
|
| 796 |
|
| 797 |
last_frame_idx = 0
|
| 798 |
with torch.inference_mode():
|
|
@@ -819,19 +932,27 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 819 |
processed += 1
|
| 820 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 821 |
if processed % 30 == 0 or processed == total:
|
| 822 |
-
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 823 |
|
| 824 |
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
|
| 825 |
|
| 826 |
# Final status; ensure slider points to last processed frame
|
| 827 |
-
yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
|
| 828 |
|
| 829 |
|
| 830 |
-
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
| 831 |
# Reset only session-related state, keep uploaded video and model
|
| 832 |
if not GLOBAL_STATE.video_frames:
|
| 833 |
# Nothing loaded; keep behavior
|
| 834 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 835 |
|
| 836 |
# Clear prompts and caches
|
| 837 |
GLOBAL_STATE.masks_by_frame.clear()
|
|
@@ -866,7 +987,15 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 866 |
slider_value = gr.update(value=current_idx)
|
| 867 |
status = "Session reset. Prompts cleared; video preserved."
|
| 868 |
# clear and reload model and processor
|
| 869 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
|
| 871 |
|
| 872 |
def create_annotation_preview(video_file, annotations):
|
|
@@ -1135,6 +1264,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1135 |
label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
|
| 1136 |
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
|
| 1137 |
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
|
|
|
|
| 1138 |
|
| 1139 |
# Wire events
|
| 1140 |
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
|
@@ -1144,13 +1274,14 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1144 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 1145 |
first_frame,
|
| 1146 |
status,
|
| 1147 |
-
gr.update(visible=False, value="")
|
|
|
|
| 1148 |
)
|
| 1149 |
|
| 1150 |
video_in.change(
|
| 1151 |
_on_video_change,
|
| 1152 |
inputs=[GLOBAL_STATE, video_in],
|
| 1153 |
-
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status],
|
| 1154 |
show_progress=True,
|
| 1155 |
)
|
| 1156 |
|
|
@@ -1165,7 +1296,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1165 |
examples=examples_list,
|
| 1166 |
inputs=[GLOBAL_STATE, video_in],
|
| 1167 |
fn=_on_video_change,
|
| 1168 |
-
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status],
|
| 1169 |
label="Examples",
|
| 1170 |
cache_examples=False,
|
| 1171 |
examples_per_page=5,
|
|
@@ -1265,6 +1396,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1265 |
visible=True,
|
| 1266 |
),
|
| 1267 |
gr.update(value=frame_idx),
|
|
|
|
| 1268 |
)
|
| 1269 |
|
| 1270 |
x_center, y_center, _, _, conf = detection
|
|
@@ -1297,12 +1429,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1297 |
status_text += f" | Kick frame ≈ {state_in.kick_frame}"
|
| 1298 |
else:
|
| 1299 |
status_text += " | Kick frame: not detected"
|
| 1300 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1301 |
|
| 1302 |
detect_ball_btn.click(
|
| 1303 |
_auto_detect_ball,
|
| 1304 |
inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
|
| 1305 |
-
outputs=[preview, ball_status, frame_slider],
|
| 1306 |
)
|
| 1307 |
|
| 1308 |
# Image click to add a point and run forward on that frame
|
|
@@ -1352,13 +1489,13 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1352 |
propagate_btn.click(
|
| 1353 |
propagate_masks,
|
| 1354 |
inputs=[GLOBAL_STATE],
|
| 1355 |
-
outputs=[GLOBAL_STATE, propagate_status, frame_slider],
|
| 1356 |
)
|
| 1357 |
|
| 1358 |
reset_btn.click(
|
| 1359 |
reset_session,
|
| 1360 |
inputs=GLOBAL_STATE,
|
| 1361 |
-
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status],
|
| 1362 |
)
|
| 1363 |
|
| 1364 |
# ============================================================================
|
|
|
|
| 5 |
import math
|
| 6 |
import statistics
|
| 7 |
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
|
| 11 |
EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
|
| 12 |
|
|
|
|
| 221 |
self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
|
| 222 |
self.ball_speeds: dict[int, dict[int, float]] = {}
|
| 223 |
self.kick_frame: int | None = None
|
| 224 |
+
self.kick_debug_frames: list[int] = []
|
| 225 |
+
self.kick_debug_speeds: list[float] = []
|
| 226 |
+
self.kick_debug_threshold: float | None = None
|
| 227 |
+
self.kick_debug_baseline: float | None = None
|
| 228 |
+
self.kick_debug_speed_std: float | None = None
|
| 229 |
+
self.kick_debug_area: list[float] = []
|
| 230 |
+
self.kick_debug_kick_frame: int | None = None
|
| 231 |
# Model selection
|
| 232 |
self.model_repo_key: str = "tiny"
|
| 233 |
self.model_repo_id: str | None = None
|
|
|
|
| 444 |
for obj_id, centers in state.ball_centers.items():
|
| 445 |
if not centers:
|
| 446 |
continue
|
| 447 |
+
raw_items = sorted(centers.items())
|
| 448 |
+
for _, (rx, ry) in raw_items:
|
| 449 |
+
draw.line([(rx - cross_half, ry), (rx + cross_half, ry)], fill=(160, 160, 160), width=1)
|
| 450 |
+
draw.line([(rx, ry - cross_half), (rx, ry + cross_half)], fill=(160, 160, 160), width=1)
|
| 451 |
+
smooth_dict = state.smoothed_centers.get(obj_id, {})
|
| 452 |
+
if not smooth_dict:
|
| 453 |
+
continue
|
| 454 |
+
smooth_items = sorted(smooth_dict.items())
|
| 455 |
distances: list[float] = []
|
| 456 |
prev_center = None
|
| 457 |
+
for _, (sx, sy) in smooth_items:
|
| 458 |
if prev_center is None:
|
| 459 |
distances.append(0.0)
|
| 460 |
else:
|
| 461 |
+
dx = sx - prev_center[0]
|
| 462 |
+
dy = sy - prev_center[1]
|
| 463 |
distances.append(float(np.hypot(dx, dy)))
|
| 464 |
+
prev_center = (sx, sy)
|
| 465 |
max_dist = max(distances[1:], default=0.0)
|
| 466 |
color_by_frame: dict[int, tuple[int, int, int]] = {}
|
| 467 |
+
for (f_idx, _), dist in zip(smooth_items, distances):
|
| 468 |
ratio = dist / max_dist if max_dist > 0 else 0.0
|
| 469 |
color_by_frame[f_idx] = _speed_to_color(ratio)
|
| 470 |
+
for f_idx, (sx, sy) in reversed(smooth_items):
|
| 471 |
highlight = (f_idx == frame_idx)
|
| 472 |
color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0))
|
| 473 |
line_width = 1 if not highlight else 2
|
| 474 |
+
draw.line([(sx - cross_half, sy), (sx + cross_half, sy)], fill=color, width=line_width)
|
| 475 |
+
draw.line([(sx, sy - cross_half), (sx, sy + cross_half)], fill=color, width=line_width)
|
| 476 |
# Save to cache and return
|
| 477 |
state.composited_frames[frame_idx] = out_img
|
| 478 |
return out_img
|
|
|
|
| 542 |
_recompute_motion_metrics(state)
|
| 543 |
|
| 544 |
|
| 545 |
+
def _build_kick_plot(state: AppState):
|
| 546 |
+
fig = go.Figure()
|
| 547 |
+
if state is None or not state.kick_debug_frames or not state.kick_debug_speeds:
|
| 548 |
+
fig.update_layout(
|
| 549 |
+
title="Kick speed diagnostics",
|
| 550 |
+
xaxis_title="Frame",
|
| 551 |
+
yaxis_title="Speed (px/s)",
|
| 552 |
+
)
|
| 553 |
+
return fig
|
| 554 |
+
|
| 555 |
+
frames = state.kick_debug_frames
|
| 556 |
+
speeds = state.kick_debug_speeds
|
| 557 |
+
areas = state.kick_debug_area if state.kick_debug_area else [0.0] * len(frames)
|
| 558 |
+
threshold = state.kick_debug_threshold or 0.0
|
| 559 |
+
baseline = state.kick_debug_baseline or 0.0
|
| 560 |
+
kick_frame = state.kick_debug_kick_frame
|
| 561 |
+
|
| 562 |
+
fig.add_trace(
|
| 563 |
+
go.Scatter(
|
| 564 |
+
x=frames,
|
| 565 |
+
y=speeds,
|
| 566 |
+
mode="lines+markers",
|
| 567 |
+
name="Speed (px/s)",
|
| 568 |
+
line=dict(color="#1f77b4"),
|
| 569 |
+
)
|
| 570 |
+
)
|
| 571 |
+
fig.add_trace(
|
| 572 |
+
go.Scatter(
|
| 573 |
+
x=[frames[0], frames[-1]],
|
| 574 |
+
y=[threshold, threshold],
|
| 575 |
+
mode="lines",
|
| 576 |
+
name="Adaptive threshold",
|
| 577 |
+
line=dict(color="#d62728", dash="dash"),
|
| 578 |
+
)
|
| 579 |
+
)
|
| 580 |
+
fig.add_trace(
|
| 581 |
+
go.Scatter(
|
| 582 |
+
x=[frames[0], frames[-1]],
|
| 583 |
+
y=[baseline, baseline],
|
| 584 |
+
mode="lines",
|
| 585 |
+
name="Baseline speed",
|
| 586 |
+
line=dict(color="#ff7f0e", dash="dot"),
|
| 587 |
+
)
|
| 588 |
+
)
|
| 589 |
+
fig.add_trace(
|
| 590 |
+
go.Scatter(
|
| 591 |
+
x=frames,
|
| 592 |
+
y=areas,
|
| 593 |
+
mode="lines",
|
| 594 |
+
name="Mask area",
|
| 595 |
+
line=dict(color="#2ca02c"),
|
| 596 |
+
yaxis="y2",
|
| 597 |
+
)
|
| 598 |
+
)
|
| 599 |
+
if kick_frame is not None:
|
| 600 |
+
fig.add_trace(
|
| 601 |
+
go.Scatter(
|
| 602 |
+
x=[kick_frame, kick_frame],
|
| 603 |
+
y=[min(speeds), max(max(speeds), threshold)],
|
| 604 |
+
mode="lines",
|
| 605 |
+
name="Detected kick",
|
| 606 |
+
line=dict(color="#9467bd", dash="dashdot"),
|
| 607 |
+
)
|
| 608 |
+
)
|
| 609 |
+
fig.update_layout(
|
| 610 |
+
title="Kick speed diagnostics",
|
| 611 |
+
xaxis_title="Frame",
|
| 612 |
+
yaxis_title="Speed (px/s)",
|
| 613 |
+
yaxis=dict(side="left"),
|
| 614 |
+
yaxis2=dict(
|
| 615 |
+
title="Mask area (px)",
|
| 616 |
+
overlaying="y",
|
| 617 |
+
side="right",
|
| 618 |
+
showgrid=False,
|
| 619 |
+
),
|
| 620 |
+
legend=dict(orientation="h"),
|
| 621 |
+
margin=dict(t=40, l=40, r=40, b=40),
|
| 622 |
+
)
|
| 623 |
+
return fig
|
| 624 |
+
|
| 625 |
+
|
| 626 |
def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
|
| 627 |
centers = state.ball_centers.get(target_obj_id)
|
| 628 |
if not centers or len(centers) < 3:
|
| 629 |
state.smoothed_centers[target_obj_id] = {}
|
| 630 |
state.ball_speeds[target_obj_id] = {}
|
| 631 |
state.kick_frame = None
|
| 632 |
+
state.kick_debug_frames = []
|
| 633 |
+
state.kick_debug_speeds = []
|
| 634 |
+
state.kick_debug_threshold = None
|
| 635 |
+
state.kick_debug_baseline = None
|
| 636 |
+
state.kick_debug_speed_std = None
|
| 637 |
+
state.kick_debug_area = []
|
| 638 |
+
state.kick_debug_kick_frame = None
|
| 639 |
return
|
| 640 |
|
| 641 |
items = sorted(centers.items())
|
|
|
|
| 695 |
areas_dict = state.mask_areas.get(target_obj_id, {})
|
| 696 |
initial_center = smoothed[frames[0]]
|
| 697 |
|
| 698 |
+
state.kick_debug_frames = frames
|
| 699 |
+
state.kick_debug_speeds = speed_series
|
| 700 |
+
state.kick_debug_threshold = speed_threshold
|
| 701 |
+
state.kick_debug_baseline = baseline_speed
|
| 702 |
+
state.kick_debug_speed_std = speed_std
|
| 703 |
+
state.kick_debug_area = [areas_dict.get(f, 0.0) for f in frames]
|
| 704 |
+
state.kick_debug_kick_frame = None
|
| 705 |
+
|
| 706 |
for idx in range(baseline_window, len(frames)):
|
| 707 |
frame = frames[idx]
|
| 708 |
speed = speed_series[idx]
|
|
|
|
| 746 |
if not moved_far:
|
| 747 |
continue
|
| 748 |
|
| 749 |
+
state.kick_debug_kick_frame = frame
|
| 750 |
return frame
|
| 751 |
|
| 752 |
return None
|
|
|
|
| 891 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 892 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 893 |
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
| 894 |
+
return GLOBAL_STATE, "Load a video first.", gr.update(), _build_kick_plot(GLOBAL_STATE)
|
| 895 |
|
| 896 |
processor = deepcopy(GLOBAL_STATE.processor)
|
| 897 |
model = deepcopy(GLOBAL_STATE.model)
|
|
|
|
| 905 |
processed = 0
|
| 906 |
|
| 907 |
# Initial status; no slider change yet
|
| 908 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(), _build_kick_plot(GLOBAL_STATE)
|
| 909 |
|
| 910 |
last_frame_idx = 0
|
| 911 |
with torch.inference_mode():
|
|
|
|
| 932 |
processed += 1
|
| 933 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 934 |
if processed % 30 == 0 or processed == total:
|
| 935 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx), _build_kick_plot(GLOBAL_STATE)
|
| 936 |
|
| 937 |
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
|
| 938 |
|
| 939 |
# Final status; ensure slider points to last processed frame
|
| 940 |
+
yield GLOBAL_STATE, text, gr.update(value=last_frame_idx), _build_kick_plot(GLOBAL_STATE)
|
| 941 |
|
| 942 |
|
| 943 |
+
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, gr.Update, go.Figure]:
|
| 944 |
# Reset only session-related state, keep uploaded video and model
|
| 945 |
if not GLOBAL_STATE.video_frames:
|
| 946 |
# Nothing loaded; keep behavior
|
| 947 |
+
return (
|
| 948 |
+
GLOBAL_STATE,
|
| 949 |
+
None,
|
| 950 |
+
0,
|
| 951 |
+
0,
|
| 952 |
+
"Session reset. Load a new video.",
|
| 953 |
+
gr.update(visible=False, value=""),
|
| 954 |
+
_build_kick_plot(GLOBAL_STATE),
|
| 955 |
+
)
|
| 956 |
|
| 957 |
# Clear prompts and caches
|
| 958 |
GLOBAL_STATE.masks_by_frame.clear()
|
|
|
|
| 987 |
slider_value = gr.update(value=current_idx)
|
| 988 |
status = "Session reset. Prompts cleared; video preserved."
|
| 989 |
# clear and reload model and processor
|
| 990 |
+
return (
|
| 991 |
+
GLOBAL_STATE,
|
| 992 |
+
preview_img,
|
| 993 |
+
slider_minmax,
|
| 994 |
+
slider_value,
|
| 995 |
+
status,
|
| 996 |
+
gr.update(visible=False, value=""),
|
| 997 |
+
_build_kick_plot(GLOBAL_STATE),
|
| 998 |
+
)
|
| 999 |
|
| 1000 |
|
| 1001 |
def create_annotation_preview(video_file, annotations):
|
|
|
|
| 1264 |
label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label")
|
| 1265 |
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object")
|
| 1266 |
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type")
|
| 1267 |
+
kick_plot = gr.Plot(label="Kick diagnostics", show_label=True)
|
| 1268 |
|
| 1269 |
# Wire events
|
| 1270 |
def _on_video_change(GLOBAL_STATE: gr.State, video):
|
|
|
|
| 1274 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 1275 |
first_frame,
|
| 1276 |
status,
|
| 1277 |
+
gr.update(visible=False, value=""),
|
| 1278 |
+
_build_kick_plot(GLOBAL_STATE)
|
| 1279 |
)
|
| 1280 |
|
| 1281 |
video_in.change(
|
| 1282 |
_on_video_change,
|
| 1283 |
inputs=[GLOBAL_STATE, video_in],
|
| 1284 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot],
|
| 1285 |
show_progress=True,
|
| 1286 |
)
|
| 1287 |
|
|
|
|
| 1296 |
examples=examples_list,
|
| 1297 |
inputs=[GLOBAL_STATE, video_in],
|
| 1298 |
fn=_on_video_change,
|
| 1299 |
+
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot],
|
| 1300 |
label="Examples",
|
| 1301 |
cache_examples=False,
|
| 1302 |
examples_per_page=5,
|
|
|
|
| 1396 |
visible=True,
|
| 1397 |
),
|
| 1398 |
gr.update(value=frame_idx),
|
| 1399 |
+
_build_kick_plot(state_in),
|
| 1400 |
)
|
| 1401 |
|
| 1402 |
x_center, y_center, _, _, conf = detection
|
|
|
|
| 1429 |
status_text += f" | Kick frame ≈ {state_in.kick_frame}"
|
| 1430 |
else:
|
| 1431 |
status_text += " | Kick frame: not detected"
|
| 1432 |
+
return (
|
| 1433 |
+
preview_img,
|
| 1434 |
+
gr.update(value=status_text, visible=True),
|
| 1435 |
+
gr.update(value=frame_idx),
|
| 1436 |
+
_build_kick_plot(state_in),
|
| 1437 |
+
)
|
| 1438 |
|
| 1439 |
detect_ball_btn.click(
|
| 1440 |
_auto_detect_ball,
|
| 1441 |
inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk],
|
| 1442 |
+
outputs=[preview, ball_status, frame_slider, kick_plot],
|
| 1443 |
)
|
| 1444 |
|
| 1445 |
# Image click to add a point and run forward on that frame
|
|
|
|
| 1489 |
propagate_btn.click(
|
| 1490 |
propagate_masks,
|
| 1491 |
inputs=[GLOBAL_STATE],
|
| 1492 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot],
|
| 1493 |
)
|
| 1494 |
|
| 1495 |
reset_btn.click(
|
| 1496 |
reset_session,
|
| 1497 |
inputs=GLOBAL_STATE,
|
| 1498 |
+
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot],
|
| 1499 |
)
|
| 1500 |
|
| 1501 |
# ============================================================================
|
requirements.txt
CHANGED
|
@@ -7,5 +7,6 @@ opencv-python
|
|
| 7 |
imageio[pyav]
|
| 8 |
spaces
|
| 9 |
git+https://github.com/iMoonLab/yolov13
|
|
|
|
| 10 |
|
| 11 |
|
|
|
|
| 7 |
imageio[pyav]
|
| 8 |
spaces
|
| 9 |
git+https://github.com/iMoonLab/yolov13
|
| 10 |
+
plotly
|
| 11 |
|
| 12 |
|