Mirko Trasciatti
commited on
Commit
·
778b9c7
1
Parent(s):
697126c
Add kick detection from SAM2 trajectories
Browse files
app.py
CHANGED
|
@@ -2,6 +2,8 @@ import colorsys
|
|
| 2 |
import gc
|
| 3 |
from copy import deepcopy
|
| 4 |
import base64
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
|
| 7 |
EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
|
|
@@ -213,6 +215,10 @@ class AppState:
|
|
| 213 |
self.pending_box_start_obj_id: int | None = None
|
| 214 |
self.is_switching_model: bool = False
|
| 215 |
self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# Model selection
|
| 217 |
self.model_repo_key: str = "tiny"
|
| 218 |
self.model_repo_id: str | None = None
|
|
@@ -288,6 +294,10 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
|
|
| 288 |
GLOBAL_STATE.masks_by_frame = {}
|
| 289 |
GLOBAL_STATE.color_by_obj = {}
|
| 290 |
GLOBAL_STATE.ball_centers = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
load_model_if_needed(GLOBAL_STATE)
|
| 293 |
|
|
@@ -499,10 +509,129 @@ def _update_centroids_for_frame(state: AppState, frame_idx: int):
|
|
| 499 |
centers.pop(int(frame_idx), None)
|
| 500 |
seen_obj_ids.add(int(obj_id))
|
| 501 |
_ensure_color_for_obj(state, int(obj_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
# Remove frames for objects without masks at this frame
|
| 503 |
for obj_id, centers in state.ball_centers.items():
|
| 504 |
if obj_id not in seen_obj_ids:
|
| 505 |
centers.pop(int(frame_idx), None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
|
| 508 |
def on_image_click(
|
|
@@ -708,6 +837,11 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 708 |
GLOBAL_STATE.pending_box_start_frame_idx = None
|
| 709 |
GLOBAL_STATE.pending_box_start_obj_id = None
|
| 710 |
GLOBAL_STATE.ball_centers.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
# Dispose and re-init inference session for current model with existing frames
|
| 713 |
try:
|
|
@@ -1154,6 +1288,8 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 1154 |
)
|
| 1155 |
|
| 1156 |
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
|
|
|
|
|
|
|
| 1157 |
return preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx)
|
| 1158 |
|
| 1159 |
detect_ball_btn.click(
|
|
|
|
| 2 |
import gc
|
| 3 |
from copy import deepcopy
|
| 4 |
import base64
|
| 5 |
+
import math
|
| 6 |
+
import statistics
|
| 7 |
from pathlib import Path
|
| 8 |
BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64")
|
| 9 |
EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4")
|
|
|
|
| 215 |
self.pending_box_start_obj_id: int | None = None
|
| 216 |
self.is_switching_model: bool = False
|
| 217 |
self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
|
| 218 |
+
self.mask_areas: dict[int, dict[int, float]] = {}
|
| 219 |
+
self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {}
|
| 220 |
+
self.ball_speeds: dict[int, dict[int, float]] = {}
|
| 221 |
+
self.kick_frame: int | None = None
|
| 222 |
# Model selection
|
| 223 |
self.model_repo_key: str = "tiny"
|
| 224 |
self.model_repo_id: str | None = None
|
|
|
|
| 294 |
GLOBAL_STATE.masks_by_frame = {}
|
| 295 |
GLOBAL_STATE.color_by_obj = {}
|
| 296 |
GLOBAL_STATE.ball_centers = {}
|
| 297 |
+
GLOBAL_STATE.mask_areas = {}
|
| 298 |
+
GLOBAL_STATE.smoothed_centers = {}
|
| 299 |
+
GLOBAL_STATE.ball_speeds = {}
|
| 300 |
+
GLOBAL_STATE.kick_frame = None
|
| 301 |
|
| 302 |
load_model_if_needed(GLOBAL_STATE)
|
| 303 |
|
|
|
|
| 509 |
centers.pop(int(frame_idx), None)
|
| 510 |
seen_obj_ids.add(int(obj_id))
|
| 511 |
_ensure_color_for_obj(state, int(obj_id))
|
| 512 |
+
mask_np = np.array(mask)
|
| 513 |
+
if mask_np.ndim == 3:
|
| 514 |
+
mask_np = mask_np.squeeze()
|
| 515 |
+
mask_np = np.clip(mask_np, 0.0, 1.0)
|
| 516 |
+
area = float(np.count_nonzero(mask_np > 0.3))
|
| 517 |
+
areas = state.mask_areas.setdefault(int(obj_id), {})
|
| 518 |
+
areas[int(frame_idx)] = area
|
| 519 |
# Remove frames for objects without masks at this frame
|
| 520 |
for obj_id, centers in state.ball_centers.items():
|
| 521 |
if obj_id not in seen_obj_ids:
|
| 522 |
centers.pop(int(frame_idx), None)
|
| 523 |
+
for obj_id, areas in state.mask_areas.items():
|
| 524 |
+
if obj_id not in seen_obj_ids:
|
| 525 |
+
areas.pop(int(frame_idx), None)
|
| 526 |
+
_recompute_motion_metrics(state)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1):
|
| 530 |
+
centers = state.ball_centers.get(target_obj_id)
|
| 531 |
+
if not centers or len(centers) < 3:
|
| 532 |
+
state.smoothed_centers[target_obj_id] = {}
|
| 533 |
+
state.ball_speeds[target_obj_id] = {}
|
| 534 |
+
state.kick_frame = None
|
| 535 |
+
return
|
| 536 |
+
|
| 537 |
+
items = sorted(centers.items())
|
| 538 |
+
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0
|
| 539 |
+
alpha = 0.35
|
| 540 |
+
|
| 541 |
+
smoothed: dict[int, tuple[float, float]] = {}
|
| 542 |
+
speeds: dict[int, float] = {}
|
| 543 |
+
|
| 544 |
+
prev_frame = None
|
| 545 |
+
prev_smooth = None
|
| 546 |
+
for frame_idx, (cx, cy) in items:
|
| 547 |
+
if prev_smooth is None:
|
| 548 |
+
smooth_x, smooth_y = float(cx), float(cy)
|
| 549 |
+
else:
|
| 550 |
+
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0])
|
| 551 |
+
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1])
|
| 552 |
+
smoothed[frame_idx] = (smooth_x, smooth_y)
|
| 553 |
+
if prev_smooth is None or prev_frame is None:
|
| 554 |
+
speeds[frame_idx] = 0.0
|
| 555 |
+
else:
|
| 556 |
+
frame_delta = max(1, frame_idx - prev_frame)
|
| 557 |
+
time_delta = frame_delta * dt
|
| 558 |
+
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1])
|
| 559 |
+
speed = dist / time_delta if time_delta > 0 else dist
|
| 560 |
+
speeds[frame_idx] = speed
|
| 561 |
+
prev_smooth = (smooth_x, smooth_y)
|
| 562 |
+
prev_frame = frame_idx
|
| 563 |
+
|
| 564 |
+
state.smoothed_centers[target_obj_id] = smoothed
|
| 565 |
+
state.ball_speeds[target_obj_id] = speeds
|
| 566 |
+
state.kick_frame = _detect_kick_frame(state, target_obj_id)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None:
|
| 570 |
+
smoothed = state.smoothed_centers.get(target_obj_id, {})
|
| 571 |
+
speeds = state.ball_speeds.get(target_obj_id, {})
|
| 572 |
+
if len(smoothed) < 5:
|
| 573 |
+
return None
|
| 574 |
+
|
| 575 |
+
frames = sorted(smoothed.keys())
|
| 576 |
+
speed_series = [speeds.get(f, 0.0) for f in frames]
|
| 577 |
+
|
| 578 |
+
baseline_window = min(5, len(frames) // 3 or 1)
|
| 579 |
+
baseline_speeds = speed_series[:baseline_window]
|
| 580 |
+
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0
|
| 581 |
+
|
| 582 |
+
speed_threshold = baseline_speed + 80.0 # pixels/second
|
| 583 |
+
sustain_frames = 3
|
| 584 |
+
holdout_frames = 8
|
| 585 |
+
return_distance = 12.0
|
| 586 |
+
area_window = 4
|
| 587 |
+
area_drop_ratio = 0.75
|
| 588 |
+
areas_dict = state.mask_areas.get(target_obj_id, {})
|
| 589 |
+
initial_center = smoothed[frames[0]]
|
| 590 |
+
|
| 591 |
+
for idx in range(baseline_window, len(frames)):
|
| 592 |
+
frame = frames[idx]
|
| 593 |
+
speed = speed_series[idx]
|
| 594 |
+
if speed < speed_threshold:
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
sustain_ok = True
|
| 598 |
+
for j in range(1, sustain_frames + 1):
|
| 599 |
+
if idx + j >= len(frames):
|
| 600 |
+
break
|
| 601 |
+
if speed_series[idx + j] < speed_threshold * 0.8:
|
| 602 |
+
sustain_ok = False
|
| 603 |
+
break
|
| 604 |
+
if not sustain_ok:
|
| 605 |
+
continue
|
| 606 |
+
|
| 607 |
+
area_pass = True
|
| 608 |
+
current_area = areas_dict.get(frame)
|
| 609 |
+
if current_area:
|
| 610 |
+
prev_areas = [
|
| 611 |
+
areas_dict.get(f)
|
| 612 |
+
for f in frames[max(0, idx - area_window):idx]
|
| 613 |
+
if areas_dict.get(f) is not None
|
| 614 |
+
]
|
| 615 |
+
if prev_areas:
|
| 616 |
+
median_prev = statistics.median(prev_areas)
|
| 617 |
+
if median_prev > 0 and current_area / median_prev > area_drop_ratio:
|
| 618 |
+
area_pass = False
|
| 619 |
+
if not area_pass:
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
moved_far = True
|
| 623 |
+
for future_frame in frames[idx:min(len(frames), idx + holdout_frames)]:
|
| 624 |
+
cx, cy = smoothed[future_frame]
|
| 625 |
+
dist = math.hypot(cx - initial_center[0], cy - initial_center[1])
|
| 626 |
+
if dist < return_distance:
|
| 627 |
+
moved_far = False
|
| 628 |
+
break
|
| 629 |
+
if not moved_far:
|
| 630 |
+
continue
|
| 631 |
+
|
| 632 |
+
return frame
|
| 633 |
+
|
| 634 |
+
return None
|
| 635 |
|
| 636 |
|
| 637 |
def on_image_click(
|
|
|
|
| 837 |
GLOBAL_STATE.pending_box_start_frame_idx = None
|
| 838 |
GLOBAL_STATE.pending_box_start_obj_id = None
|
| 839 |
GLOBAL_STATE.ball_centers.clear()
|
| 840 |
+
GLOBAL_STATE.mask_areas.clear()
|
| 841 |
+
GLOBAL_STATE.smoothed_centers.clear()
|
| 842 |
+
GLOBAL_STATE.ball_speeds.clear()
|
| 843 |
+
GLOBAL_STATE.kick_frame = None
|
| 844 |
+
GLOBAL_STATE.ball_centers.clear()
|
| 845 |
|
| 846 |
# Dispose and re-init inference session for current model with existing frames
|
| 847 |
try:
|
|
|
|
| 1288 |
)
|
| 1289 |
|
| 1290 |
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})"
|
| 1291 |
+
if state_in.kick_frame is not None:
|
| 1292 |
+
status_text += f" | Kick frame ≈ {state_in.kick_frame}"
|
| 1293 |
return preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx)
|
| 1294 |
|
| 1295 |
detect_ball_btn.click(
|