|
|
from __future__ import annotations |
|
|
|
|
|
import colorsys |
|
|
import gc |
|
|
from copy import deepcopy |
|
|
import base64 |
|
|
import math |
|
|
import statistics |
|
|
from pathlib import Path |
|
|
|
|
|
import plotly.graph_objects as go |
|
|
BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64") |
|
|
EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4") |
|
|
|
|
|
|
|
|
def ensure_example_video() -> str: |
|
|
""" |
|
|
Ensure the Kickit example video exists locally by decoding the base64 text file. |
|
|
Returns the path to the decoded MP4. |
|
|
""" |
|
|
if EXAMPLE_VIDEO_PATH.exists(): |
|
|
return str(EXAMPLE_VIDEO_PATH) |
|
|
if not BASE64_VIDEO_PATH.exists(): |
|
|
raise FileNotFoundError("Base64 video asset not found.") |
|
|
data = BASE64_VIDEO_PATH.read_text() |
|
|
EXAMPLE_VIDEO_PATH.write_bytes(base64.b64decode(data)) |
|
|
return str(EXAMPLE_VIDEO_PATH) |
|
|
|
|
|
from types import SimpleNamespace |
|
|
from typing import Optional, Any |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import spaces |
|
|
import torch |
|
|
from gradio.themes import Soft |
|
|
from PIL import Image, ImageDraw |
|
|
|
|
|
from transformers import AutoModel, Sam2VideoProcessor |
|
|
from ultralytics import YOLO |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
YOLO_MODEL_CACHE: dict[str, YOLO] = {} |
|
|
YOLO_DEFAULT_MODEL = "yolov13n.pt" |
|
|
YOLO_REPO_ID = "atalaydenknalbant/Yolov13" |
|
|
YOLO_TARGET_NAME = "sports ball" |
|
|
YOLO_CONF_THRESHOLD = 0.0 |
|
|
YOLO_IOU_THRESHOLD = 0.02 |
|
|
PLAYER_TARGET_NAME = "person" |
|
|
PLAYER_OBJECT_ID = 2 |
|
|
BALL_OBJECT_ID = 1 |
|
|
|
|
|
|
|
|
def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO: |
|
|
""" |
|
|
Lazily download and load a YOLOv13 model, caching it for reuse. |
|
|
""" |
|
|
if model_filename in YOLO_MODEL_CACHE: |
|
|
return YOLO_MODEL_CACHE[model_filename] |
|
|
|
|
|
model_path = hf_hub_download(repo_id=YOLO_REPO_ID, filename=model_filename) |
|
|
model = YOLO(model_path) |
|
|
YOLO_MODEL_CACHE[model_filename] = model |
|
|
return model |
|
|
|
|
|
|
|
|
def detect_ball_center( |
|
|
frame: Image.Image, |
|
|
model_filename: str = YOLO_DEFAULT_MODEL, |
|
|
conf_threshold: float = YOLO_CONF_THRESHOLD, |
|
|
iou_threshold: float = YOLO_IOU_THRESHOLD, |
|
|
) -> Optional[tuple[int, int, int, int, float]]: |
|
|
""" |
|
|
Run YOLO on a single frame and return (x_center, y_center, width, height, confidence) |
|
|
for the highest-confidence sports ball detection. |
|
|
""" |
|
|
model = get_yolo_model(model_filename) |
|
|
class_ids = [ |
|
|
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME |
|
|
] |
|
|
if not class_ids: |
|
|
return None |
|
|
|
|
|
results = model.predict( |
|
|
source=frame, |
|
|
conf=conf_threshold, |
|
|
iou=iou_threshold, |
|
|
max_det=1, |
|
|
classes=class_ids, |
|
|
imgsz=640, |
|
|
device="cpu", |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
if not results: |
|
|
return None |
|
|
|
|
|
boxes = results[0].boxes |
|
|
if boxes is None or len(boxes) == 0: |
|
|
return None |
|
|
|
|
|
box = boxes[0] |
|
|
|
|
|
xywh = box.xywh[0].cpu().tolist() |
|
|
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 |
|
|
x_center, y_center, width, height = xywh |
|
|
return ( |
|
|
int(round(x_center)), |
|
|
int(round(y_center)), |
|
|
int(round(width)), |
|
|
int(round(height)), |
|
|
conf, |
|
|
) |
|
|
|
|
|
|
|
|
def detect_person_box( |
|
|
frame: Image.Image, |
|
|
model_filename: str = YOLO_DEFAULT_MODEL, |
|
|
conf_threshold: float = YOLO_CONF_THRESHOLD, |
|
|
iou_threshold: float = YOLO_IOU_THRESHOLD, |
|
|
) -> Optional[tuple[int, int, int, int, float]]: |
|
|
""" |
|
|
Run YOLO on a single frame and return (x_min, y_min, x_max, y_max, confidence) |
|
|
for the highest-confidence person detection. |
|
|
""" |
|
|
model = get_yolo_model(model_filename) |
|
|
class_ids = [ |
|
|
idx for idx, name in model.names.items() if name.lower() == PLAYER_TARGET_NAME |
|
|
] |
|
|
if not class_ids: |
|
|
return None |
|
|
|
|
|
results = model.predict( |
|
|
source=frame, |
|
|
conf=conf_threshold, |
|
|
iou=iou_threshold, |
|
|
max_det=5, |
|
|
classes=class_ids, |
|
|
imgsz=640, |
|
|
device="cpu", |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
if not results: |
|
|
return None |
|
|
|
|
|
boxes = results[0].boxes |
|
|
if boxes is None or len(boxes) == 0: |
|
|
return None |
|
|
|
|
|
box = boxes[0] |
|
|
xyxy = box.xyxy[0].cpu().tolist() |
|
|
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 |
|
|
x_min, y_min, x_max, y_max = xyxy |
|
|
frame_width, frame_height = frame.size |
|
|
x_min = max(0, min(frame_width - 1, int(round(x_min)))) |
|
|
y_min = max(0, min(frame_height - 1, int(round(y_min)))) |
|
|
x_max = max(0, min(frame_width - 1, int(round(x_max)))) |
|
|
y_max = max(0, min(frame_height - 1, int(round(y_max)))) |
|
|
if x_max <= x_min or y_max <= y_min: |
|
|
return None |
|
|
return x_min, y_min, x_max, y_max, conf |
|
|
|
|
|
|
|
|
def _compute_sam_window_from_kick(state: AppState, kick_frame: int | None) -> tuple[int, int]: |
|
|
total_frames = state.num_frames |
|
|
if total_frames == 0: |
|
|
return 0, 0 |
|
|
fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0 |
|
|
target_window_frames = max(1, int(round(fps * 4.0))) |
|
|
half_window = target_window_frames // 2 |
|
|
if kick_frame is None: |
|
|
start_idx = 0 |
|
|
else: |
|
|
start_idx = max(0, int(kick_frame) - half_window) |
|
|
end_idx = min(total_frames, start_idx + target_window_frames) |
|
|
if end_idx <= start_idx: |
|
|
end_idx = min(total_frames, start_idx + 1) |
|
|
state.sam_window = (start_idx, end_idx) |
|
|
return start_idx, end_idx |
|
|
|
|
|
|
|
|
def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None = None) -> None: |
|
|
if state is None or state.num_frames == 0: |
|
|
raise gr.Error("Load a video first, then track with YOLO.") |
|
|
|
|
|
model = get_yolo_model() |
|
|
class_ids = [ |
|
|
idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME |
|
|
] |
|
|
if not class_ids: |
|
|
raise gr.Error("YOLO model does not contain the sports ball class.") |
|
|
|
|
|
frames = state.video_frames |
|
|
total = len(frames) |
|
|
centers: dict[int, tuple[float, float]] = {} |
|
|
boxes: dict[int, tuple[int, int, int, int]] = {} |
|
|
confs: dict[int, float] = {} |
|
|
areas: dict[int, float] = {} |
|
|
first_detection_frame: int | None = None |
|
|
|
|
|
for idx, frame in enumerate(frames): |
|
|
if progress is not None: |
|
|
progress((idx + 1) / total) |
|
|
|
|
|
results = model.predict( |
|
|
source=frame, |
|
|
conf=YOLO_CONF_THRESHOLD, |
|
|
iou=YOLO_IOU_THRESHOLD, |
|
|
max_det=1, |
|
|
classes=class_ids, |
|
|
imgsz=640, |
|
|
device="cpu", |
|
|
verbose=False, |
|
|
) |
|
|
if not results: |
|
|
continue |
|
|
boxes_result = results[0].boxes |
|
|
if boxes_result is None or len(boxes_result) == 0: |
|
|
continue |
|
|
|
|
|
box = boxes_result[0] |
|
|
xywh = box.xywh[0].cpu().tolist() |
|
|
conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 |
|
|
x_center, y_center, width, height = xywh |
|
|
x_center = float(x_center) |
|
|
y_center = float(y_center) |
|
|
width = max(1.0, float(width)) |
|
|
height = max(1.0, float(height)) |
|
|
|
|
|
frame_width, frame_height = frame.size |
|
|
x_min = int(round(max(0.0, x_center - width / 2.0))) |
|
|
y_min = int(round(max(0.0, y_center - height / 2.0))) |
|
|
x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0))) |
|
|
y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0))) |
|
|
if x_max <= x_min or y_max <= y_min: |
|
|
continue |
|
|
|
|
|
centers[idx] = (x_center, y_center) |
|
|
boxes[idx] = (x_min, y_min, x_max, y_max) |
|
|
confs[idx] = conf |
|
|
areas[idx] = float((x_max - x_min) * (y_max - y_min)) |
|
|
if first_detection_frame is None: |
|
|
first_detection_frame = idx |
|
|
|
|
|
state.yolo_ball_centers = centers |
|
|
state.yolo_ball_boxes = boxes |
|
|
state.yolo_ball_conf = confs |
|
|
state.yolo_mask_area_proxy = [areas.get(k, 0.0) for k in sorted(centers.keys())] |
|
|
state.yolo_initial_frame = first_detection_frame |
|
|
|
|
|
if len(centers) < 3: |
|
|
state.yolo_smoothed_centers = {} |
|
|
state.yolo_speeds = {} |
|
|
state.yolo_distance_from_start = {} |
|
|
state.yolo_threshold = None |
|
|
state.yolo_baseline_speed = None |
|
|
state.yolo_speed_std = None |
|
|
state.yolo_kick_frame = None |
|
|
state.yolo_status = "❌ YOLO13: insufficient detections to estimate kick. Please retry or annotate manually." |
|
|
state.sam_window = None |
|
|
return |
|
|
|
|
|
items = sorted(centers.items()) |
|
|
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 |
|
|
alpha = 0.35 |
|
|
|
|
|
smoothed: dict[int, tuple[float, float]] = {} |
|
|
speeds: dict[int, float] = {} |
|
|
|
|
|
prev_frame = None |
|
|
prev_smooth = None |
|
|
for frame_idx, (cx, cy) in items: |
|
|
if prev_smooth is None: |
|
|
smooth_x, smooth_y = float(cx), float(cy) |
|
|
else: |
|
|
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) |
|
|
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) |
|
|
smoothed[frame_idx] = (smooth_x, smooth_y) |
|
|
if prev_smooth is None or prev_frame is None: |
|
|
speeds[frame_idx] = 0.0 |
|
|
else: |
|
|
frame_delta = max(1, frame_idx - prev_frame) |
|
|
time_delta = frame_delta * dt |
|
|
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) |
|
|
speed = dist / time_delta if time_delta > 0 else dist |
|
|
speeds[frame_idx] = speed |
|
|
prev_smooth = (smooth_x, smooth_y) |
|
|
prev_frame = frame_idx |
|
|
|
|
|
frames_ordered = [frame_idx for frame_idx, _ in items] |
|
|
speed_series = [speeds.get(f, 0.0) for f in frames_ordered] |
|
|
|
|
|
baseline_window = min(10, len(frames_ordered) // 3 or 1) |
|
|
baseline_speeds = speed_series[:baseline_window] |
|
|
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 |
|
|
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 |
|
|
base_threshold = baseline_speed + 4.0 * speed_std |
|
|
if base_threshold < baseline_speed * 3.0: |
|
|
base_threshold = baseline_speed * 3.0 |
|
|
speed_threshold = max(base_threshold, 15.0) |
|
|
|
|
|
distance_dict: dict[int, float] = {} |
|
|
if smoothed: |
|
|
first_frame = frames_ordered[0] |
|
|
origin = smoothed[first_frame] |
|
|
for frame_idx, (sx, sy) in smoothed.items(): |
|
|
distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1]) |
|
|
|
|
|
areas_dict = {idx: areas.get(idx, 0.0) for idx in frames_ordered} |
|
|
initial_area = areas_dict.get(frames_ordered[0], 1.0) or 1.0 |
|
|
radius_estimate = math.sqrt(initial_area / math.pi) |
|
|
adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0)) |
|
|
|
|
|
sustain_frames = 3 |
|
|
holdout_frames = 8 |
|
|
area_window = 4 |
|
|
area_drop_ratio = 0.75 |
|
|
|
|
|
kalman_pos, kalman_speed, _ = _run_kalman_filter(items, dt) |
|
|
kalman_speed_series = [kalman_speed.get(f, 0.0) for f in frames_ordered] |
|
|
|
|
|
kick_frame: int | None = None |
|
|
for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window): |
|
|
speed = speed_series[idx] |
|
|
if speed < speed_threshold: |
|
|
continue |
|
|
sustain_ok = True |
|
|
for j in range(1, sustain_frames + 1): |
|
|
if idx + j >= len(frames_ordered): |
|
|
break |
|
|
if speed_series[idx + j] < speed_threshold * 0.7: |
|
|
sustain_ok = False |
|
|
break |
|
|
if not sustain_ok: |
|
|
continue |
|
|
|
|
|
area_pass = True |
|
|
current_area = areas_dict.get(frame) |
|
|
if current_area: |
|
|
prev_areas = [ |
|
|
areas_dict.get(f) |
|
|
for f in frames_ordered[max(0, idx - area_window):idx] |
|
|
if areas_dict.get(f) is not None |
|
|
] |
|
|
if prev_areas: |
|
|
median_prev = statistics.median(prev_areas) |
|
|
if median_prev > 0: |
|
|
ratio = current_area / median_prev |
|
|
if ratio > area_drop_ratio: |
|
|
area_pass = False |
|
|
if not area_pass and speed < speed_threshold * 1.2: |
|
|
continue |
|
|
|
|
|
future_slice = frames_ordered[idx: min(len(frames_ordered), idx + holdout_frames)] |
|
|
max_future_dist = 0.0 |
|
|
for future_frame in future_slice: |
|
|
dist = distance_dict.get(future_frame, 0.0) |
|
|
if dist > max_future_dist: |
|
|
max_future_dist = dist |
|
|
if max_future_dist < adaptive_return_distance: |
|
|
continue |
|
|
|
|
|
kick_frame = frame |
|
|
break |
|
|
|
|
|
state.yolo_smoothed_centers = smoothed |
|
|
state.yolo_speeds = speeds |
|
|
state.yolo_distance_from_start = distance_dict |
|
|
state.yolo_threshold = speed_threshold |
|
|
state.yolo_baseline_speed = baseline_speed |
|
|
state.yolo_speed_std = speed_std |
|
|
state.yolo_kick_frames = frames_ordered |
|
|
state.yolo_kick_speeds = speed_series |
|
|
state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered] |
|
|
state.yolo_mask_area_proxy = [areas_dict.get(f, 0.0) for f in frames_ordered] |
|
|
state.yolo_kick_frame = kick_frame |
|
|
coverage = len(centers) / total if total else 0.0 |
|
|
if kick_frame is not None: |
|
|
state.yolo_status = f"✅ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%})." |
|
|
else: |
|
|
state.yolo_status = ( |
|
|
f"⚠️ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%}) but did not find a definitive kick." |
|
|
) |
|
|
state.kalman_centers[BALL_OBJECT_ID] = kalman_pos |
|
|
state.kalman_speeds[BALL_OBJECT_ID] = kalman_speed |
|
|
|
|
|
if kick_frame is not None: |
|
|
state.kick_frame = kick_frame |
|
|
_compute_sam_window_from_kick(state, kick_frame) |
|
|
else: |
|
|
state.sam_window = None |
|
|
|
|
|
|
|
|
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]: |
|
|
"""Generate a deterministic pastel RGB color for a given object id. |
|
|
|
|
|
Uses golden ratio to distribute hues; low-medium saturation, high value. |
|
|
""" |
|
|
golden_ratio_conjugate = 0.61803398875 |
|
|
|
|
|
hue = (obj_id * golden_ratio_conjugate) % 1.0 |
|
|
saturation = 0.45 |
|
|
value = 1.0 |
|
|
r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value) |
|
|
return int(r_f * 255), int(g_f * 255), int(b_f * 255) |
|
|
|
|
|
|
|
|
def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]: |
|
|
"""Load video frames as PIL Images using transformers.video_utils if available, |
|
|
otherwise fall back to OpenCV. Returns (frames, info). |
|
|
""" |
|
|
|
|
|
cap = cv2.VideoCapture(video_path_or_url) |
|
|
frames = [] |
|
|
print("loading video frames") |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
|
|
fps_val = cap.get(cv2.CAP_PROP_FPS) |
|
|
cap.release() |
|
|
print("loaded video frames") |
|
|
info = { |
|
|
"num_frames": len(frames), |
|
|
"fps": float(fps_val) if fps_val and fps_val > 0 else None, |
|
|
} |
|
|
return frames, info |
|
|
|
|
|
|
|
|
def overlay_masks_on_frame( |
|
|
frame: Image.Image, |
|
|
masks_per_object: dict[int, np.ndarray], |
|
|
color_by_obj: dict[int, tuple[int, int, int]], |
|
|
alpha: float = 0.5, |
|
|
) -> Image.Image: |
|
|
"""Overlay per-object soft masks onto the RGB frame. |
|
|
|
|
|
masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1] |
|
|
color_by_obj: mapping of obj_id -> (R, G, B) |
|
|
""" |
|
|
base = np.array(frame).astype(np.float32) / 255.0 |
|
|
height, width = base.shape[:2] |
|
|
overlay = base.copy() |
|
|
|
|
|
for obj_id, mask in masks_per_object.items(): |
|
|
if mask is None: |
|
|
continue |
|
|
if mask.dtype != np.float32: |
|
|
mask = mask.astype(np.float32) |
|
|
|
|
|
if mask.ndim == 3: |
|
|
mask = mask.squeeze() |
|
|
mask = np.clip(mask, 0.0, 1.0) |
|
|
color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0 |
|
|
|
|
|
a = alpha |
|
|
m = mask[..., None] |
|
|
overlay = (1.0 - a * m) * overlay + (a * m) * color |
|
|
|
|
|
out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) |
|
|
return Image.fromarray(out) |
|
|
|
|
|
|
|
|
def get_device_and_dtype() -> tuple[str, torch.dtype]: |
|
|
device = "cpu" |
|
|
dtype = torch.bfloat16 |
|
|
return device, dtype |
|
|
|
|
|
|
|
|
class AppState: |
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.video_frames: list[Image.Image] = [] |
|
|
self.inference_session = None |
|
|
self.model: Optional[AutoModel] = None |
|
|
self.processor: Optional[Sam2VideoProcessor] = None |
|
|
self.device: str = "cpu" |
|
|
self.dtype: torch.dtype = torch.bfloat16 |
|
|
self.video_fps: float | None = None |
|
|
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {} |
|
|
self.color_by_obj: dict[int, tuple[int, int, int]] = {} |
|
|
self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {} |
|
|
self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {} |
|
|
|
|
|
self.composited_frames: dict[int, Image.Image] = {} |
|
|
|
|
|
self.current_frame_idx: int = 0 |
|
|
self.current_obj_id: int = 1 |
|
|
self.current_label: str = "positive" |
|
|
self.current_clear_old: bool = True |
|
|
self.current_prompt_type: str = "Points" |
|
|
self.pending_box_start: tuple[int, int] | None = None |
|
|
self.pending_box_start_frame_idx: int | None = None |
|
|
self.pending_box_start_obj_id: int | None = None |
|
|
self.is_switching_model: bool = False |
|
|
self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {} |
|
|
self.mask_areas: dict[int, dict[int, float]] = {} |
|
|
self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {} |
|
|
self.ball_speeds: dict[int, dict[int, float]] = {} |
|
|
self.distance_from_start: dict[int, dict[int, float]] = {} |
|
|
self.direction_change: dict[int, dict[int, float]] = {} |
|
|
self.kick_frame: int | None = None |
|
|
self.kick_debug_frames: list[int] = [] |
|
|
self.kick_debug_speeds: list[float] = [] |
|
|
self.kick_debug_threshold: float | None = None |
|
|
self.kick_debug_baseline: float | None = None |
|
|
self.kick_debug_speed_std: float | None = None |
|
|
self.kick_debug_area: list[float] = [] |
|
|
self.kick_debug_kick_frame: int | None = None |
|
|
self.kick_debug_distance: list[float] = [] |
|
|
self.kick_debug_kalman_speeds: list[float] = [] |
|
|
self.kalman_centers: dict[int, dict[int, tuple[float, float]]] = {} |
|
|
self.kalman_speeds: dict[int, dict[int, float]] = {} |
|
|
self.kalman_residuals: dict[int, dict[int, float]] = {} |
|
|
self.min_impact_speed_kmh: float = 20.0 |
|
|
self.goal_distance_m: float = 18.0 |
|
|
self.impact_frame: int | None = None |
|
|
self.impact_debug_frames: list[int] = [] |
|
|
self.impact_debug_innovation: list[float] = [] |
|
|
self.impact_debug_innovation_threshold: float | None = None |
|
|
self.impact_debug_direction: list[float] = [] |
|
|
self.impact_debug_direction_threshold: float | None = None |
|
|
self.impact_debug_speed_kmh: list[float] = [] |
|
|
self.impact_debug_speed_threshold_px: float | None = None |
|
|
self.impact_meters_per_px: float | None = None |
|
|
|
|
|
self.model_repo_key: str = "tiny" |
|
|
self.model_repo_id: str | None = None |
|
|
self.session_repo_id: str | None = None |
|
|
self.player_obj_id: int | None = None |
|
|
self.player_detection_frame: int | None = None |
|
|
self.player_detection_conf: float | None = None |
|
|
|
|
|
self.yolo_ball_centers: dict[int, tuple[float, float]] = {} |
|
|
self.yolo_ball_boxes: dict[int, tuple[int, int, int, int]] = {} |
|
|
self.yolo_ball_conf: dict[int, float] = {} |
|
|
self.yolo_smoothed_centers: dict[int, tuple[float, float]] = {} |
|
|
self.yolo_speeds: dict[int, float] = {} |
|
|
self.yolo_distance_from_start: dict[int, float] = {} |
|
|
self.yolo_threshold: float | None = None |
|
|
self.yolo_baseline_speed: float | None = None |
|
|
self.yolo_speed_std: float | None = None |
|
|
self.yolo_kick_frame: int | None = None |
|
|
self.yolo_status: str = "" |
|
|
self.yolo_kick_frames: list[int] = [] |
|
|
self.yolo_kick_speeds: list[float] = [] |
|
|
self.yolo_kick_distance: list[float] = [] |
|
|
self.yolo_mask_area_proxy: list[float] = [] |
|
|
self.yolo_initial_frame: int | None = None |
|
|
|
|
|
self.sam_window: tuple[int, int] | None = None |
|
|
|
|
|
def __repr__(self): |
|
|
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})" |
|
|
|
|
|
@property |
|
|
def num_frames(self) -> int: |
|
|
return len(self.video_frames) |
|
|
|
|
|
|
|
|
def _model_repo_from_key(key: str) -> str: |
|
|
mapping = { |
|
|
"tiny": "facebook/sam2.1-hiera-tiny", |
|
|
"small": "facebook/sam2.1-hiera-small", |
|
|
"base_plus": "facebook/sam2.1-hiera-base-plus", |
|
|
"large": "facebook/sam2.1-hiera-large", |
|
|
} |
|
|
return mapping.get(key, mapping["base_plus"]) |
|
|
|
|
|
|
|
|
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]: |
|
|
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) |
|
|
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None: |
|
|
if GLOBAL_STATE.model_repo_id == desired_repo: |
|
|
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype |
|
|
|
|
|
GLOBAL_STATE.model = None |
|
|
GLOBAL_STATE.processor = None |
|
|
print(f"Loading model from {desired_repo}") |
|
|
device, dtype = get_device_and_dtype() |
|
|
|
|
|
model = AutoModel.from_pretrained(desired_repo) |
|
|
processor = Sam2VideoProcessor.from_pretrained(desired_repo) |
|
|
model.to(device, dtype=dtype) |
|
|
|
|
|
GLOBAL_STATE.model = model |
|
|
GLOBAL_STATE.processor = processor |
|
|
GLOBAL_STATE.device = device |
|
|
GLOBAL_STATE.dtype = dtype |
|
|
GLOBAL_STATE.model_repo_id = desired_repo |
|
|
|
|
|
|
|
|
def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None: |
|
|
"""Ensure the model/processor match the selected repo and inference_session exists. |
|
|
If a video is already loaded, re-initialize the inference session when needed. |
|
|
""" |
|
|
load_model_if_needed(GLOBAL_STATE) |
|
|
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) |
|
|
if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo: |
|
|
if GLOBAL_STATE.video_frames: |
|
|
|
|
|
GLOBAL_STATE.masks_by_frame.clear() |
|
|
GLOBAL_STATE.clicks_by_frame_obj.clear() |
|
|
GLOBAL_STATE.boxes_by_frame_obj.clear() |
|
|
GLOBAL_STATE.composited_frames.clear() |
|
|
GLOBAL_STATE.inference_session = None |
|
|
GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session( |
|
|
inference_device=GLOBAL_STATE.device, |
|
|
video_storage_device="cpu", |
|
|
dtype=GLOBAL_STATE.dtype, |
|
|
) |
|
|
GLOBAL_STATE.session_repo_id = desired_repo |
|
|
|
|
|
|
|
|
def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]: |
|
|
"""Gradio handler: load video, init session, return state, slider bounds, and first frame.""" |
|
|
|
|
|
GLOBAL_STATE.video_frames = [] |
|
|
GLOBAL_STATE.inference_session = None |
|
|
GLOBAL_STATE.masks_by_frame = {} |
|
|
GLOBAL_STATE.color_by_obj = {} |
|
|
GLOBAL_STATE.ball_centers = {} |
|
|
GLOBAL_STATE.mask_areas = {} |
|
|
GLOBAL_STATE.smoothed_centers = {} |
|
|
GLOBAL_STATE.ball_speeds = {} |
|
|
GLOBAL_STATE.distance_from_start = {} |
|
|
GLOBAL_STATE.direction_change = {} |
|
|
GLOBAL_STATE.kick_frame = None |
|
|
GLOBAL_STATE.kalman_centers = {} |
|
|
GLOBAL_STATE.kalman_speeds = {} |
|
|
GLOBAL_STATE.kalman_residuals = {} |
|
|
GLOBAL_STATE.kick_debug_kalman_speeds = [] |
|
|
GLOBAL_STATE.kick_debug_frames = [] |
|
|
GLOBAL_STATE.kick_debug_speeds = [] |
|
|
GLOBAL_STATE.kick_debug_threshold = None |
|
|
GLOBAL_STATE.kick_debug_baseline = None |
|
|
GLOBAL_STATE.kick_debug_speed_std = None |
|
|
GLOBAL_STATE.kick_debug_area = [] |
|
|
GLOBAL_STATE.kick_debug_kick_frame = None |
|
|
GLOBAL_STATE.kick_debug_distance = [] |
|
|
GLOBAL_STATE.impact_frame = None |
|
|
GLOBAL_STATE.impact_debug_frames = [] |
|
|
GLOBAL_STATE.impact_debug_innovation = [] |
|
|
GLOBAL_STATE.impact_debug_innovation_threshold = None |
|
|
GLOBAL_STATE.impact_debug_direction = [] |
|
|
GLOBAL_STATE.impact_debug_direction_threshold = None |
|
|
GLOBAL_STATE.impact_debug_speed_kmh = [] |
|
|
GLOBAL_STATE.impact_debug_speed_threshold_px = None |
|
|
GLOBAL_STATE.impact_meters_per_px = None |
|
|
GLOBAL_STATE.yolo_ball_centers = {} |
|
|
GLOBAL_STATE.yolo_ball_boxes = {} |
|
|
GLOBAL_STATE.yolo_ball_conf = {} |
|
|
GLOBAL_STATE.yolo_smoothed_centers = {} |
|
|
GLOBAL_STATE.yolo_speeds = {} |
|
|
GLOBAL_STATE.yolo_distance_from_start = {} |
|
|
GLOBAL_STATE.yolo_threshold = None |
|
|
GLOBAL_STATE.yolo_baseline_speed = None |
|
|
GLOBAL_STATE.yolo_speed_std = None |
|
|
GLOBAL_STATE.yolo_kick_frame = None |
|
|
GLOBAL_STATE.yolo_status = "" |
|
|
GLOBAL_STATE.yolo_kick_frames = [] |
|
|
GLOBAL_STATE.yolo_kick_speeds = [] |
|
|
GLOBAL_STATE.yolo_kick_distance = [] |
|
|
GLOBAL_STATE.yolo_mask_area_proxy = [] |
|
|
GLOBAL_STATE.yolo_initial_frame = None |
|
|
GLOBAL_STATE.sam_window = None |
|
|
GLOBAL_STATE.player_obj_id = None |
|
|
GLOBAL_STATE.player_detection_frame = None |
|
|
GLOBAL_STATE.player_detection_conf = None |
|
|
GLOBAL_STATE.yolo_ball_centers = {} |
|
|
GLOBAL_STATE.yolo_ball_boxes = {} |
|
|
GLOBAL_STATE.yolo_ball_conf = {} |
|
|
GLOBAL_STATE.yolo_smoothed_centers = {} |
|
|
GLOBAL_STATE.yolo_speeds = {} |
|
|
GLOBAL_STATE.yolo_distance_from_start = {} |
|
|
GLOBAL_STATE.yolo_threshold = None |
|
|
GLOBAL_STATE.yolo_baseline_speed = None |
|
|
GLOBAL_STATE.yolo_speed_std = None |
|
|
GLOBAL_STATE.yolo_kick_frame = None |
|
|
GLOBAL_STATE.yolo_status = "" |
|
|
GLOBAL_STATE.yolo_kick_frames = [] |
|
|
GLOBAL_STATE.yolo_kick_speeds = [] |
|
|
GLOBAL_STATE.yolo_kick_distance = [] |
|
|
GLOBAL_STATE.yolo_mask_area_proxy = [] |
|
|
GLOBAL_STATE.yolo_initial_frame = None |
|
|
GLOBAL_STATE.sam_window = None |
|
|
|
|
|
load_model_if_needed(GLOBAL_STATE) |
|
|
|
|
|
|
|
|
video_path: Optional[str] = None |
|
|
if isinstance(video, dict): |
|
|
video_path = video.get("name") or video.get("path") or video.get("data") |
|
|
elif isinstance(video, str): |
|
|
video_path = video |
|
|
else: |
|
|
video_path = None |
|
|
|
|
|
if not video_path: |
|
|
raise gr.Error("Invalid video input.") |
|
|
|
|
|
frames, info = try_load_video_frames(video_path) |
|
|
if len(frames) == 0: |
|
|
raise gr.Error("No frames could be loaded from the video.") |
|
|
|
|
|
|
|
|
MAX_SECONDS = 8.0 |
|
|
trimmed_note = "" |
|
|
fps_in = info.get("fps") |
|
|
max_frames_allowed = int(MAX_SECONDS * fps_in) |
|
|
if len(frames) > max_frames_allowed: |
|
|
frames = frames[:max_frames_allowed] |
|
|
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" |
|
|
if isinstance(info, dict): |
|
|
info["num_frames"] = len(frames) |
|
|
GLOBAL_STATE.video_frames = frames |
|
|
|
|
|
GLOBAL_STATE.video_fps = float(fps_in) |
|
|
|
|
|
inference_session = GLOBAL_STATE.processor.init_video_session( |
|
|
inference_device=GLOBAL_STATE.device, |
|
|
video_storage_device="cpu", |
|
|
dtype=GLOBAL_STATE.dtype, |
|
|
) |
|
|
GLOBAL_STATE.inference_session = inference_session |
|
|
|
|
|
first_frame = frames[0] |
|
|
max_idx = len(frames) - 1 |
|
|
status = ( |
|
|
f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. " |
|
|
f"Device: {GLOBAL_STATE.device}, dtype: bfloat16" |
|
|
) |
|
|
return GLOBAL_STATE, 0, max_idx, first_frame, status |
|
|
|
|
|
|
|
|
def _speed_to_color(ratio: float) -> tuple[int, int, int]: |
|
|
ratio = float(np.clip(ratio, 0.0, 1.0)) |
|
|
gradient = [ |
|
|
(255, 0, 0), |
|
|
(255, 165, 0), |
|
|
(255, 255, 0), |
|
|
(0, 255, 0), |
|
|
] |
|
|
segment = ratio * (len(gradient) - 1) |
|
|
idx = int(segment) |
|
|
frac = segment - idx |
|
|
if idx >= len(gradient) - 1: |
|
|
return gradient[-1] |
|
|
c1 = np.array(gradient[idx], dtype=float) |
|
|
c2 = np.array(gradient[idx + 1], dtype=float) |
|
|
blended = (1 - frac) * c1 + frac * c2 |
|
|
return tuple(int(v) for v in blended) |
|
|
|
|
|
|
|
|
def _angle_between(v1: tuple[float, float], v2: tuple[float, float]) -> float: |
|
|
x1, y1 = v1 |
|
|
x2, y2 = v2 |
|
|
mag1 = math.hypot(x1, y1) |
|
|
mag2 = math.hypot(x2, y2) |
|
|
if mag1 < 1e-6 or mag2 < 1e-6: |
|
|
return 0.0 |
|
|
cos_val = (x1 * x2 + y1 * y2) / (mag1 * mag2) |
|
|
cos_val = max(-1.0, min(1.0, cos_val)) |
|
|
return math.degrees(math.acos(cos_val)) |
|
|
|
|
|
|
|
|
DISPLAY_MIN_WIDTH = 640 |
|
|
DISPLAY_MAX_WIDTH = 1280 |
|
|
|
|
|
|
|
|
def _maybe_upscale_for_display(image: Image.Image) -> Image.Image: |
|
|
if image is None: |
|
|
return image |
|
|
original_width, original_height = image.size |
|
|
if original_width <= 0 or original_height <= 0: |
|
|
return image |
|
|
target_width = original_width |
|
|
if original_width < DISPLAY_MIN_WIDTH: |
|
|
target_width = DISPLAY_MIN_WIDTH |
|
|
elif original_width > DISPLAY_MAX_WIDTH: |
|
|
target_width = DISPLAY_MAX_WIDTH |
|
|
if target_width == original_width: |
|
|
return image |
|
|
scale = target_width / float(original_width) |
|
|
target_height = int(round(original_height * scale)) |
|
|
return image.resize((target_width, target_height), Image.BILINEAR) |
|
|
|
|
|
|
|
|
def _annotate_frame_index(image: Image.Image, frame_idx: int) -> Image.Image: |
|
|
if image is None: |
|
|
return image |
|
|
annotated = image.copy() |
|
|
draw = ImageDraw.Draw(annotated) |
|
|
text = f"Frame {frame_idx}" |
|
|
padding = 6 |
|
|
try: |
|
|
bbox = draw.textbbox((0, 0), text) |
|
|
text_w = bbox[2] - bbox[0] |
|
|
text_h = bbox[3] - bbox[1] |
|
|
except AttributeError: |
|
|
text_w, text_h = draw.textsize(text) |
|
|
x0, y0 = padding, padding |
|
|
x1, y1 = x0 + text_w + padding, y0 + text_h + padding |
|
|
draw.rectangle([(x0 - padding // 2, y0 - padding // 2), (x1, y1)], fill=(0, 0, 0)) |
|
|
draw.text((x0, y0), text, fill=(255, 255, 255)) |
|
|
return annotated |
|
|
|
|
|
|
|
|
def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> Image.Image: |
|
|
if state is None or state.video_frames is None or len(state.video_frames) == 0: |
|
|
return None |
|
|
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) |
|
|
frame = state.video_frames[frame_idx] |
|
|
masks = state.masks_by_frame.get(frame_idx, {}) |
|
|
out_img = frame |
|
|
|
|
|
if len(masks) != 0: |
|
|
if remove_bg: |
|
|
|
|
|
frame_np = np.array(frame) |
|
|
|
|
|
combined_mask = np.zeros((frame_np.shape[0], frame_np.shape[1]), dtype=np.float32) |
|
|
for obj_id, mask in masks.items(): |
|
|
if mask is not None: |
|
|
if mask.dtype != np.float32: |
|
|
mask = mask.astype(np.float32) |
|
|
if mask.ndim == 3: |
|
|
mask = mask.squeeze() |
|
|
combined_mask = np.maximum(combined_mask, np.clip(mask, 0.0, 1.0)) |
|
|
|
|
|
|
|
|
mask_3d = np.repeat(combined_mask[:, :, np.newaxis], 3, axis=2) |
|
|
result_np = (frame_np * mask_3d).astype(np.uint8) |
|
|
out_img = Image.fromarray(result_np) |
|
|
else: |
|
|
|
|
|
out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65) |
|
|
|
|
|
|
|
|
clicks_map = state.clicks_by_frame_obj.get(frame_idx) |
|
|
if clicks_map: |
|
|
draw = ImageDraw.Draw(out_img) |
|
|
cross_half = 6 |
|
|
for obj_id, pts in clicks_map.items(): |
|
|
for x, y, lbl in pts: |
|
|
color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0) |
|
|
|
|
|
draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) |
|
|
|
|
|
draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) |
|
|
|
|
|
if ( |
|
|
state.pending_box_start is not None |
|
|
and state.pending_box_start_frame_idx == frame_idx |
|
|
and state.pending_box_start_obj_id is not None |
|
|
): |
|
|
draw = ImageDraw.Draw(out_img) |
|
|
x, y = state.pending_box_start |
|
|
cross_half = 6 |
|
|
color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255)) |
|
|
draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) |
|
|
draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) |
|
|
|
|
|
box_map = state.boxes_by_frame_obj.get(frame_idx) |
|
|
if box_map: |
|
|
draw = ImageDraw.Draw(out_img) |
|
|
for obj_id, boxes in box_map.items(): |
|
|
color = state.color_by_obj.get(obj_id, (255, 255, 255)) |
|
|
for x1, y1, x2, y2 in boxes: |
|
|
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2) |
|
|
|
|
|
if state.ball_centers: |
|
|
draw = ImageDraw.Draw(out_img) |
|
|
cross_half = 4 |
|
|
for obj_id, centers in state.ball_centers.items(): |
|
|
if not centers: |
|
|
continue |
|
|
raw_items = sorted(centers.items()) |
|
|
for _, (rx, ry) in raw_items: |
|
|
draw.line([(rx - cross_half, ry), (rx + cross_half, ry)], fill=(160, 160, 160), width=1) |
|
|
draw.line([(rx, ry - cross_half), (rx, ry + cross_half)], fill=(160, 160, 160), width=1) |
|
|
smooth_dict = state.smoothed_centers.get(obj_id, {}) |
|
|
if not smooth_dict: |
|
|
continue |
|
|
smooth_items = sorted(smooth_dict.items()) |
|
|
distances: list[float] = [] |
|
|
prev_center = None |
|
|
for _, (sx, sy) in smooth_items: |
|
|
if prev_center is None: |
|
|
distances.append(0.0) |
|
|
else: |
|
|
dx = sx - prev_center[0] |
|
|
dy = sy - prev_center[1] |
|
|
distances.append(float(np.hypot(dx, dy))) |
|
|
prev_center = (sx, sy) |
|
|
max_dist = max(distances[1:], default=0.0) |
|
|
color_by_frame: dict[int, tuple[int, int, int]] = {} |
|
|
for (f_idx, _), dist in zip(smooth_items, distances): |
|
|
ratio = dist / max_dist if max_dist > 0 else 0.0 |
|
|
color_by_frame[f_idx] = _speed_to_color(ratio) |
|
|
for f_idx, (sx, sy) in reversed(smooth_items): |
|
|
highlight = (f_idx == frame_idx) |
|
|
color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0)) |
|
|
line_width = 1 if not highlight else 2 |
|
|
draw.line([(sx - cross_half, sy), (sx + cross_half, sy)], fill=color, width=line_width) |
|
|
draw.line([(sx, sy - cross_half), (sx, sy + cross_half)], fill=color, width=line_width) |
|
|
|
|
|
state.composited_frames[frame_idx] = out_img |
|
|
return out_img |
|
|
|
|
|
|
|
|
def update_frame_display(state: AppState, frame_idx: int) -> Image.Image: |
|
|
if state is None or state.video_frames is None or len(state.video_frames) == 0: |
|
|
return None |
|
|
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) |
|
|
|
|
|
cached = state.composited_frames.get(frame_idx) |
|
|
if cached is not None: |
|
|
return _maybe_upscale_for_display(cached) |
|
|
composed = compose_frame(state, frame_idx) |
|
|
return _maybe_upscale_for_display(composed) |
|
|
|
|
|
|
|
|
def _ensure_color_for_obj(state: AppState, obj_id: int): |
|
|
if obj_id not in state.color_by_obj: |
|
|
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id) |
|
|
|
|
|
|
|
|
def _compute_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None: |
|
|
if mask is None: |
|
|
return None |
|
|
mask_np = np.array(mask) |
|
|
if mask_np.ndim == 3: |
|
|
mask_np = mask_np.squeeze() |
|
|
if mask_np.size == 0: |
|
|
return None |
|
|
mask_float = np.clip(mask_np, 0.0, 1.0).astype(np.float32) |
|
|
moments = cv2.moments(mask_float) |
|
|
if moments["m00"] == 0: |
|
|
return None |
|
|
cx = int(moments["m10"] / moments["m00"]) |
|
|
cy = int(moments["m01"] / moments["m00"]) |
|
|
return cx, cy |
|
|
|
|
|
|
|
|
def _update_centroids_for_frame(state: AppState, frame_idx: int): |
|
|
if state is None: |
|
|
return |
|
|
masks = state.masks_by_frame.get(int(frame_idx), {}) |
|
|
seen_obj_ids: set[int] = set() |
|
|
for obj_id, mask in masks.items(): |
|
|
centroid = _compute_mask_centroid(mask) |
|
|
centers = state.ball_centers.setdefault(int(obj_id), {}) |
|
|
if centroid is not None: |
|
|
centers[int(frame_idx)] = centroid |
|
|
else: |
|
|
centers.pop(int(frame_idx), None) |
|
|
seen_obj_ids.add(int(obj_id)) |
|
|
_ensure_color_for_obj(state, int(obj_id)) |
|
|
mask_np = np.array(mask) |
|
|
if mask_np.ndim == 3: |
|
|
mask_np = mask_np.squeeze() |
|
|
mask_np = np.clip(mask_np, 0.0, 1.0) |
|
|
area = float(np.count_nonzero(mask_np > 0.3)) |
|
|
areas = state.mask_areas.setdefault(int(obj_id), {}) |
|
|
areas[int(frame_idx)] = area |
|
|
|
|
|
for obj_id, centers in state.ball_centers.items(): |
|
|
if obj_id not in seen_obj_ids: |
|
|
centers.pop(int(frame_idx), None) |
|
|
for obj_id, areas in state.mask_areas.items(): |
|
|
if obj_id not in seen_obj_ids: |
|
|
areas.pop(int(frame_idx), None) |
|
|
_recompute_motion_metrics(state) |
|
|
|
|
|
|
|
|
def _run_kalman_filter( |
|
|
ordered_items: list[tuple[int, tuple[float, float]]], |
|
|
base_dt: float, |
|
|
) -> tuple[dict[int, tuple[float, float]], dict[int, float], dict[int, float]]: |
|
|
if not ordered_items: |
|
|
return {}, {}, {} |
|
|
|
|
|
H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=float) |
|
|
R = np.eye(2, dtype=float) * 25.0 |
|
|
|
|
|
state_vec = np.array( |
|
|
[ordered_items[0][1][0], ordered_items[0][1][1], 0.0, 0.0], dtype=float |
|
|
) |
|
|
P = np.eye(4, dtype=float) * 50.0 |
|
|
|
|
|
positions: dict[int, tuple[float, float]] = {} |
|
|
speeds: dict[int, float] = {} |
|
|
residuals: dict[int, float] = {} |
|
|
|
|
|
prev_frame = ordered_items[0][0] |
|
|
|
|
|
for frame_idx, (cx, cy) in ordered_items: |
|
|
frame_delta = max(1, frame_idx - prev_frame) if frame_idx != prev_frame else 1 |
|
|
dt = frame_delta * base_dt |
|
|
F = np.array( |
|
|
[ |
|
|
[1, 0, dt, 0], |
|
|
[0, 1, 0, dt], |
|
|
[0, 0, 1, 0], |
|
|
[0, 0, 0, 1], |
|
|
], |
|
|
dtype=float, |
|
|
) |
|
|
q = 0.5 * dt**2 |
|
|
Q = np.array( |
|
|
[ |
|
|
[q, 0, dt, 0], |
|
|
[0, q, 0, dt], |
|
|
[dt, 0, 1, 0], |
|
|
[0, dt, 0, 1], |
|
|
], |
|
|
dtype=float, |
|
|
) * 0.05 |
|
|
|
|
|
state_vec = F @ state_vec |
|
|
P = F @ P @ F.T + Q |
|
|
|
|
|
z = np.array([cx, cy], dtype=float) |
|
|
innovation = z - H @ state_vec |
|
|
S = H @ P @ H.T + R |
|
|
K = P @ H.T @ np.linalg.inv(S) |
|
|
state_vec = state_vec + K @ innovation |
|
|
P = (np.eye(4) - K @ H) @ P |
|
|
|
|
|
positions[frame_idx] = (state_vec[0], state_vec[1]) |
|
|
speeds[frame_idx] = float(math.hypot(state_vec[2], state_vec[3])) |
|
|
residuals[frame_idx] = float(math.hypot(innovation[0], innovation[1])) |
|
|
|
|
|
prev_frame = frame_idx |
|
|
|
|
|
return positions, speeds, residuals |
|
|
|
|
|
|
|
|
def _build_kick_plot(state: AppState): |
|
|
fig = go.Figure() |
|
|
if state is None or not state.kick_debug_frames or not state.kick_debug_speeds: |
|
|
fig.update_layout( |
|
|
title="Kick & impact diagnostics", |
|
|
xaxis_title="Frame", |
|
|
yaxis_title="Speed (px/s)", |
|
|
) |
|
|
return fig |
|
|
|
|
|
frames = state.kick_debug_frames |
|
|
speeds = state.kick_debug_speeds |
|
|
areas = state.kick_debug_area if state.kick_debug_area else [0.0] * len(frames) |
|
|
threshold = state.kick_debug_threshold or 0.0 |
|
|
baseline = state.kick_debug_baseline or 0.0 |
|
|
kick_frame = state.kick_debug_kick_frame |
|
|
distance = state.kick_debug_distance if state.kick_debug_distance else [0.0] * len(frames) |
|
|
impact_frames = state.impact_debug_frames if state.impact_debug_frames else frames |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=speeds, |
|
|
mode="lines+markers", |
|
|
name="Speed (px/s)", |
|
|
line=dict(color="#1f77b4"), |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[frames[0], frames[-1]], |
|
|
y=[threshold, threshold], |
|
|
mode="lines", |
|
|
name="Adaptive threshold", |
|
|
line=dict(color="#d62728", dash="dash"), |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[frames[0], frames[-1]], |
|
|
y=[baseline, baseline], |
|
|
mode="lines", |
|
|
name="Baseline speed", |
|
|
line=dict(color="#ff7f0e", dash="dot"), |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=areas, |
|
|
mode="lines", |
|
|
name="Mask area", |
|
|
line=dict(color="#2ca02c"), |
|
|
yaxis="y2", |
|
|
) |
|
|
) |
|
|
max_primary = max( |
|
|
max(speeds) if speeds else 0.0, |
|
|
threshold, |
|
|
baseline, |
|
|
max(state.kick_debug_kalman_speeds) if state.kick_debug_kalman_speeds else 0.0, |
|
|
state.impact_debug_innovation_threshold or 0.0, |
|
|
state.impact_debug_direction_threshold or 0.0, |
|
|
state.impact_debug_speed_threshold_px or 0.0, |
|
|
1.0, |
|
|
) |
|
|
max_distance = max(distance) if distance else 0.0 |
|
|
if max_distance > 0 and max_primary > 0: |
|
|
distance_scaled = [d * (max_primary / max_distance) for d in distance] |
|
|
else: |
|
|
distance_scaled = distance |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=distance_scaled, |
|
|
mode="lines", |
|
|
name="Distance from start (scaled)", |
|
|
line=dict(color="#9467bd"), |
|
|
) |
|
|
) |
|
|
if state.kick_debug_kalman_speeds: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=state.kick_debug_kalman_speeds, |
|
|
mode="lines", |
|
|
name="Kalman speed", |
|
|
line=dict(color="#8c564b"), |
|
|
) |
|
|
) |
|
|
if state.impact_debug_innovation: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=impact_frames, |
|
|
y=state.impact_debug_innovation, |
|
|
mode="lines", |
|
|
name="Kalman innovation", |
|
|
line=dict(color="#17becf"), |
|
|
) |
|
|
) |
|
|
max_primary = max(max_primary, max(state.impact_debug_innovation)) |
|
|
if ( |
|
|
state.impact_debug_innovation_threshold is not None |
|
|
and impact_frames |
|
|
and len(impact_frames) >= 2 |
|
|
): |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[impact_frames[0], impact_frames[-1]], |
|
|
y=[ |
|
|
state.impact_debug_innovation_threshold, |
|
|
state.impact_debug_innovation_threshold, |
|
|
], |
|
|
mode="lines", |
|
|
name="Innovation threshold", |
|
|
line=dict(color="#17becf", dash="dash"), |
|
|
) |
|
|
) |
|
|
max_primary = max(max_primary, state.impact_debug_innovation_threshold or 0.0) |
|
|
if state.impact_debug_direction: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=impact_frames, |
|
|
y=state.impact_debug_direction, |
|
|
mode="lines", |
|
|
name="Direction change (deg)", |
|
|
line=dict(color="#bcbd22"), |
|
|
) |
|
|
) |
|
|
max_primary = max(max_primary, max(state.impact_debug_direction)) |
|
|
if ( |
|
|
state.impact_debug_direction_threshold is not None |
|
|
and impact_frames |
|
|
and len(impact_frames) >= 2 |
|
|
): |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[impact_frames[0], impact_frames[-1]], |
|
|
y=[ |
|
|
state.impact_debug_direction_threshold, |
|
|
state.impact_debug_direction_threshold, |
|
|
], |
|
|
mode="lines", |
|
|
name="Direction threshold (deg)", |
|
|
line=dict(color="#bcbd22", dash="dot"), |
|
|
) |
|
|
) |
|
|
max_primary = max(max_primary, state.impact_debug_direction_threshold or 0.0) |
|
|
if state.impact_debug_speed_threshold_px: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[frames[0], frames[-1]], |
|
|
y=[state.impact_debug_speed_threshold_px] * 2, |
|
|
mode="lines", |
|
|
name="Min impact speed (px/s)", |
|
|
line=dict(color="#b82e2e", dash="dot"), |
|
|
) |
|
|
) |
|
|
max_primary = max(max_primary, state.impact_debug_speed_threshold_px or 0.0) |
|
|
if kick_frame is not None: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[kick_frame, kick_frame], |
|
|
y=[0, max_primary * 1.05], |
|
|
mode="lines", |
|
|
name="Detected kick", |
|
|
line=dict(color="#ff00ff", dash="solid", width=3), |
|
|
) |
|
|
) |
|
|
impact_frame = state.impact_frame |
|
|
if impact_frame is not None: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[impact_frame, impact_frame], |
|
|
y=[0, max_primary * 1.05], |
|
|
mode="lines", |
|
|
name="Detected impact", |
|
|
line=dict(color="#ff1493", width=3), |
|
|
) |
|
|
) |
|
|
fig.update_layout( |
|
|
title="Kick & impact diagnostics", |
|
|
xaxis_title="Frame", |
|
|
yaxis_title="Speed (px/s)", |
|
|
yaxis=dict(side="left"), |
|
|
yaxis2=dict( |
|
|
title="Mask area / Distance / Direction", |
|
|
overlaying="y", |
|
|
side="right", |
|
|
showgrid=False, |
|
|
), |
|
|
legend=dict(orientation="h"), |
|
|
margin=dict(t=40, l=40, r=40, b=40), |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
def _ensure_ball_prompt_from_yolo(state: AppState): |
|
|
if ( |
|
|
state is None |
|
|
or state.inference_session is None |
|
|
or not state.yolo_ball_centers |
|
|
): |
|
|
return |
|
|
|
|
|
for frame_clicks in state.clicks_by_frame_obj.values(): |
|
|
if frame_clicks.get(BALL_OBJECT_ID): |
|
|
return |
|
|
anchor_frame = state.yolo_initial_frame |
|
|
if anchor_frame is None and state.yolo_ball_centers: |
|
|
anchor_frame = min(state.yolo_ball_centers.keys()) |
|
|
if anchor_frame is None or anchor_frame >= state.num_frames: |
|
|
return |
|
|
center = state.yolo_ball_centers.get(anchor_frame) |
|
|
if center is None: |
|
|
return |
|
|
x_center, y_center = center |
|
|
frame_width, frame_height = state.video_frames[anchor_frame].size |
|
|
x_center = int(np.clip(round(x_center), 0, frame_width - 1)) |
|
|
y_center = int(np.clip(round(y_center), 0, frame_height - 1)) |
|
|
event = SimpleNamespace( |
|
|
index=(x_center, y_center), |
|
|
value={"x": x_center, "y": y_center}, |
|
|
) |
|
|
state.current_obj_id = BALL_OBJECT_ID |
|
|
state.current_label = "positive" |
|
|
state.current_frame_idx = anchor_frame |
|
|
on_image_click( |
|
|
update_frame_display(state, anchor_frame), |
|
|
state, |
|
|
anchor_frame, |
|
|
BALL_OBJECT_ID, |
|
|
"positive", |
|
|
False, |
|
|
event, |
|
|
) |
|
|
|
|
|
|
|
|
def _build_yolo_plot(state: AppState): |
|
|
fig = go.Figure() |
|
|
if state is None or not state.yolo_kick_frames or not state.yolo_kick_speeds: |
|
|
fig.update_layout( |
|
|
title="YOLO kick diagnostics", |
|
|
xaxis_title="Frame", |
|
|
yaxis_title="Speed (px/s)", |
|
|
) |
|
|
return fig |
|
|
|
|
|
frames = state.yolo_kick_frames |
|
|
speeds = state.yolo_kick_speeds |
|
|
distance = state.yolo_kick_distance if state.yolo_kick_distance else [0.0] * len(frames) |
|
|
areas = state.yolo_mask_area_proxy if state.yolo_mask_area_proxy else [0.0] * len(frames) |
|
|
threshold = state.yolo_threshold or 0.0 |
|
|
baseline = state.yolo_baseline_speed or 0.0 |
|
|
kick_frame = state.yolo_kick_frame |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=speeds, |
|
|
mode="lines+markers", |
|
|
name="YOLO speed", |
|
|
line=dict(color="#4caf50"), |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=[threshold] * len(frames), |
|
|
mode="lines", |
|
|
name="Adaptive threshold", |
|
|
line=dict(color="#ff9800", dash="dash"), |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=[baseline] * len(frames), |
|
|
mode="lines", |
|
|
name="Baseline speed", |
|
|
line=dict(color="#9e9e9e", dash="dot"), |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=distance, |
|
|
mode="lines", |
|
|
name="Distance from start", |
|
|
line=dict(color="#03a9f4"), |
|
|
yaxis="y2", |
|
|
) |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=frames, |
|
|
y=areas, |
|
|
mode="lines", |
|
|
name="Box area proxy", |
|
|
line=dict(color="#ab47bc", dash="dot"), |
|
|
yaxis="y2", |
|
|
) |
|
|
) |
|
|
|
|
|
if kick_frame is not None: |
|
|
fig.add_vline( |
|
|
x=kick_frame, |
|
|
line=dict(color="#e91e63", width=2), |
|
|
annotation_text=f"Kick {kick_frame}", |
|
|
annotation_position="top right", |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
title="YOLO kick diagnostics", |
|
|
xaxis=dict(title="Frame"), |
|
|
yaxis=dict(title="Speed (px/s)"), |
|
|
yaxis2=dict( |
|
|
title="Distance / Area", |
|
|
overlaying="y", |
|
|
side="right", |
|
|
showgrid=False, |
|
|
), |
|
|
legend=dict(orientation="h"), |
|
|
margin=dict(t=40, l=40, r=40, b=40), |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
def _format_impact_status(state: AppState) -> str: |
|
|
if state is None: |
|
|
return "Impact frame: not computed" |
|
|
if not state.impact_debug_frames: |
|
|
return "Impact frame: not computed" |
|
|
if state.impact_frame is None: |
|
|
return "Impact frame: not detected" |
|
|
frame = state.impact_frame |
|
|
time_part = "" |
|
|
if state.video_fps and state.video_fps > 1e-6: |
|
|
seconds = frame / state.video_fps |
|
|
time_part = f" (~{seconds:.2f}s)" |
|
|
speed_text = "" |
|
|
meters_per_px = state.impact_meters_per_px |
|
|
target_obj_id = getattr(state, "current_obj_id", 1) or 1 |
|
|
speed_px = state.kalman_speeds.get(int(target_obj_id), {}).get(frame, 0.0) |
|
|
if meters_per_px and meters_per_px > 0 and speed_px > 0: |
|
|
speed_kmh = speed_px * meters_per_px * 3.6 |
|
|
if speed_kmh > 0.1: |
|
|
speed_text = f", est. speed ≈ {speed_kmh:.1f} km/h" |
|
|
return f"Impact frame: {frame}{time_part}{speed_text}" |
|
|
|
|
|
|
|
|
def _format_kick_status(state: AppState) -> str: |
|
|
if state is None or not isinstance(state, AppState): |
|
|
return "Kick frame: not computed" |
|
|
frame = state.kick_frame |
|
|
if frame is None: |
|
|
frame = getattr(state, "kick_debug_kick_frame", None) |
|
|
if frame is None: |
|
|
if state.kick_debug_frames: |
|
|
return "Kick frame: not detected" |
|
|
return "Kick frame: not computed" |
|
|
if state.kick_frame is None and frame is not None: |
|
|
state.kick_frame = frame |
|
|
time_part = "" |
|
|
if state.video_fps and state.video_fps > 1e-6: |
|
|
time_part = f" (~{frame / state.video_fps:.2f}s)" |
|
|
return f"Kick frame ≈ {frame}{time_part}" |
|
|
|
|
|
|
|
|
def _ball_has_masks(state: AppState, target_obj_id: int = BALL_OBJECT_ID) -> bool: |
|
|
if state is None: |
|
|
return False |
|
|
for masks in state.masks_by_frame.values(): |
|
|
if target_obj_id in masks: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def _player_has_masks(state: AppState) -> bool: |
|
|
if state is None or state.player_obj_id is None: |
|
|
return False |
|
|
player_id = state.player_obj_id |
|
|
for masks in state.masks_by_frame.values(): |
|
|
if player_id in masks: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def _button_updates(state: AppState) -> tuple[Any, Any, Any]: |
|
|
yolo_ready = isinstance(state, AppState) and state.yolo_kick_frame is not None |
|
|
propagate_main_enabled = _ball_has_masks(state) or yolo_ready |
|
|
detect_player_enabled = yolo_ready |
|
|
propagate_player_enabled = _player_has_masks(state) |
|
|
return ( |
|
|
gr.update(interactive=propagate_main_enabled), |
|
|
gr.update(interactive=detect_player_enabled), |
|
|
gr.update(interactive=propagate_player_enabled), |
|
|
) |
|
|
|
|
|
|
|
|
def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1): |
|
|
centers = state.ball_centers.get(target_obj_id) |
|
|
if not centers or len(centers) < 3: |
|
|
state.smoothed_centers[target_obj_id] = {} |
|
|
state.ball_speeds[target_obj_id] = {} |
|
|
state.kick_frame = None |
|
|
state.kick_debug_frames = [] |
|
|
state.kick_debug_speeds = [] |
|
|
state.kick_debug_threshold = None |
|
|
state.kick_debug_baseline = None |
|
|
state.kick_debug_speed_std = None |
|
|
state.kick_debug_area = [] |
|
|
state.kick_debug_kick_frame = None |
|
|
state.kick_debug_distance = [] |
|
|
state.kalman_centers[target_obj_id] = {} |
|
|
state.kalman_speeds[target_obj_id] = {} |
|
|
state.kalman_residuals[target_obj_id] = {} |
|
|
state.kick_debug_kalman_speeds = [] |
|
|
state.distance_from_start[target_obj_id] = {} |
|
|
state.direction_change[target_obj_id] = {} |
|
|
state.impact_frame = None |
|
|
state.impact_debug_frames = [] |
|
|
state.impact_debug_innovation = [] |
|
|
state.impact_debug_innovation_threshold = None |
|
|
state.impact_debug_direction = [] |
|
|
state.impact_debug_direction_threshold = None |
|
|
state.impact_debug_speed_kmh = [] |
|
|
state.impact_debug_speed_threshold_px = None |
|
|
state.impact_meters_per_px = None |
|
|
return |
|
|
|
|
|
items = sorted(centers.items()) |
|
|
dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 |
|
|
alpha = 0.35 |
|
|
|
|
|
smoothed: dict[int, tuple[float, float]] = {} |
|
|
speeds: dict[int, float] = {} |
|
|
|
|
|
prev_frame = None |
|
|
prev_smooth = None |
|
|
for frame_idx, (cx, cy) in items: |
|
|
if prev_smooth is None: |
|
|
smooth_x, smooth_y = float(cx), float(cy) |
|
|
else: |
|
|
smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) |
|
|
smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) |
|
|
smoothed[frame_idx] = (smooth_x, smooth_y) |
|
|
if prev_smooth is None or prev_frame is None: |
|
|
speeds[frame_idx] = 0.0 |
|
|
else: |
|
|
frame_delta = max(1, frame_idx - prev_frame) |
|
|
time_delta = frame_delta * dt |
|
|
dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) |
|
|
speed = dist / time_delta if time_delta > 0 else dist |
|
|
speeds[frame_idx] = speed |
|
|
prev_smooth = (smooth_x, smooth_y) |
|
|
prev_frame = frame_idx |
|
|
|
|
|
state.smoothed_centers[target_obj_id] = smoothed |
|
|
state.ball_speeds[target_obj_id] = speeds |
|
|
|
|
|
if smoothed: |
|
|
first_frame = min(smoothed.keys()) |
|
|
origin = smoothed[first_frame] |
|
|
distance_dict: dict[int, float] = {} |
|
|
for frame_idx, (sx, sy) in smoothed.items(): |
|
|
distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1]) |
|
|
state.distance_from_start[target_obj_id] = distance_dict |
|
|
state.kick_debug_distance = [distance_dict.get(f, 0.0) for f in sorted(smoothed.keys())] |
|
|
|
|
|
kalman_pos, kalman_speed, kalman_res = _run_kalman_filter(items, dt) |
|
|
state.kalman_centers[target_obj_id] = kalman_pos |
|
|
state.kalman_speeds[target_obj_id] = kalman_speed |
|
|
state.kalman_residuals[target_obj_id] = kalman_res |
|
|
state.kick_frame = _detect_kick_frame(state, target_obj_id) |
|
|
state.impact_frame = _detect_impact_frame(state, target_obj_id) |
|
|
|
|
|
|
|
|
def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None: |
|
|
smoothed = state.smoothed_centers.get(target_obj_id, {}) |
|
|
speeds = state.ball_speeds.get(target_obj_id, {}) |
|
|
if len(smoothed) < 5: |
|
|
return None |
|
|
|
|
|
frames = sorted(smoothed.keys()) |
|
|
speed_series = [speeds.get(f, 0.0) for f in frames] |
|
|
|
|
|
baseline_window = min(10, len(frames) // 3 or 1) |
|
|
baseline_speeds = speed_series[:baseline_window] |
|
|
baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 |
|
|
speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 |
|
|
base_threshold = baseline_speed + 4.0 * speed_std |
|
|
if base_threshold < baseline_speed * 3.0: |
|
|
base_threshold = baseline_speed * 3.0 |
|
|
speed_threshold = max(base_threshold, 15.0) |
|
|
sustain_frames = 3 |
|
|
holdout_frames = 8 |
|
|
area_window = 4 |
|
|
area_drop_ratio = 0.75 |
|
|
areas_dict = state.mask_areas.get(target_obj_id, {}) |
|
|
initial_center = smoothed[frames[0]] |
|
|
initial_area = areas_dict.get(frames[0], 1.0) or 1.0 |
|
|
radius_estimate = math.sqrt(initial_area / math.pi) |
|
|
adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0)) |
|
|
|
|
|
state.kick_debug_frames = frames |
|
|
state.kick_debug_speeds = speed_series |
|
|
state.kick_debug_threshold = speed_threshold |
|
|
state.kick_debug_baseline = baseline_speed |
|
|
state.kick_debug_speed_std = speed_std |
|
|
state.kick_debug_area = [areas_dict.get(f, 0.0) for f in frames] |
|
|
state.kick_debug_distance = [ |
|
|
math.hypot(smoothed[f][0] - initial_center[0], smoothed[f][1] - initial_center[1]) |
|
|
for f in frames |
|
|
] |
|
|
kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {}) |
|
|
state.kick_debug_kalman_speeds = [kalman_speed_dict.get(f, 0.0) for f in frames] |
|
|
state.kick_debug_kick_frame = None |
|
|
|
|
|
for idx in range(baseline_window, len(frames)): |
|
|
frame = frames[idx] |
|
|
speed = speed_series[idx] |
|
|
if speed < speed_threshold: |
|
|
continue |
|
|
|
|
|
sustain_ok = True |
|
|
for j in range(1, sustain_frames + 1): |
|
|
if idx + j >= len(frames): |
|
|
break |
|
|
if speed_series[idx + j] < speed_threshold * 0.7: |
|
|
sustain_ok = False |
|
|
break |
|
|
if not sustain_ok: |
|
|
continue |
|
|
|
|
|
current_area = areas_dict.get(frame) |
|
|
area_pass = True |
|
|
if current_area: |
|
|
prev_areas = [ |
|
|
areas_dict.get(f) |
|
|
for f in frames[max(0, idx - area_window):idx] |
|
|
if areas_dict.get(f) is not None |
|
|
] |
|
|
if prev_areas: |
|
|
median_prev = statistics.median(prev_areas) |
|
|
if median_prev > 0: |
|
|
ratio = current_area / median_prev |
|
|
if ratio > area_drop_ratio: |
|
|
area_pass = False |
|
|
if not area_pass and speed < speed_threshold * 1.2: |
|
|
continue |
|
|
|
|
|
future_frames = frames[idx:min(len(frames), idx + holdout_frames)] |
|
|
max_future_dist = 0.0 |
|
|
for future_frame in future_frames: |
|
|
cx, cy = smoothed[future_frame] |
|
|
dist = math.hypot(cx - initial_center[0], cy - initial_center[1]) |
|
|
if dist > max_future_dist: |
|
|
max_future_dist = dist |
|
|
if max_future_dist < adaptive_return_distance: |
|
|
continue |
|
|
|
|
|
state.kick_debug_kick_frame = frame |
|
|
return frame |
|
|
|
|
|
state.kick_debug_kick_frame = None |
|
|
return None |
|
|
|
|
|
|
|
|
def _detect_impact_frame(state: AppState, target_obj_id: int) -> int | None: |
|
|
residuals = state.kalman_residuals.get(target_obj_id, {}) |
|
|
frames = sorted(residuals.keys()) |
|
|
state.impact_debug_frames = frames |
|
|
state.impact_debug_innovation = [residuals.get(f, 0.0) for f in frames] |
|
|
state.impact_debug_innovation_threshold = None |
|
|
state.impact_debug_direction = [] |
|
|
state.impact_debug_direction_threshold = None |
|
|
state.impact_debug_speed_kmh = [] |
|
|
state.impact_debug_speed_threshold_px = None |
|
|
state.impact_meters_per_px = None |
|
|
|
|
|
if not frames or state.kick_frame is None: |
|
|
state.impact_frame = None |
|
|
return None |
|
|
|
|
|
kalman_positions = state.kalman_centers.get(target_obj_id, {}) |
|
|
direction_dict: dict[int, float] = {} |
|
|
prev_pos: tuple[float, float] | None = None |
|
|
prev_vec: tuple[float, float] | None = None |
|
|
for frame in frames: |
|
|
pos = kalman_positions.get(frame) |
|
|
if pos is None: |
|
|
direction_dict[frame] = 0.0 |
|
|
continue |
|
|
if prev_pos is None: |
|
|
direction_dict[frame] = 0.0 |
|
|
prev_vec = (0.0, 0.0) |
|
|
else: |
|
|
vec = (pos[0] - prev_pos[0], pos[1] - prev_pos[1]) |
|
|
if prev_vec is None: |
|
|
direction_dict[frame] = 0.0 |
|
|
else: |
|
|
direction_dict[frame] = _angle_between(prev_vec, vec) |
|
|
prev_vec = vec |
|
|
prev_pos = pos |
|
|
state.direction_change[target_obj_id] = direction_dict |
|
|
state.impact_debug_direction = [direction_dict.get(f, 0.0) for f in frames] |
|
|
|
|
|
distance_dict = state.distance_from_start.get(target_obj_id, {}) |
|
|
max_distance_px = max(distance_dict.values()) if distance_dict else 0.0 |
|
|
goal_distance_m = max(state.goal_distance_m, 0.0) |
|
|
meters_per_px = goal_distance_m / max_distance_px if goal_distance_m > 0 and max_distance_px > 1e-6 else None |
|
|
state.impact_meters_per_px = meters_per_px |
|
|
|
|
|
kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {}) |
|
|
if meters_per_px: |
|
|
state.impact_debug_speed_kmh = [ |
|
|
kalman_speed_dict.get(f, 0.0) * meters_per_px * 3.6 for f in frames |
|
|
] |
|
|
if state.min_impact_speed_kmh > 0: |
|
|
state.impact_debug_speed_threshold_px = (state.min_impact_speed_kmh / 3.6) / meters_per_px |
|
|
else: |
|
|
state.impact_debug_speed_kmh = [0.0 for _ in frames] |
|
|
state.impact_debug_speed_threshold_px = None |
|
|
|
|
|
baseline_frames = [f for f in frames if f <= state.kick_frame] |
|
|
if not baseline_frames: |
|
|
baseline_frames = frames[: max(1, min(len(frames), 10))] |
|
|
baseline_vals = [residuals.get(f, 0.0) for f in baseline_frames] |
|
|
baseline_median = statistics.median(baseline_vals) if baseline_vals else 0.0 |
|
|
baseline_std = statistics.pstdev(baseline_vals) if len(baseline_vals) > 1 else 0.0 |
|
|
innovation_threshold = baseline_median + 4.0 * baseline_std |
|
|
innovation_threshold = max(innovation_threshold, baseline_median * 3.0, 5.0) |
|
|
state.impact_debug_innovation_threshold = innovation_threshold |
|
|
|
|
|
direction_threshold = 25.0 |
|
|
state.impact_debug_direction_threshold = direction_threshold |
|
|
|
|
|
post_kick_buffer = 3 |
|
|
candidates: list[tuple[float, float, int]] = [] |
|
|
meters_limit = goal_distance_m * 1.1 if goal_distance_m > 0 else None |
|
|
frame_list_len = len(frames) |
|
|
for idx, frame in enumerate(frames): |
|
|
if frame <= state.kick_frame + post_kick_buffer: |
|
|
continue |
|
|
innovation = residuals.get(frame, 0.0) |
|
|
if innovation < innovation_threshold: |
|
|
continue |
|
|
direction_delta = direction_dict.get(frame, 0.0) |
|
|
if direction_delta < direction_threshold: |
|
|
continue |
|
|
speed_px = kalman_speed_dict.get(frame, 0.0) |
|
|
if state.impact_debug_speed_threshold_px and speed_px < state.impact_debug_speed_threshold_px: |
|
|
continue |
|
|
if meters_per_px and meters_limit is not None: |
|
|
distance_m = distance_dict.get(frame, 0.0) * meters_per_px |
|
|
if distance_m > meters_limit: |
|
|
continue |
|
|
|
|
|
prev_innovation = residuals.get(frames[idx - 1], innovation) if idx > 0 else innovation |
|
|
next_innovation = residuals.get(frames[idx + 1], innovation) if idx + 1 < frame_list_len else innovation |
|
|
if innovation < prev_innovation and innovation < next_innovation: |
|
|
continue |
|
|
candidates.append((innovation, -frame, frame)) |
|
|
|
|
|
if not candidates: |
|
|
state.impact_frame = None |
|
|
return None |
|
|
|
|
|
candidates.sort(reverse=True) |
|
|
impact_frame = candidates[0][2] |
|
|
state.impact_frame = impact_frame |
|
|
return impact_frame |
|
|
|
|
|
|
|
|
def on_image_click( |
|
|
img: Image.Image | np.ndarray, |
|
|
state: AppState, |
|
|
frame_idx: int, |
|
|
obj_id: int, |
|
|
label: str, |
|
|
clear_old: bool, |
|
|
evt: gr.SelectData, |
|
|
) -> Image.Image: |
|
|
if state is None or state.inference_session is None: |
|
|
return img |
|
|
if state.is_switching_model: |
|
|
|
|
|
return update_frame_display(state, int(frame_idx)) |
|
|
|
|
|
|
|
|
x = y = None |
|
|
if evt is not None: |
|
|
|
|
|
try: |
|
|
if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2: |
|
|
x, y = int(evt.index[0]), int(evt.index[1]) |
|
|
elif hasattr(evt, "value") and isinstance(evt.value, dict) and "x" in evt.value and "y" in evt.value: |
|
|
x, y = int(evt.value["x"]), int(evt.value["y"]) |
|
|
except Exception: |
|
|
x = y = None |
|
|
|
|
|
if x is None or y is None: |
|
|
raise gr.Error("Could not read click coordinates.") |
|
|
|
|
|
_ensure_color_for_obj(state, int(obj_id)) |
|
|
|
|
|
processor = state.processor |
|
|
model = state.model |
|
|
inference_session = state.inference_session |
|
|
original_size = None |
|
|
pixel_values = None |
|
|
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: |
|
|
inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt") |
|
|
original_size = inputs.original_sizes[0] |
|
|
pixel_values = inputs.pixel_values[0] |
|
|
|
|
|
if state.current_prompt_type == "Boxes": |
|
|
|
|
|
if state.pending_box_start is None: |
|
|
|
|
|
frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) |
|
|
frame_clicks[int(obj_id)] = [] |
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
state.pending_box_start = (int(x), int(y)) |
|
|
state.pending_box_start_frame_idx = int(frame_idx) |
|
|
state.pending_box_start_obj_id = int(obj_id) |
|
|
|
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
return update_frame_display(state, int(frame_idx)) |
|
|
else: |
|
|
x1, y1 = state.pending_box_start |
|
|
x2, y2 = int(x), int(y) |
|
|
|
|
|
state.pending_box_start = None |
|
|
state.pending_box_start_frame_idx = None |
|
|
state.pending_box_start_obj_id = None |
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
x_min, y_min = min(x1, x2), min(y1, y2) |
|
|
x_max, y_max = max(x1, x2), max(y1, y2) |
|
|
|
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=int(frame_idx), |
|
|
obj_ids=int(obj_id), |
|
|
input_boxes=[[[x_min, y_min, x_max, y_max]]], |
|
|
clear_old_inputs=True, |
|
|
original_size=original_size, |
|
|
) |
|
|
|
|
|
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) |
|
|
obj_boxes = frame_boxes.setdefault(int(obj_id), []) |
|
|
|
|
|
obj_boxes.clear() |
|
|
obj_boxes.append((x_min, y_min, x_max, y_max)) |
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
else: |
|
|
|
|
|
label_int = 1 if str(label).lower().startswith("pos") else 0 |
|
|
|
|
|
if bool(clear_old): |
|
|
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) |
|
|
frame_boxes[int(obj_id)] = [] |
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=int(frame_idx), |
|
|
obj_ids=int(obj_id), |
|
|
input_points=[[[[int(x), int(y)]]]], |
|
|
input_labels=[[[int(label_int)]]], |
|
|
original_size=original_size, |
|
|
clear_old_inputs=bool(clear_old), |
|
|
) |
|
|
|
|
|
frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) |
|
|
obj_clicks = frame_clicks.setdefault(int(obj_id), []) |
|
|
if bool(clear_old): |
|
|
obj_clicks.clear() |
|
|
obj_clicks.append((int(x), int(y), int(label_int))) |
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx)) |
|
|
|
|
|
H = inference_session.video_height |
|
|
W = inference_session.video_width |
|
|
|
|
|
pred_masks = outputs.pred_masks.detach().cpu() |
|
|
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] |
|
|
|
|
|
|
|
|
|
|
|
masks_for_frame: dict[int, np.ndarray] = {} |
|
|
obj_ids_order = list(inference_session.obj_ids) |
|
|
for i, oid in enumerate(obj_ids_order): |
|
|
mask_i = video_res_masks[i] |
|
|
|
|
|
mask_2d = mask_i.cpu().numpy().squeeze() |
|
|
masks_for_frame[int(oid)] = mask_2d |
|
|
|
|
|
state.masks_by_frame[int(frame_idx)] = masks_for_frame |
|
|
_update_centroids_for_frame(state, int(frame_idx)) |
|
|
|
|
|
state.composited_frames.pop(int(frame_idx), None) |
|
|
|
|
|
|
|
|
return update_frame_display(state, int(frame_idx)) |
|
|
|
|
|
|
|
|
def _on_image_click_with_updates( |
|
|
img: Image.Image | np.ndarray, |
|
|
state: AppState, |
|
|
frame_idx: int, |
|
|
obj_id: int, |
|
|
label: str, |
|
|
clear_old: bool, |
|
|
evt: gr.SelectData, |
|
|
): |
|
|
preview_img = on_image_click(img, state, frame_idx, obj_id, label, clear_old, evt) |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) |
|
|
return preview_img, propagate_main_update, detect_btn_update, propagate_player_update |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def propagate_masks(GLOBAL_STATE: gr.State): |
|
|
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: |
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
return ( |
|
|
GLOBAL_STATE, |
|
|
"Load a video first.", |
|
|
gr.update(), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
_ensure_ball_prompt_from_yolo(GLOBAL_STATE) |
|
|
|
|
|
processor = deepcopy(GLOBAL_STATE.processor) |
|
|
model = deepcopy(GLOBAL_STATE.model) |
|
|
inference_session = deepcopy(GLOBAL_STATE.inference_session) |
|
|
|
|
|
inference_session.inference_device = "cuda" |
|
|
inference_session.cache.inference_device = "cuda" |
|
|
model.to("cuda") |
|
|
|
|
|
if not GLOBAL_STATE.sam_window: |
|
|
_compute_sam_window_from_kick( |
|
|
GLOBAL_STATE, |
|
|
GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None), |
|
|
) |
|
|
start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames) |
|
|
start_idx = max(0, int(start_idx)) |
|
|
end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx))) |
|
|
total = max(1, end_idx - start_idx) |
|
|
processed = 0 |
|
|
|
|
|
_ensure_ball_prompt_from_yolo(GLOBAL_STATE) |
|
|
|
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
yield ( |
|
|
GLOBAL_STATE, |
|
|
f"Propagating masks: {processed}/{total}", |
|
|
gr.update(), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
last_frame_idx = start_idx |
|
|
with torch.inference_mode(): |
|
|
for frame_idx in range(start_idx, end_idx): |
|
|
frame = GLOBAL_STATE.video_frames[frame_idx] |
|
|
pixel_values = None |
|
|
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: |
|
|
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] |
|
|
sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) |
|
|
H = inference_session.video_height |
|
|
W = inference_session.video_width |
|
|
pred_masks = sam2_video_output.pred_masks.detach().cpu() |
|
|
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] |
|
|
last_frame_idx = frame_idx |
|
|
masks_for_frame: dict[int, np.ndarray] = {} |
|
|
obj_ids_order = list(inference_session.obj_ids) |
|
|
for i, oid in enumerate(obj_ids_order): |
|
|
mask_2d = video_res_masks[i].cpu().numpy().squeeze() |
|
|
masks_for_frame[int(oid)] = mask_2d |
|
|
GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame |
|
|
_update_centroids_for_frame(GLOBAL_STATE, frame_idx) |
|
|
|
|
|
GLOBAL_STATE.composited_frames.pop(frame_idx, None) |
|
|
|
|
|
processed += 1 |
|
|
|
|
|
if processed % 30 == 0 or processed == total: |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
yield ( |
|
|
GLOBAL_STATE, |
|
|
f"Propagating masks: {processed}/{total}", |
|
|
gr.update(value=frame_idx), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects." |
|
|
|
|
|
|
|
|
target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None) |
|
|
if target_frame is None: |
|
|
target_frame = last_frame_idx |
|
|
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1))) |
|
|
GLOBAL_STATE.current_frame_idx = target_frame |
|
|
|
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
yield ( |
|
|
GLOBAL_STATE, |
|
|
text, |
|
|
gr.update(value=target_frame), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
|
|
|
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any, Any]: |
|
|
|
|
|
if not GLOBAL_STATE.video_frames: |
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
return ( |
|
|
GLOBAL_STATE, |
|
|
None, |
|
|
0, |
|
|
0, |
|
|
"Session reset. Load a new video.", |
|
|
gr.update(visible=False, value=""), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
|
|
|
GLOBAL_STATE.masks_by_frame.clear() |
|
|
GLOBAL_STATE.clicks_by_frame_obj.clear() |
|
|
GLOBAL_STATE.boxes_by_frame_obj.clear() |
|
|
GLOBAL_STATE.composited_frames.clear() |
|
|
GLOBAL_STATE.pending_box_start = None |
|
|
GLOBAL_STATE.pending_box_start_frame_idx = None |
|
|
GLOBAL_STATE.pending_box_start_obj_id = None |
|
|
GLOBAL_STATE.ball_centers.clear() |
|
|
GLOBAL_STATE.mask_areas.clear() |
|
|
GLOBAL_STATE.smoothed_centers.clear() |
|
|
GLOBAL_STATE.ball_speeds.clear() |
|
|
GLOBAL_STATE.distance_from_start.clear() |
|
|
GLOBAL_STATE.direction_change.clear() |
|
|
GLOBAL_STATE.kick_frame = None |
|
|
GLOBAL_STATE.ball_centers.clear() |
|
|
GLOBAL_STATE.kalman_centers.clear() |
|
|
GLOBAL_STATE.kalman_speeds.clear() |
|
|
GLOBAL_STATE.kalman_residuals.clear() |
|
|
GLOBAL_STATE.kick_debug_frames = [] |
|
|
GLOBAL_STATE.kick_debug_speeds = [] |
|
|
GLOBAL_STATE.kick_debug_threshold = None |
|
|
GLOBAL_STATE.kick_debug_baseline = None |
|
|
GLOBAL_STATE.kick_debug_speed_std = None |
|
|
GLOBAL_STATE.kick_debug_area = [] |
|
|
GLOBAL_STATE.kick_debug_kick_frame = None |
|
|
GLOBAL_STATE.kick_debug_distance = [] |
|
|
GLOBAL_STATE.kick_debug_kalman_speeds = [] |
|
|
GLOBAL_STATE.impact_frame = None |
|
|
GLOBAL_STATE.impact_debug_frames = [] |
|
|
GLOBAL_STATE.impact_debug_innovation = [] |
|
|
GLOBAL_STATE.impact_debug_innovation_threshold = None |
|
|
GLOBAL_STATE.impact_debug_direction = [] |
|
|
GLOBAL_STATE.impact_debug_direction_threshold = None |
|
|
GLOBAL_STATE.impact_debug_speed_kmh = [] |
|
|
GLOBAL_STATE.impact_debug_speed_threshold_px = None |
|
|
GLOBAL_STATE.impact_meters_per_px = None |
|
|
|
|
|
|
|
|
try: |
|
|
if GLOBAL_STATE.inference_session is not None: |
|
|
GLOBAL_STATE.inference_session.reset_inference_session() |
|
|
except Exception: |
|
|
pass |
|
|
GLOBAL_STATE.inference_session = None |
|
|
gc.collect() |
|
|
ensure_session_for_current_model(GLOBAL_STATE) |
|
|
|
|
|
|
|
|
current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0)) |
|
|
current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1)) |
|
|
preview_img = update_frame_display(GLOBAL_STATE, current_idx) |
|
|
slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True) |
|
|
slider_value = gr.update(value=current_idx) |
|
|
status = "Session reset. Prompts cleared; video preserved." |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
|
|
|
return ( |
|
|
GLOBAL_STATE, |
|
|
preview_img, |
|
|
slider_minmax, |
|
|
slider_value, |
|
|
status, |
|
|
gr.update(visible=False, value=""), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
|
|
|
def create_annotation_preview(video_file, annotations): |
|
|
""" |
|
|
Create a preview image showing annotation points on video frames. |
|
|
|
|
|
Args: |
|
|
video_file: Path to video file |
|
|
annotations: List of annotation dicts |
|
|
|
|
|
Returns: |
|
|
PIL Image with annotations visualized |
|
|
""" |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_file) |
|
|
if not cap.isOpened(): |
|
|
return None |
|
|
|
|
|
|
|
|
frames_to_show = {} |
|
|
for ann in annotations: |
|
|
frame_idx = ann.get("frame", 0) |
|
|
if frame_idx not in frames_to_show: |
|
|
frames_to_show[frame_idx] = [] |
|
|
frames_to_show[frame_idx].append(ann) |
|
|
|
|
|
|
|
|
annotated_frames = [] |
|
|
for frame_idx in sorted(frames_to_show.keys())[:3]: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
continue |
|
|
|
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_img = Image.fromarray(frame_rgb) |
|
|
draw = ImageDraw.Draw(pil_img) |
|
|
|
|
|
|
|
|
for ann in frames_to_show[frame_idx]: |
|
|
x, y = ann.get("x", 0), ann.get("y", 0) |
|
|
obj_id = ann.get("object_id", 1) |
|
|
label = ann.get("label", "positive") |
|
|
|
|
|
|
|
|
color = pastel_color_for_object(obj_id) |
|
|
|
|
|
|
|
|
size = 20 |
|
|
draw.line([(x-size, y), (x+size, y)], fill=color, width=3) |
|
|
draw.line([(x, y-size), (x, y+size)], fill=color, width=3) |
|
|
draw.ellipse([(x-10, y-10), (x+10, y+10)], outline=color, width=3) |
|
|
|
|
|
|
|
|
text = f"Obj{obj_id} F{frame_idx}" |
|
|
draw.text((x+15, y-15), text, fill=color) |
|
|
|
|
|
|
|
|
draw.text((10, 10), f"Frame {frame_idx}", fill=(255, 255, 255)) |
|
|
|
|
|
annotated_frames.append(pil_img) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
|
|
|
if not annotated_frames: |
|
|
return None |
|
|
|
|
|
total_width = sum(img.width for img in annotated_frames) |
|
|
max_height = max(img.height for img in annotated_frames) |
|
|
|
|
|
combined = Image.new('RGB', (total_width, max_height)) |
|
|
x_offset = 0 |
|
|
for img in annotated_frames: |
|
|
combined.paste(img, (x_offset, 0)) |
|
|
x_offset += img.width |
|
|
|
|
|
return combined |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def process_video_api( |
|
|
video_file, |
|
|
annotations_json_str: str, |
|
|
checkpoint: str = "base_plus", |
|
|
remove_background: bool = True |
|
|
): |
|
|
""" |
|
|
Single-endpoint API for programmatic video processing. |
|
|
|
|
|
Args: |
|
|
video_file: Uploaded video file |
|
|
annotations_json_str: JSON string with format: |
|
|
{ |
|
|
"annotations": [ |
|
|
{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}, |
|
|
{"object_id": 1, "frame": 156, "x": 374, "y": 513, "label": "positive"}, |
|
|
{"object_id": 2, "frame": 156, "x": 374, "y": 257, "label": "positive"} |
|
|
] |
|
|
} |
|
|
checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large) |
|
|
remove_background: Whether to remove background (default: True) |
|
|
|
|
|
Returns: |
|
|
Tuple of (preview_image, processed_video_path) |
|
|
""" |
|
|
import json |
|
|
|
|
|
try: |
|
|
|
|
|
annotations_data = json.loads(annotations_json_str) |
|
|
annotations = annotations_data.get("annotations", []) |
|
|
client_fps = annotations_data.get("fps", None) |
|
|
|
|
|
print(f"[API] Processing video with {len(annotations)} annotations") |
|
|
print(f"[API] Client FPS: {client_fps}") |
|
|
print(f"[API] Checkpoint: {checkpoint}") |
|
|
print(f"[API] Remove background: {remove_background}") |
|
|
|
|
|
|
|
|
preview_img = create_annotation_preview(video_file, annotations) |
|
|
|
|
|
|
|
|
api_state = AppState() |
|
|
api_state.model_repo_key = checkpoint |
|
|
|
|
|
|
|
|
api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file) |
|
|
space_fps = api_state.video_fps |
|
|
print(f"[API] Video loaded: {status}") |
|
|
print(f"[API] ⚠️ FPS mismatch check: Client={client_fps}, Space={space_fps}") |
|
|
|
|
|
|
|
|
if client_fps and space_fps and abs(client_fps - space_fps) > 0.5: |
|
|
offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps))) |
|
|
print(f"[API] ⚠️ FPS mismatch detected! Frame indices may be off by ~{offset_estimate} frames") |
|
|
print(f"[API] ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy") |
|
|
|
|
|
|
|
|
for i, ann in enumerate(annotations): |
|
|
object_id = ann.get("object_id", 1) |
|
|
timestamp_ms = ann.get("timestamp_ms", None) |
|
|
frame_idx = ann.get("frame", None) |
|
|
x = ann.get("x", 0) |
|
|
y = ann.get("y", 0) |
|
|
label = ann.get("label", "positive") |
|
|
|
|
|
|
|
|
if timestamp_ms is not None and space_fps and space_fps > 0: |
|
|
calculated_frame = int((timestamp_ms / 1000.0) * space_fps) |
|
|
if frame_idx is not None and calculated_frame != frame_idx: |
|
|
print(f"[API] ✅ Using timestamp: {timestamp_ms}ms → Frame {calculated_frame} (client sent frame {frame_idx})") |
|
|
else: |
|
|
print(f"[API] ✅ Calculated frame from timestamp: {timestamp_ms}ms → Frame {calculated_frame}") |
|
|
frame_idx = calculated_frame |
|
|
elif frame_idx is None: |
|
|
print(f"[API] ⚠️ Warning: No timestamp or frame provided, using frame 0") |
|
|
frame_idx = 0 |
|
|
|
|
|
print(f"[API] Adding annotation {i+1}/{len(annotations)}: " |
|
|
f"Object {object_id}, Frame {frame_idx}, ({x}, {y}), {label}") |
|
|
|
|
|
|
|
|
api_state.current_frame_idx = int(frame_idx) |
|
|
api_state.current_obj_id = int(object_id) |
|
|
api_state.current_label = str(label) |
|
|
|
|
|
|
|
|
class MockEvent: |
|
|
def __init__(self, x, y): |
|
|
self.index = (x, y) |
|
|
|
|
|
mock_evt = MockEvent(x, y) |
|
|
|
|
|
|
|
|
preview_img = on_image_click( |
|
|
first_frame, |
|
|
api_state, |
|
|
frame_idx, |
|
|
object_id, |
|
|
label, |
|
|
clear_old=False, |
|
|
evt=mock_evt |
|
|
) |
|
|
|
|
|
|
|
|
print("[API] Propagating masks across video...") |
|
|
|
|
|
for outputs in propagate_masks(api_state): |
|
|
if not outputs: |
|
|
continue |
|
|
api_state = outputs[0] |
|
|
status_msg = outputs[1] if len(outputs) > 1 else "" |
|
|
if status_msg: |
|
|
print(f"[API] Progress: {status_msg}") |
|
|
|
|
|
|
|
|
print(f"[API] Rendering video with remove_background={remove_background}...") |
|
|
result_video_path = _render_video(api_state, remove_background) |
|
|
|
|
|
print(f"[API] ✅ Processing complete: {result_video_path}") |
|
|
return preview_img, result_video_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[API] ❌ Error: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise gr.Error(f"Processing failed: {str(e)}") |
|
|
|
|
|
|
|
|
theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate") |
|
|
CUSTOM_CSS = "" |
|
|
|
|
|
with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme, css=CUSTOM_CSS) as demo: |
|
|
GLOBAL_STATE = gr.State(AppState()) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### SAM2 Video Tracking · powered by Hugging Face 🤗 Transformers |
|
|
Segment and track objects across a video with SAM2 (Segment Anything 2). This demo runs the official implementation from the Hugging Face Transformers library for interactive, promptable video segmentation. |
|
|
""" |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
""" |
|
|
**Quick start** |
|
|
- **Load a video**: Upload your own or pick an example below. |
|
|
- **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy). |
|
|
- **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames. |
|
|
- **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically. |
|
|
""" |
|
|
) |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
""" |
|
|
**Working with results** |
|
|
- **Preview**: Use the slider to navigate frames and see the current masks. |
|
|
- **Track**: Click “Track ball (SAM2)” to track all defined objects across the selected window. The preview follows progress periodically to keep things responsive. |
|
|
- **Export**: Render an MP4 for smooth playback using the original video FPS. |
|
|
- **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video). |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
video_in = gr.Video( |
|
|
label="Upload video", |
|
|
sources=["upload", "webcam"], |
|
|
interactive=True, |
|
|
elem_id="video-pane", |
|
|
) |
|
|
ckpt_radio = gr.Radio( |
|
|
choices=["tiny", "small", "base_plus", "large"], |
|
|
value="tiny", |
|
|
label="SAM2.1 checkpoint", |
|
|
) |
|
|
ckpt_progress = gr.Markdown(visible=False) |
|
|
load_status = gr.Markdown(visible=True) |
|
|
reset_btn = gr.Button("Reset Session", variant="secondary") |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("**Preview**") |
|
|
preview = gr.Image( |
|
|
interactive=True, |
|
|
elem_id="preview-pane", |
|
|
container=False, |
|
|
show_label=False, |
|
|
) |
|
|
frame_slider = gr.Slider( |
|
|
label="Frame", |
|
|
minimum=0, |
|
|
maximum=0, |
|
|
step=1, |
|
|
value=0, |
|
|
interactive=True, |
|
|
elem_id="frame-slider", |
|
|
) |
|
|
with gr.Row(): |
|
|
min_impact_speed_slider = gr.Slider( |
|
|
label="Min impact speed (km/h)", |
|
|
minimum=0, |
|
|
maximum=120, |
|
|
step=1, |
|
|
value=20, |
|
|
interactive=True, |
|
|
) |
|
|
goal_distance_slider = gr.Slider( |
|
|
label="Distance to goal (m)", |
|
|
minimum=1, |
|
|
maximum=60, |
|
|
step=0.5, |
|
|
value=18, |
|
|
interactive=True, |
|
|
) |
|
|
with gr.Row(): |
|
|
detect_ball_btn = gr.Button("Detect Ball", variant="secondary") |
|
|
track_ball_yolo_btn = gr.Button("Track ball (YOLO13)", variant="secondary") |
|
|
propagate_btn = gr.Button("Track ball (SAM2)", variant="primary", interactive=False) |
|
|
detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False) |
|
|
propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False) |
|
|
ball_status = gr.Markdown(visible=False) |
|
|
propagate_status = gr.Markdown(visible=True) |
|
|
impact_status = gr.Markdown("Impact frame: not computed") |
|
|
with gr.Row(): |
|
|
obj_id_inp = gr.Number(value=1, precision=0, label="Object ID", scale=0) |
|
|
label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label") |
|
|
clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object") |
|
|
prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type") |
|
|
kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True) |
|
|
yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True) |
|
|
|
|
|
|
|
|
def _on_video_change(GLOBAL_STATE: gr.State, video): |
|
|
GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video) |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
return ( |
|
|
GLOBAL_STATE, |
|
|
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), |
|
|
first_frame, |
|
|
status, |
|
|
gr.update(visible=False, value=""), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
video_in.change( |
|
|
_on_video_change, |
|
|
inputs=[GLOBAL_STATE, video_in], |
|
|
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
example_video_path = ensure_example_video() |
|
|
examples_list = [ |
|
|
[None, example_video_path], |
|
|
] |
|
|
with gr.Row(): |
|
|
gr.Examples( |
|
|
examples=examples_list, |
|
|
inputs=[GLOBAL_STATE, video_in], |
|
|
fn=_on_video_change, |
|
|
outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
label="Examples", |
|
|
cache_examples=False, |
|
|
examples_per_page=5, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
remove_bg_checkbox = gr.Checkbox( |
|
|
label="Remove Background", |
|
|
value=False, |
|
|
info="If checked, shows only tracked objects on black background. If unchecked, overlays colored masks on original video." |
|
|
) |
|
|
with gr.Row(): |
|
|
render_btn = gr.Button("Render MP4 for smooth playback", variant="primary") |
|
|
playback_video = gr.Video(label="Rendered Playback", interactive=False) |
|
|
|
|
|
def _on_ckpt_change(s: AppState, key: str): |
|
|
if s is not None and key: |
|
|
key = str(key) |
|
|
if key != s.model_repo_key: |
|
|
|
|
|
s.is_switching_model = True |
|
|
s.model_repo_key = key |
|
|
s.model_repo_id = None |
|
|
s.model = None |
|
|
s.processor = None |
|
|
|
|
|
yield gr.update(visible=True, value=f"Loading checkpoint: {key}...") |
|
|
ensure_session_for_current_model(s) |
|
|
if s is not None: |
|
|
s.is_switching_model = False |
|
|
|
|
|
yield gr.update(visible=False, value="") |
|
|
|
|
|
ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress]) |
|
|
|
|
|
def _sync_frame_idx(state_in: AppState, idx: int): |
|
|
if state_in is not None: |
|
|
state_in.current_frame_idx = int(idx) |
|
|
return update_frame_display(state_in, int(idx)) |
|
|
|
|
|
frame_slider.change( |
|
|
_sync_frame_idx, |
|
|
inputs=[GLOBAL_STATE, frame_slider], |
|
|
outputs=preview, |
|
|
) |
|
|
|
|
|
def _sync_obj_id(s: AppState, oid): |
|
|
if s is not None and oid is not None: |
|
|
s.current_obj_id = int(oid) |
|
|
return gr.update() |
|
|
|
|
|
obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[]) |
|
|
|
|
|
def _sync_label(s: AppState, lab: str): |
|
|
if s is not None and lab is not None: |
|
|
s.current_label = str(lab) |
|
|
return gr.update() |
|
|
|
|
|
label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[]) |
|
|
|
|
|
def _sync_prompt_type(s: AppState, val: str): |
|
|
if s is not None and val is not None: |
|
|
s.current_prompt_type = str(val) |
|
|
s.pending_box_start = None |
|
|
is_points = str(val).lower() == "points" |
|
|
|
|
|
updates = [ |
|
|
gr.update(visible=is_points), |
|
|
gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False), |
|
|
] |
|
|
return updates |
|
|
|
|
|
prompt_type.change( |
|
|
_sync_prompt_type, |
|
|
inputs=[GLOBAL_STATE, prompt_type], |
|
|
outputs=[label_radio, clear_old_chk], |
|
|
) |
|
|
|
|
|
def _update_min_impact_speed(s: AppState, val: float): |
|
|
if s is not None and val is not None: |
|
|
s.min_impact_speed_kmh = float(val) |
|
|
_recompute_motion_metrics(s) |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s) |
|
|
return ( |
|
|
_build_kick_plot(s), |
|
|
_format_impact_status(s), |
|
|
gr.update(value=_format_kick_status(s), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
def _update_goal_distance(s: AppState, val: float): |
|
|
if s is not None and val is not None: |
|
|
s.goal_distance_m = float(val) |
|
|
_recompute_motion_metrics(s) |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s) |
|
|
return ( |
|
|
_build_kick_plot(s), |
|
|
_format_impact_status(s), |
|
|
gr.update(value=_format_kick_status(s), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
min_impact_speed_slider.change( |
|
|
_update_min_impact_speed, |
|
|
inputs=[GLOBAL_STATE, min_impact_speed_slider], |
|
|
outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
goal_distance_slider.change( |
|
|
_update_goal_distance, |
|
|
inputs=[GLOBAL_STATE, goal_distance_slider], |
|
|
outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
def _auto_detect_ball( |
|
|
state_in: AppState, |
|
|
obj_id, |
|
|
label_value: str, |
|
|
clear_old_value: bool, |
|
|
): |
|
|
if state_in is None or state_in.num_frames == 0: |
|
|
raise gr.Error("Load a video first, then try auto-detect.") |
|
|
|
|
|
frame_idx = 0 |
|
|
frame = state_in.video_frames[frame_idx] |
|
|
detection = detect_ball_center(frame) |
|
|
if detection is None: |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) |
|
|
return ( |
|
|
update_frame_display(state_in, frame_idx), |
|
|
gr.update( |
|
|
value="❌ Unable to auto-detect the ball. Please add a point manually.", |
|
|
visible=True, |
|
|
), |
|
|
gr.update(value=frame_idx), |
|
|
_build_kick_plot(state_in), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
x_center, y_center, _, _, conf = detection |
|
|
frame_width, frame_height = frame.size |
|
|
x_center = max(0, min(frame_width - 1, int(x_center))) |
|
|
y_center = max(0, min(frame_height - 1, int(y_center))) |
|
|
obj_id_int = int(obj_id) if obj_id is not None else state_in.current_obj_id |
|
|
label_str = label_value if label_value else state_in.current_label |
|
|
clear_old_flag = bool(clear_old_value) |
|
|
|
|
|
|
|
|
synthetic_evt = SimpleNamespace( |
|
|
index=(x_center, y_center), |
|
|
value={"x": x_center, "y": y_center}, |
|
|
) |
|
|
|
|
|
state_in.current_frame_idx = frame_idx |
|
|
preview_img = on_image_click( |
|
|
update_frame_display(state_in, frame_idx), |
|
|
state_in, |
|
|
frame_idx, |
|
|
obj_id_int, |
|
|
label_str, |
|
|
clear_old_flag, |
|
|
synthetic_evt, |
|
|
) |
|
|
|
|
|
status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})" |
|
|
status_text += f" | {_format_kick_status(state_in)}" |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) |
|
|
return ( |
|
|
preview_img, |
|
|
gr.update(value=status_text, visible=True), |
|
|
gr.update(value=frame_idx), |
|
|
_build_kick_plot(state_in), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
detect_ball_btn.click( |
|
|
_auto_detect_ball, |
|
|
inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk], |
|
|
outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
def _track_ball_yolo(state_in: AppState): |
|
|
if state_in is None or state_in.num_frames == 0: |
|
|
raise gr.Error("Load a video first, then track the ball with YOLO.") |
|
|
progress = gr.Progress(track_tqdm=False) |
|
|
_perform_yolo_ball_tracking(state_in, progress=progress) |
|
|
target_frame = ( |
|
|
state_in.yolo_kick_frame |
|
|
if state_in.yolo_kick_frame is not None |
|
|
else state_in.yolo_initial_frame |
|
|
if state_in.yolo_initial_frame is not None |
|
|
else 0 |
|
|
) |
|
|
if state_in.num_frames: |
|
|
target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1)) |
|
|
state_in.current_frame_idx = target_frame |
|
|
preview_img = update_frame_display(state_in, target_frame) |
|
|
base_msg = state_in.yolo_status or "" |
|
|
kick_msg = _format_kick_status(state_in) |
|
|
status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) |
|
|
return ( |
|
|
preview_img, |
|
|
gr.update(value=status_text, visible=True), |
|
|
gr.update(value=target_frame), |
|
|
_build_kick_plot(state_in), |
|
|
_build_yolo_plot(state_in), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
track_ball_yolo_btn.click( |
|
|
_track_ball_yolo, |
|
|
inputs=[GLOBAL_STATE], |
|
|
outputs=[preview, ball_status, frame_slider, kick_plot, yolo_plot, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
def _auto_detect_player(state_in: AppState): |
|
|
if state_in is None or state_in.num_frames == 0: |
|
|
raise gr.Error("Load a video first, then try auto-detect.") |
|
|
if state_in.inference_session is None or state_in.processor is None or state_in.model is None: |
|
|
raise gr.Error("Model session is not ready. Load a video and propagate masks first.") |
|
|
|
|
|
kick_frame = state_in.kick_frame or getattr(state_in, "kick_debug_kick_frame", None) |
|
|
if kick_frame is None: |
|
|
raise gr.Error("Detect the kick frame first by propagating the ball masks.") |
|
|
|
|
|
frame_idx = int(np.clip(int(kick_frame), 0, state_in.num_frames - 1)) |
|
|
frame = state_in.video_frames[frame_idx] |
|
|
detection = detect_person_box(frame) |
|
|
if detection is None: |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) |
|
|
status_text = ( |
|
|
f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. " |
|
|
"Please add a box manually." |
|
|
) |
|
|
return ( |
|
|
update_frame_display(state_in, frame_idx), |
|
|
gr.update(value=status_text, visible=True), |
|
|
gr.update(value=frame_idx), |
|
|
_build_kick_plot(state_in), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
gr.update(), |
|
|
) |
|
|
|
|
|
x_min, y_min, x_max, y_max, conf = detection |
|
|
state_in.player_obj_id = PLAYER_OBJECT_ID |
|
|
state_in.player_detection_frame = frame_idx |
|
|
state_in.player_detection_conf = conf |
|
|
state_in.current_obj_id = PLAYER_OBJECT_ID |
|
|
|
|
|
|
|
|
for frame_boxes in state_in.boxes_by_frame_obj.values(): |
|
|
frame_boxes.pop(PLAYER_OBJECT_ID, None) |
|
|
for frame_clicks in state_in.clicks_by_frame_obj.values(): |
|
|
frame_clicks.pop(PLAYER_OBJECT_ID, None) |
|
|
for frame_masks in state_in.masks_by_frame.values(): |
|
|
frame_masks.pop(PLAYER_OBJECT_ID, None) |
|
|
|
|
|
_ensure_color_for_obj(state_in, PLAYER_OBJECT_ID) |
|
|
|
|
|
processor = state_in.processor |
|
|
model = state_in.model |
|
|
inference_session = state_in.inference_session |
|
|
|
|
|
inputs = processor(images=frame, device=state_in.device, return_tensors="pt") |
|
|
original_size = inputs.original_sizes[0] |
|
|
pixel_values = inputs.pixel_values[0] |
|
|
|
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=frame_idx, |
|
|
obj_ids=PLAYER_OBJECT_ID, |
|
|
input_boxes=[[[x_min, y_min, x_max, y_max]]], |
|
|
clear_old_inputs=True, |
|
|
original_size=original_size, |
|
|
) |
|
|
|
|
|
frame_boxes = state_in.boxes_by_frame_obj.setdefault(frame_idx, {}) |
|
|
frame_boxes[PLAYER_OBJECT_ID] = [(x_min, y_min, x_max, y_max)] |
|
|
|
|
|
state_in.composited_frames.pop(frame_idx, None) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) |
|
|
|
|
|
H = inference_session.video_height |
|
|
W = inference_session.video_width |
|
|
pred_masks = outputs.pred_masks.detach().cpu() |
|
|
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] |
|
|
|
|
|
masks_for_frame = state_in.masks_by_frame.get(frame_idx, {}).copy() |
|
|
obj_ids_order = list(inference_session.obj_ids) |
|
|
for i, oid in enumerate(obj_ids_order): |
|
|
mask_i = video_res_masks[i].cpu().numpy().squeeze() |
|
|
masks_for_frame[int(oid)] = mask_i |
|
|
state_in.masks_by_frame[frame_idx] = masks_for_frame |
|
|
_update_centroids_for_frame(state_in, frame_idx) |
|
|
state_in.composited_frames.pop(frame_idx, None) |
|
|
state_in.current_frame_idx = frame_idx |
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) |
|
|
status_text = ( |
|
|
f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})" |
|
|
) |
|
|
return ( |
|
|
update_frame_display(state_in, frame_idx), |
|
|
gr.update(value=status_text, visible=True), |
|
|
gr.update(value=frame_idx), |
|
|
_build_kick_plot(state_in), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
gr.update(value=PLAYER_OBJECT_ID), |
|
|
) |
|
|
|
|
|
detect_player_btn.click( |
|
|
_auto_detect_player, |
|
|
inputs=[GLOBAL_STATE], |
|
|
outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, obj_id_inp], |
|
|
) |
|
|
|
|
|
@spaces.GPU() |
|
|
def propagate_player_masks(GLOBAL_STATE: gr.State): |
|
|
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
return ( |
|
|
GLOBAL_STATE, |
|
|
"Load a video first.", |
|
|
gr.update(), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE): |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
return ( |
|
|
GLOBAL_STATE, |
|
|
"Detect the player before propagating.", |
|
|
gr.update(), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
processor = deepcopy(GLOBAL_STATE.processor) |
|
|
model = deepcopy(GLOBAL_STATE.model) |
|
|
inference_session = deepcopy(GLOBAL_STATE.inference_session) |
|
|
inference_session.inference_device = "cuda" |
|
|
inference_session.cache.inference_device = "cuda" |
|
|
model.to("cuda") |
|
|
|
|
|
if not GLOBAL_STATE.sam_window: |
|
|
_compute_sam_window_from_kick( |
|
|
GLOBAL_STATE, |
|
|
GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None), |
|
|
) |
|
|
start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames) |
|
|
start_idx = max(0, int(start_idx)) |
|
|
end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx))) |
|
|
total = max(1, end_idx - start_idx) |
|
|
processed = 0 |
|
|
last_frame_idx = start_idx |
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
yield ( |
|
|
GLOBAL_STATE, |
|
|
f"Propagating player: {processed}/{total}", |
|
|
gr.update(), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for frame_idx in range(start_idx, end_idx): |
|
|
frame = GLOBAL_STATE.video_frames[frame_idx] |
|
|
pixel_values = None |
|
|
if ( |
|
|
inference_session.processed_frames is None |
|
|
or frame_idx not in inference_session.processed_frames |
|
|
): |
|
|
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] |
|
|
sam2_video_output = model( |
|
|
inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx |
|
|
) |
|
|
H = inference_session.video_height |
|
|
W = inference_session.video_width |
|
|
pred_masks = sam2_video_output.pred_masks.detach().cpu() |
|
|
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] |
|
|
|
|
|
masks_for_frame = GLOBAL_STATE.masks_by_frame.get(frame_idx, {}).copy() |
|
|
obj_ids_order = list(inference_session.obj_ids) |
|
|
for i, oid in enumerate(obj_ids_order): |
|
|
mask_2d = video_res_masks[i].cpu().numpy().squeeze() |
|
|
if int(oid) == int(player_id): |
|
|
masks_for_frame[int(player_id)] = mask_2d |
|
|
GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame |
|
|
_update_centroids_for_frame(GLOBAL_STATE, frame_idx) |
|
|
GLOBAL_STATE.composited_frames.pop(frame_idx, None) |
|
|
|
|
|
processed += 1 |
|
|
last_frame_idx = frame_idx |
|
|
if processed % 30 == 0 or processed == total: |
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
yield ( |
|
|
GLOBAL_STATE, |
|
|
f"Propagating player: {processed}/{total}", |
|
|
gr.update(value=frame_idx), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
text = f"Propagated player across {processed} frames." |
|
|
target_frame = GLOBAL_STATE.player_detection_frame |
|
|
if target_frame is None: |
|
|
target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None) |
|
|
if target_frame is None: |
|
|
target_frame = last_frame_idx |
|
|
target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1))) |
|
|
GLOBAL_STATE.current_frame_idx = target_frame |
|
|
|
|
|
propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) |
|
|
yield ( |
|
|
GLOBAL_STATE, |
|
|
text, |
|
|
gr.update(value=target_frame), |
|
|
_build_kick_plot(GLOBAL_STATE), |
|
|
_build_yolo_plot(GLOBAL_STATE), |
|
|
_format_impact_status(GLOBAL_STATE), |
|
|
gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), |
|
|
propagate_main_update, |
|
|
detect_btn_update, |
|
|
propagate_player_update, |
|
|
) |
|
|
|
|
|
propagate_player_btn.click( |
|
|
propagate_player_masks, |
|
|
inputs=[GLOBAL_STATE], |
|
|
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
|
|
|
preview.select( |
|
|
_on_image_click_with_updates, |
|
|
[preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], |
|
|
[preview, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _render_video(s: AppState, remove_bg: bool = False): |
|
|
if s is None or s.num_frames == 0: |
|
|
raise gr.Error("Load a video first.") |
|
|
fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12 |
|
|
trim_duration_sec = 4.0 |
|
|
target_window_frames = max(1, int(round(fps * trim_duration_sec))) |
|
|
half_window = target_window_frames // 2 |
|
|
kick_frame = s.kick_frame or getattr(s, "kick_debug_kick_frame", None) |
|
|
start_idx = 0 |
|
|
end_idx = min(s.num_frames, target_window_frames) |
|
|
if kick_frame is not None: |
|
|
start_idx = max(0, int(kick_frame) - half_window) |
|
|
end_idx = start_idx + target_window_frames |
|
|
if end_idx > s.num_frames: |
|
|
end_idx = s.num_frames |
|
|
start_idx = max(0, end_idx - target_window_frames) |
|
|
else: |
|
|
end_idx = min(s.num_frames, start_idx + target_window_frames) |
|
|
if end_idx <= start_idx: |
|
|
end_idx = min(s.num_frames, start_idx + 1) |
|
|
|
|
|
frames_np = [] |
|
|
first = compose_frame(s, start_idx, remove_bg=remove_bg) |
|
|
h, w = first.size[1], first.size[0] |
|
|
for idx in range(start_idx, end_idx): |
|
|
|
|
|
if remove_bg: |
|
|
img = compose_frame(s, idx, remove_bg=True) |
|
|
else: |
|
|
img = s.composited_frames.get(idx) |
|
|
if img is None: |
|
|
img = compose_frame(s, idx, remove_bg=False) |
|
|
img_with_idx = _annotate_frame_index(img, idx) |
|
|
frames_np.append(np.array(img_with_idx)[:, :, ::-1]) |
|
|
|
|
|
if (idx + 1) % 60 == 0: |
|
|
gc.collect() |
|
|
out_path = "/tmp/sam2_playback.mp4" |
|
|
|
|
|
try: |
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) |
|
|
for fr_bgr in frames_np: |
|
|
writer.write(fr_bgr) |
|
|
writer.release() |
|
|
return out_path |
|
|
except Exception as e: |
|
|
print(f"Failed to render video with cv2: {e}") |
|
|
raise gr.Error(f"Failed to render video: {e}") |
|
|
|
|
|
render_btn.click(_render_video, inputs=[GLOBAL_STATE, remove_bg_checkbox], outputs=[playback_video]) |
|
|
|
|
|
|
|
|
propagate_btn.click( |
|
|
propagate_masks, |
|
|
inputs=[GLOBAL_STATE], |
|
|
outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
reset_btn.click( |
|
|
reset_session, |
|
|
inputs=GLOBAL_STATE, |
|
|
outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_interface = gr.Interface( |
|
|
fn=process_video_api, |
|
|
inputs=[ |
|
|
gr.Video(label="Video File"), |
|
|
gr.Textbox( |
|
|
label="Annotations JSON", |
|
|
placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}', |
|
|
lines=5 |
|
|
), |
|
|
gr.Radio( |
|
|
choices=["tiny", "small", "base_plus", "large"], |
|
|
value="base_plus", |
|
|
label="SAM2 Checkpoint" |
|
|
), |
|
|
gr.Checkbox(label="Remove Background", value=True) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Image(label="Annotation Preview (shows where points are placed)"), |
|
|
gr.Video(label="Processed Video") |
|
|
], |
|
|
title="SAM2 API", |
|
|
description=""" |
|
|
## Programmatic API for Video Background Removal |
|
|
|
|
|
**The preview image shows where your annotation points are placed on the video frames.** |
|
|
|
|
|
**Annotations JSON Format:** |
|
|
```json |
|
|
{ |
|
|
"annotations": [ |
|
|
{"object_id": 1, "frame": 0, "x": 363, "y": 631, "label": "positive"}, |
|
|
{"object_id": 1, "frame": 187, "x": 296, "y": 485, "label": "positive"}, |
|
|
{"object_id": 2, "frame": 187, "x": 296, "y": 412, "label": "positive"} |
|
|
] |
|
|
} |
|
|
``` |
|
|
|
|
|
- **Object 1** (Ball): Frame 0 + Impact frame |
|
|
- **Object 2** (Player): Impact frame |
|
|
- Colors represent different objects |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SAM2 Video Tracking") as combined_demo: |
|
|
gr.Markdown("# SAM2 Video Tracking") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Interactive UI"): |
|
|
demo.render() |
|
|
|
|
|
with gr.TabItem("API"): |
|
|
api_interface.render() |
|
|
|
|
|
|
|
|
|
|
|
api_video_input_hidden = gr.Video(visible=False) |
|
|
api_annotations_input_hidden = gr.Textbox(visible=False) |
|
|
api_checkpoint_input_hidden = gr.Radio(choices=["tiny", "small", "base_plus", "large"], visible=False) |
|
|
api_remove_bg_input_hidden = gr.Checkbox(visible=False) |
|
|
api_preview_output_hidden = gr.Image(visible=False) |
|
|
api_video_output_hidden = gr.Video(visible=False) |
|
|
|
|
|
|
|
|
api_dummy_btn = gr.Button("API", visible=False) |
|
|
api_dummy_btn.click( |
|
|
fn=process_video_api, |
|
|
inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden], |
|
|
outputs=[api_preview_output_hidden, api_video_output_hidden], |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
combined_demo.queue(api_open=True).launch() |
|
|
|