from __future__ import annotations import colorsys import gc from copy import deepcopy import base64 import math import statistics from pathlib import Path import json import plotly.graph_objects as go BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64") EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4") def ensure_example_video() -> str: """ Ensure the Kickit example video exists locally by decoding the base64 text file. Returns the path to the decoded MP4. """ if EXAMPLE_VIDEO_PATH.exists(): return str(EXAMPLE_VIDEO_PATH) if not BASE64_VIDEO_PATH.exists(): raise FileNotFoundError("Base64 video asset not found.") data = BASE64_VIDEO_PATH.read_text() EXAMPLE_VIDEO_PATH.write_bytes(base64.b64decode(data)) return str(EXAMPLE_VIDEO_PATH) from types import SimpleNamespace from typing import Optional, Any import cv2 import gradio as gr import numpy as np try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*args, **kwargs): def decorator(fn): return fn return decorator spaces = _SpacesFallback() import torch from gradio.themes import Soft from PIL import Image, ImageDraw from transformers import AutoModel, Sam2VideoProcessor from ultralytics import YOLO from huggingface_hub import hf_hub_download YOLO_MODEL_CACHE: dict[str, YOLO] = {} YOLO_DEFAULT_MODEL = "yolov13n.pt" YOLO_REPO_ID = "atalaydenknalbant/Yolov13" YOLO_TARGET_NAME = "sports ball" YOLO_CONF_THRESHOLD = 0.0 YOLO_IOU_THRESHOLD = 0.02 PLAYER_TARGET_NAME = "person" PLAYER_OBJECT_ID = 2 BALL_OBJECT_ID = 1 GOAL_MODE_IDLE = "idle" GOAL_MODE_PLACING_FIRST = "placing_first" GOAL_MODE_PLACING_SECOND = "placing_second" GOAL_MODE_EDITING = "editing" GOAL_HANDLE_RADIUS_PX = 8 GOAL_HANDLE_HIT_RADIUS_PX = 28 GOAL_LINE_COLOR = (255, 214, 64) GOAL_HANDLE_FILL = (10, 10, 10) def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO: """ Lazily download and load a YOLOv13 model, caching it for reuse. """ if model_filename in YOLO_MODEL_CACHE: return YOLO_MODEL_CACHE[model_filename] model_path = hf_hub_download(repo_id=YOLO_REPO_ID, filename=model_filename) model = YOLO(model_path) YOLO_MODEL_CACHE[model_filename] = model return model def detect_ball_center( frame: Image.Image, model_filename: str = YOLO_DEFAULT_MODEL, conf_threshold: float = YOLO_CONF_THRESHOLD, iou_threshold: float = YOLO_IOU_THRESHOLD, ) -> Optional[tuple[int, int, int, int, float]]: """ Run YOLO on a single frame and return (x_center, y_center, width, height, confidence) for the highest-confidence sports ball detection. """ model = get_yolo_model(model_filename) class_ids = [ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME ] if not class_ids: return None results = model.predict( source=frame, conf=conf_threshold, iou=iou_threshold, max_det=1, classes=class_ids, imgsz=640, device="cpu", verbose=False, ) if not results: return None boxes = results[0].boxes if boxes is None or len(boxes) == 0: return None box = boxes[0] # xywh format: x_center, y_center, width, height xywh = box.xywh[0].cpu().tolist() conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 x_center, y_center, width, height = xywh return ( int(round(x_center)), int(round(y_center)), int(round(width)), int(round(height)), conf, ) def detect_all_balls( frame: Image.Image, model_filename: str = YOLO_DEFAULT_MODEL, conf_threshold: float = 0.05, # Minimum 5% confidence to filter noise iou_threshold: float = YOLO_IOU_THRESHOLD, max_detections: int = 10, # Get more from YOLO, then filter to top 5 max_candidates: int = 5, # Return only top 5 by confidence ) -> list[dict]: """ Detect all ball candidates in a frame. - Minimum 5% confidence to filter noise - Returns top 5 candidates by confidence - No ROI filtering - scoring happens later Returns list of dicts with keys: - id: int (candidate index) - center: (x, y) tuple - box: (x_min, y_min, x_max, y_max) tuple - width: float - height: float - conf: float (YOLO confidence) - x_ratio: float (horizontal position as fraction of frame width) - y_ratio: float (vertical position as fraction of frame height) """ model = get_yolo_model(model_filename) class_ids = [ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME ] if not class_ids: return [] results = model.predict( source=frame, conf=conf_threshold, iou=iou_threshold, max_det=max_detections, classes=class_ids, imgsz=640, device="cpu", verbose=False, ) if not results: return [] boxes = results[0].boxes if boxes is None or len(boxes) == 0: return [] frame_width, frame_height = frame.size candidates = [] for i, box in enumerate(boxes): xywh = box.xywh[0].cpu().tolist() conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 x_center, y_center, width, height = xywh # Compute bounding box x_min = int(round(max(0.0, x_center - width / 2.0))) y_min = int(round(max(0.0, y_center - height / 2.0))) x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0))) y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0))) if x_max <= x_min or y_max <= y_min: continue # Compute position ratios x_ratio = x_center / frame_width y_ratio = y_center / frame_height # NO ROI filtering - accept all balls # (ROI scoring will happen later in the scoring phase) candidates.append({ "id": len(candidates), "center": (float(x_center), float(y_center)), "box": (x_min, y_min, x_max, y_max), "width": float(width), "height": float(height), "conf": conf, "x_ratio": x_ratio, "y_ratio": y_ratio, }) # Sort by confidence descending candidates.sort(key=lambda c: c["conf"], reverse=True) # Keep only top N candidates candidates = candidates[:max_candidates] # Re-assign IDs after sorting and filtering for i, c in enumerate(candidates): c["id"] = i # Debug logging print(f"[detect_all_balls] Found {len(candidates)} ball candidates (top {max_candidates}, conf >= {conf_threshold:.0%}):") for c in candidates: print(f" Ball {c['id']}: center={c['center']}, conf={c['conf']:.1%}, box={c['box']}") return candidates def detect_person_box( frame: Image.Image, model_filename: str = YOLO_DEFAULT_MODEL, conf_threshold: float = YOLO_CONF_THRESHOLD, iou_threshold: float = YOLO_IOU_THRESHOLD, ) -> Optional[tuple[int, int, int, int, float]]: """ Run YOLO on a single frame and return (x_min, y_min, x_max, y_max, confidence) for the highest-confidence person detection. """ model = get_yolo_model(model_filename) class_ids = [ idx for idx, name in model.names.items() if name.lower() == PLAYER_TARGET_NAME ] if not class_ids: return None results = model.predict( source=frame, conf=conf_threshold, iou=iou_threshold, max_det=5, classes=class_ids, imgsz=640, device="cpu", verbose=False, ) if not results: return None boxes = results[0].boxes if boxes is None or len(boxes) == 0: return None box = boxes[0] xyxy = box.xyxy[0].cpu().tolist() conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 x_min, y_min, x_max, y_max = xyxy frame_width, frame_height = frame.size x_min = max(0, min(frame_width - 1, int(round(x_min)))) y_min = max(0, min(frame_height - 1, int(round(y_min)))) x_max = max(0, min(frame_width - 1, int(round(x_max)))) y_max = max(0, min(frame_height - 1, int(round(y_max)))) if x_max <= x_min or y_max <= y_min: return None return x_min, y_min, x_max, y_max, conf def _compute_sam_window_from_kick(state: AppState, kick_frame: int | None) -> tuple[int, int]: total_frames = state.num_frames if total_frames == 0: return 0, 0 # If no kick detected, use ALL frames if kick_frame is None: start_idx = 0 end_idx = total_frames print(f"[_compute_sam_window_from_kick] No kick detected → using ALL {total_frames} frames") else: # If kick detected, use 4-second window around kick fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0 target_window_frames = max(1, int(round(fps * 4.0))) half_window = target_window_frames // 2 start_idx = max(0, int(kick_frame) - half_window) end_idx = min(total_frames, start_idx + target_window_frames) if end_idx <= start_idx: end_idx = min(total_frames, start_idx + 1) print(f"[_compute_sam_window_from_kick] Kick @ {kick_frame} → window [{start_idx}, {end_idx}] ({end_idx - start_idx} frames)") state.sam_window = (start_idx, end_idx) return start_idx, end_idx def _goal_frame_dims(state: AppState, frame_idx: int | None = None) -> tuple[int, int]: if state is None or not state.video_frames: return 1, 1 idx = 0 if frame_idx is None else int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) frame = state.video_frames[idx] return frame.size def _goal_norm_from_xy(state: AppState, frame_idx: int, x: int, y: int) -> tuple[float, float]: width, height = _goal_frame_dims(state, frame_idx) if width <= 0: width = 1 if height <= 0: height = 1 return ( float(np.clip(x / width, 0.0, 1.0)), float(np.clip(y / height, 0.0, 1.0)), ) def _goal_xy_from_norm(state: AppState, frame_idx: int, pt: tuple[float, float]) -> tuple[int, int]: width, height = _goal_frame_dims(state, frame_idx) return ( int(round(float(pt[0]) * width)), int(round(float(pt[1]) * height)), ) def _goal_points_for_drawing(state: AppState) -> list[tuple[float, float]]: if state is None: return [] if state.goal_mode in {GOAL_MODE_PLACING_FIRST, GOAL_MODE_PLACING_SECOND, GOAL_MODE_EDITING}: return list(state.goal_points_norm) return list(state.goal_overlay_points) def _goal_clear_preview_cache(state: AppState) -> None: if state is None: return state.composited_frames.clear() def _goal_has_confirmed(state: AppState) -> bool: return isinstance(state, AppState) and len(state.goal_confirmed_points_norm) == 2 def _goal_set_status(state: AppState, text: str) -> None: if state is None: return state.goal_status_text = text def _goal_status_text(state: AppState) -> str: if state is None: return "Goal crossbar unavailable." if state.goal_status_text: return state.goal_status_text if _goal_has_confirmed(state): return "Goal crossbar confirmed. Click Start Mapping to adjust." return "Goal crossbar inactive." def _goal_button_updates(state: AppState) -> tuple[Any, Any, Any, Any, Any]: if state is None: return ( gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(value="Goal crossbar unavailable.", visible=True), ) start_enabled = state.goal_mode == GOAL_MODE_IDLE confirm_enabled = len(state.goal_points_norm) == 2 and state.goal_mode in { GOAL_MODE_PLACING_SECOND, GOAL_MODE_EDITING, } clear_enabled = bool(state.goal_points_norm or state.goal_confirmed_points_norm) back_enabled = bool(state.goal_prev_confirmed_points_norm) status_update = gr.update(value=_goal_status_text(state), visible=True) return ( gr.update(interactive=start_enabled), gr.update(interactive=confirm_enabled), gr.update(interactive=clear_enabled), gr.update(interactive=back_enabled), status_update, ) def _goal_handle_hit_index(state: AppState, frame_idx: int, x: int, y: int) -> int | None: points = state.goal_points_norm if state is None or len(points) == 0: return None width, height = _goal_frame_dims(state, frame_idx) max_dist = GOAL_HANDLE_HIT_RADIUS_PX for idx, pt in enumerate(points): px, py = _goal_xy_from_norm(state, frame_idx, pt) dist = math.hypot(px - x, py - y) if dist <= max_dist: return idx return None def _goal_current_frame_idx(state: AppState) -> int: if state is None or state.num_frames == 0: return 0 idx = int(getattr(state, "current_frame_idx", 0)) return int(np.clip(idx, 0, state.num_frames - 1)) def _goal_output_tuple(state: AppState, preview_img: Image.Image | None = None) -> tuple[Image.Image, Any, Any, Any, Any, Any]: if state is None: return (preview_img, *(gr.update(interactive=False) for _ in range(4)), gr.update(value="Goal crossbar unavailable.", visible=True)) idx = _goal_current_frame_idx(state) if preview_img is None: preview_img = update_frame_display(state, idx) return (preview_img, *_goal_button_updates(state)) def _goal_start_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]: if state is None or not state.video_frames: raise gr.Error("Load a video first, then map the goal crossbar.") state.goal_prev_confirmed_points_norm = list(state.goal_confirmed_points_norm) if state.goal_confirmed_points_norm: state.goal_points_norm = list(state.goal_confirmed_points_norm) else: state.goal_points_norm = [] state.goal_overlay_points = [] state.goal_mode = GOAL_MODE_PLACING_FIRST state.goal_dragging_idx = None _goal_set_status(state, "Click the left goalpost to start the crossbar.") _goal_clear_preview_cache(state) return _goal_output_tuple(state) def _goal_confirm_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]: if state is None: return (None, *_goal_button_updates(state)) if len(state.goal_points_norm) != 2: _goal_set_status(state, "Select both goal corners before confirming.") return _goal_output_tuple(state) state.goal_confirmed_points_norm = list(state.goal_points_norm) state.goal_overlay_points = list(state.goal_points_norm) state.goal_points_norm = [] state.goal_mode = GOAL_MODE_IDLE state.goal_dragging_idx = None _goal_set_status(state, "Goal crossbar saved. Click Start Mapping to adjust again.") _goal_clear_preview_cache(state) return _goal_output_tuple(state) def _goal_clear_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]: if state is None: return (None, *_goal_button_updates(state)) state.goal_points_norm.clear() state.goal_confirmed_points_norm.clear() state.goal_prev_confirmed_points_norm.clear() state.goal_overlay_points.clear() state.goal_points_norm = [] state.goal_confirmed_points_norm = [] state.goal_prev_confirmed_points_norm = [] state.goal_overlay_points = [] state.goal_mode = GOAL_MODE_IDLE state.goal_dragging_idx = None _goal_set_status(state, "Goal crossbar cleared.") _goal_clear_preview_cache(state) return _goal_output_tuple(state) def _goal_back_mapping(state: AppState) -> tuple[Image.Image, Any, Any, Any, Any, Any]: if state is None: return (None, *_goal_button_updates(state)) if not state.goal_prev_confirmed_points_norm: _goal_set_status(state, "No previous goal crossbar to restore.") return _goal_output_tuple(state) state.goal_confirmed_points_norm = list(state.goal_prev_confirmed_points_norm) state.goal_overlay_points = list(state.goal_prev_confirmed_points_norm) state.goal_points_norm = [] state.goal_prev_confirmed_points_norm = [] state.goal_mode = GOAL_MODE_IDLE state.goal_dragging_idx = None _goal_set_status(state, "Restored the previous goal crossbar.") _goal_clear_preview_cache(state) return _goal_output_tuple(state) def _goal_process_preview_click( state: AppState, frame_idx: int, evt: gr.SelectData | None, ) -> tuple[Image.Image | None, bool]: if state is None or state.goal_mode == GOAL_MODE_IDLE: return None, False x = y = None if evt is not None: try: if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2: x, y = int(evt.index[0]), int(evt.index[1]) elif hasattr(evt, "value") and isinstance(evt.value, dict): data = evt.value if "x" in data and "y" in data: x, y = int(data["x"]), int(data["y"]) except Exception: x = y = None if x is None or y is None: _goal_set_status(state, "Could not read click coordinates. Please try again.") return _goal_output_tuple(state)[0], True norm_pt = _goal_norm_from_xy(state, frame_idx, x, y) points = state.goal_points_norm if state.goal_mode == GOAL_MODE_PLACING_FIRST: state.goal_points_norm = [norm_pt] state.goal_mode = GOAL_MODE_PLACING_SECOND _goal_set_status(state, "Click the right goalpost to finish the crossbar.") elif state.goal_mode == GOAL_MODE_PLACING_SECOND: handle_idx = _goal_handle_hit_index(state, frame_idx, x, y) if handle_idx is not None and handle_idx < len(points): state.goal_points_norm[handle_idx] = norm_pt _goal_set_status(state, "Adjusted the first corner. Click the other post.") else: if len(points) == 0: state.goal_points_norm = [norm_pt] _goal_set_status(state, "Click the next goalpost to finish the crossbar.") elif len(points) == 1: state.goal_points_norm.append(norm_pt) state.goal_mode = GOAL_MODE_EDITING _goal_set_status(state, "Adjust handles if needed, then Confirm.") else: state.goal_points_norm[1] = norm_pt state.goal_mode = GOAL_MODE_EDITING _goal_set_status(state, "Adjust handles if needed, then Confirm.") elif state.goal_mode == GOAL_MODE_EDITING: handle_idx = _goal_handle_hit_index(state, frame_idx, x, y) if handle_idx is None and len(points) == 2: # fall back to whichever endpoint is closest to click px0, py0 = _goal_xy_from_norm(state, frame_idx, points[0]) px1, py1 = _goal_xy_from_norm(state, frame_idx, points[1]) dist0 = math.hypot(px0 - x, py0 - y) dist1 = math.hypot(px1 - x, py1 - y) handle_idx = 0 if dist0 <= dist1 else 1 if handle_idx is not None and handle_idx < len(points): state.goal_points_norm[handle_idx] = norm_pt _goal_set_status(state, "Handle moved. Press Confirm to save.") state.goal_points_norm = state.goal_points_norm[:2] _goal_clear_preview_cache(state) preview_img = update_frame_display(state, frame_idx) return preview_img, True def _draw_goal_overlay(state: AppState, frame_idx: int, image: Image.Image) -> None: if state is None or image is None: return points = _goal_points_for_drawing(state) if not points: return draw = ImageDraw.Draw(image) px_points = [_goal_xy_from_norm(state, frame_idx, pt) for pt in points[:2]] if len(px_points) >= 2: draw.line( [px_points[0], px_points[1]], fill=GOAL_LINE_COLOR, width=4, ) handle_radius = max(4, GOAL_HANDLE_RADIUS_PX) for cx, cy in px_points: bbox = [ (cx - handle_radius, cy - handle_radius), (cx + handle_radius, cy + handle_radius), ] draw.ellipse(bbox, outline=GOAL_LINE_COLOR, fill=GOAL_HANDLE_FILL, width=2) def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None = None) -> None: if state is None or state.num_frames == 0: raise gr.Error("Load a video first, then track with YOLO.") model = get_yolo_model() class_ids = [ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME ] if not class_ids: raise gr.Error("YOLO model does not contain the sports ball class.") frames = state.video_frames total = len(frames) centers: dict[int, tuple[float, float]] = {} boxes: dict[int, tuple[int, int, int, int]] = {} confs: dict[int, float] = {} areas: dict[int, float] = {} first_detection_frame: int | None = None for idx, frame in enumerate(frames): if progress is not None: progress((idx + 1) / total) results = model.predict( source=frame, conf=YOLO_CONF_THRESHOLD, iou=YOLO_IOU_THRESHOLD, max_det=1, classes=class_ids, imgsz=640, device="cpu", verbose=False, ) if not results: continue boxes_result = results[0].boxes if boxes_result is None or len(boxes_result) == 0: continue box = boxes_result[0] xywh = box.xywh[0].cpu().tolist() conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 x_center, y_center, width, height = xywh x_center = float(x_center) y_center = float(y_center) width = max(1.0, float(width)) height = max(1.0, float(height)) frame_width, frame_height = frame.size x_min = int(round(max(0.0, x_center - width / 2.0))) y_min = int(round(max(0.0, y_center - height / 2.0))) x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0))) y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0))) if x_max <= x_min or y_max <= y_min: continue centers[idx] = (x_center, y_center) boxes[idx] = (x_min, y_min, x_max, y_max) confs[idx] = conf areas[idx] = float((x_max - x_min) * (y_max - y_min)) if first_detection_frame is None: first_detection_frame = idx state.yolo_ball_centers = centers state.yolo_ball_boxes = boxes state.yolo_ball_conf = confs state.yolo_mask_area_proxy = [areas.get(k, 0.0) for k in sorted(centers.keys())] state.yolo_initial_frame = first_detection_frame if len(centers) < 3: state.yolo_smoothed_centers = {} state.yolo_speeds = {} state.yolo_distance_from_start = {} state.yolo_threshold = None state.yolo_baseline_speed = None state.yolo_speed_std = None state.yolo_kick_frame = None state.yolo_status = "❌ YOLO13: insufficient detections to estimate kick. Please retry or annotate manually." state.sam_window = None return items = sorted(centers.items()) dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 alpha = 0.35 smoothed: dict[int, tuple[float, float]] = {} speeds: dict[int, float] = {} prev_frame = None prev_smooth = None for frame_idx, (cx, cy) in items: if prev_smooth is None: smooth_x, smooth_y = float(cx), float(cy) else: smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) smoothed[frame_idx] = (smooth_x, smooth_y) if prev_smooth is None or prev_frame is None: speeds[frame_idx] = 0.0 else: frame_delta = max(1, frame_idx - prev_frame) time_delta = frame_delta * dt dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) speed = dist / time_delta if time_delta > 0 else dist speeds[frame_idx] = speed prev_smooth = (smooth_x, smooth_y) prev_frame = frame_idx frames_ordered = [frame_idx for frame_idx, _ in items] speed_series = [speeds.get(f, 0.0) for f in frames_ordered] baseline_window = min(10, len(frames_ordered) // 3 or 1) baseline_speeds = speed_series[:baseline_window] baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 base_threshold = baseline_speed + 4.0 * speed_std if base_threshold < baseline_speed * 3.0: base_threshold = baseline_speed * 3.0 speed_threshold = max(base_threshold, 15.0) distance_dict: dict[int, float] = {} if smoothed: first_frame = frames_ordered[0] origin = smoothed[first_frame] for frame_idx, (sx, sy) in smoothed.items(): distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1]) areas_dict = {idx: areas.get(idx, 0.0) for idx in frames_ordered} initial_area = areas_dict.get(frames_ordered[0], 1.0) or 1.0 radius_estimate = math.sqrt(initial_area / math.pi) adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0)) sustain_frames = 3 holdout_frames = 8 area_window = 4 area_drop_ratio = 0.75 kalman_pos, kalman_speed, _ = _run_kalman_filter(items, dt) kalman_speed_series = [kalman_speed.get(f, 0.0) for f in frames_ordered] kick_frame: int | None = None for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window): speed = speed_series[idx] if speed < speed_threshold: continue sustain_ok = True for j in range(1, sustain_frames + 1): if idx + j >= len(frames_ordered): break if speed_series[idx + j] < speed_threshold * 0.7: sustain_ok = False break if not sustain_ok: continue area_pass = True current_area = areas_dict.get(frame) if current_area: prev_areas = [ areas_dict.get(f) for f in frames_ordered[max(0, idx - area_window):idx] if areas_dict.get(f) is not None ] if prev_areas: median_prev = statistics.median(prev_areas) if median_prev > 0: ratio = current_area / median_prev if ratio > area_drop_ratio: area_pass = False if not area_pass and speed < speed_threshold * 1.2: continue future_slice = frames_ordered[idx: min(len(frames_ordered), idx + holdout_frames)] max_future_dist = 0.0 for future_frame in future_slice: dist = distance_dict.get(future_frame, 0.0) if dist > max_future_dist: max_future_dist = dist if max_future_dist < adaptive_return_distance: continue kick_frame = frame break state.yolo_smoothed_centers = smoothed state.yolo_speeds = speeds state.yolo_distance_from_start = distance_dict state.yolo_threshold = speed_threshold state.yolo_baseline_speed = baseline_speed state.yolo_speed_std = speed_std state.yolo_kick_frames = frames_ordered state.yolo_kick_speeds = speed_series state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered] state.yolo_mask_area_proxy = [areas_dict.get(f, 0.0) for f in frames_ordered] state.yolo_kick_frame = kick_frame coverage = len(centers) / total if total else 0.0 if kick_frame is not None: state.yolo_status = f"✅ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%})." else: state.yolo_status = ( f"⚠️ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%}) but did not find a definitive kick." ) state.kalman_centers[BALL_OBJECT_ID] = kalman_pos state.kalman_speeds[BALL_OBJECT_ID] = kalman_speed if kick_frame is not None: state.kick_frame = kick_frame _compute_sam_window_from_kick(state, kick_frame) else: state.sam_window = None def _track_single_ball_candidate( state: AppState, candidate: dict, progress: gr.Progress | None = None, ) -> dict: """ Track a single ball candidate across ALL frames using YOLO. Uses proximity matching to follow the same ball. Returns dict with tracking results: - centers: dict[frame_idx, (x, y)] - speeds: dict[frame_idx, speed] - kick_frame: int | None - max_velocity: float - has_kick: bool - coverage: float (fraction of frames with detection) """ model = get_yolo_model() class_ids = [ idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME ] frames = state.video_frames total = len(frames) print(f"[_track_single_ball_candidate] Tracking Ball {candidate['id']} across {total} frames...") # Initial position from candidate last_center = candidate["center"] max_distance_threshold = 100 # Max pixels to consider same ball centers: dict[int, tuple[float, float]] = {} boxes: dict[int, tuple[int, int, int, int]] = {} confs: dict[int, float] = {} areas: dict[int, float] = {} for idx, frame in enumerate(frames): if progress is not None: progress((idx + 1) / total) results = model.predict( source=frame, conf=0.05, # Lower threshold to catch more iou=YOLO_IOU_THRESHOLD, max_det=10, # Allow multiple detections classes=class_ids, imgsz=640, device="cpu", verbose=False, ) if not results: continue boxes_result = results[0].boxes if boxes_result is None or len(boxes_result) == 0: continue # Find the detection closest to last known position best_box = None best_distance = float("inf") for box in boxes_result: xywh = box.xywh[0].cpu().tolist() x_center, y_center = xywh[0], xywh[1] dist = math.hypot(x_center - last_center[0], y_center - last_center[1]) if dist < best_distance and dist < max_distance_threshold: best_distance = dist best_box = box if best_box is None: continue xywh = best_box.xywh[0].cpu().tolist() conf = float(best_box.conf[0].cpu().item()) if best_box.conf is not None else 0.0 x_center, y_center, width, height = xywh x_center = float(x_center) y_center = float(y_center) width = max(1.0, float(width)) height = max(1.0, float(height)) frame_width, frame_height = frame.size x_min = int(round(max(0.0, x_center - width / 2.0))) y_min = int(round(max(0.0, y_center - height / 2.0))) x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0))) y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0))) if x_max <= x_min or y_max <= y_min: continue centers[idx] = (x_center, y_center) boxes[idx] = (x_min, y_min, x_max, y_max) confs[idx] = conf areas[idx] = float((x_max - x_min) * (y_max - y_min)) last_center = (x_center, y_center) # Compute speeds if len(centers) < 3: return { "centers": centers, "boxes": boxes, "confs": confs, "areas": areas, "speeds": {}, "smoothed_centers": {}, "frames_ordered": [], "speed_series": [], "kick_frame": None, "max_velocity": 0.0, "has_kick": False, "coverage": len(centers) / total if total else 0.0, } items = sorted(centers.items()) dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 alpha = 0.35 smoothed: dict[int, tuple[float, float]] = {} speeds: dict[int, float] = {} prev_frame = None prev_smooth = None for frame_idx, (cx, cy) in items: if prev_smooth is None: smooth_x, smooth_y = float(cx), float(cy) else: smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) smoothed[frame_idx] = (smooth_x, smooth_y) if prev_smooth is None or prev_frame is None: speeds[frame_idx] = 0.0 else: frame_delta = max(1, frame_idx - prev_frame) time_delta = frame_delta * dt dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) speed = dist / time_delta if time_delta > 0 else dist speeds[frame_idx] = speed prev_smooth = (smooth_x, smooth_y) prev_frame = frame_idx frames_ordered = [frame_idx for frame_idx, _ in items] speed_series = [speeds.get(f, 0.0) for f in frames_ordered] # Detect kick (velocity spike) baseline_window = min(10, len(frames_ordered) // 3 or 1) baseline_speeds = speed_series[:baseline_window] baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 base_threshold = baseline_speed + 4.0 * speed_std if base_threshold < baseline_speed * 3.0: base_threshold = baseline_speed * 3.0 speed_threshold = max(base_threshold, 15.0) kick_frame: int | None = None max_velocity = max(speed_series) if speed_series else 0.0 for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window): speed = speed_series[idx] if speed < speed_threshold: continue # Check sustain sustain_ok = True for j in range(1, 4): if idx + j >= len(frames_ordered): break if speed_series[idx + j] < speed_threshold * 0.7: sustain_ok = False break if sustain_ok: kick_frame = frame break result = { "centers": centers, "boxes": boxes, "confs": confs, "areas": areas, "speeds": speeds, "smoothed_centers": smoothed, "frames_ordered": frames_ordered, "speed_series": speed_series, "threshold": speed_threshold, "baseline": baseline_speed, "kick_frame": kick_frame, "max_velocity": max_velocity, "has_kick": kick_frame is not None, "coverage": len(centers) / total if total else 0.0, } # Summary logging kick_info = f"Kick @ frame {kick_frame}" if kick_frame else "No kick" print(f"[_track_single_ball_candidate] Ball {candidate['id']} done: " f"{len(centers)}/{total} frames ({result['coverage']:.0%}), " f"max_vel={max_velocity:.1f}px/s, {kick_info}") return result def _detect_and_track_all_ball_candidates( state: AppState, progress: gr.Progress | None = None, ) -> None: """ Detect all ball candidates in first frame, track each with YOLO, score them, and auto-select the best candidate. """ if state is None or state.num_frames == 0: raise gr.Error("Load a video first.") first_frame = state.video_frames[0] frame_width, frame_height = first_frame.size # Step 1: Detect all balls in first frame candidates = detect_all_balls(first_frame) if not candidates: state.ball_candidates = [] state.multi_ball_status = "❌ No ball candidates detected in first frame." return state.multi_ball_status = f"🔍 Found {len(candidates)} ball candidate(s). Tracking..." # Step 2: Track each candidate tracking_results: dict[int, dict] = {} for i, candidate in enumerate(candidates): if progress is not None: progress((i + 1) / len(candidates), desc=f"Tracking ball {i+1}/{len(candidates)}") result = _track_single_ball_candidate(state, candidate, progress=None) tracking_results[candidate["id"]] = result # Add tracking summary to candidate candidate["tracking"] = result candidate["has_kick"] = result["has_kick"] candidate["kick_frame"] = result["kick_frame"] candidate["max_velocity"] = result["max_velocity"] candidate["coverage"] = result["coverage"] # Step 3: Score candidates frame_center_x = frame_width / 2 for candidate in candidates: score = 0.0 # 1. Has a detected kick (velocity spike) — most important if candidate["has_kick"]: score += 50 # 2. Higher max velocity — ball that moves most score += min(30, candidate["max_velocity"] / 10) # 3. Centered horizontally x_offset = abs(candidate["center"][0] - frame_center_x) / frame_center_x score += 20 * (1 - x_offset) # 4. YOLO confidence as tiebreaker score += candidate["conf"] * 10 # 5. Better coverage score += candidate["coverage"] * 10 candidate["score"] = score # Sort by score descending candidates.sort(key=lambda c: c["score"], reverse=True) # Re-assign IDs after sorting for i, c in enumerate(candidates): c["id"] = i state.ball_candidates = candidates state.ball_candidates_tracking = tracking_results state.selected_ball_idx = 0 # Auto-select best candidate state.ball_selection_confirmed = False # Build status message if len(candidates) == 1: c = candidates[0] kick_info = f"Kick @ frame {c['kick_frame']}" if c["has_kick"] else "No kick detected" state.multi_ball_status = f"✅ 1 ball detected. {kick_info}." else: kicked_count = sum(1 for c in candidates if c["has_kick"]) state.multi_ball_status = ( f"⚠️ {len(candidates)} balls detected. " f"{kicked_count} show movement. " f"Best candidate auto-selected. Please confirm or change selection." ) def _apply_selected_ball_to_yolo_state(state: AppState) -> None: """ Copy the selected ball candidate's tracking data to the main YOLO state. This allows the rest of the pipeline to work unchanged. """ if not state.ball_candidates: return idx = state.selected_ball_idx if idx < 0 or idx >= len(state.ball_candidates): idx = 0 candidate = state.ball_candidates[idx] tracking = candidate.get("tracking", {}) # Copy to main YOLO state state.yolo_ball_centers = tracking.get("centers", {}) state.yolo_ball_boxes = tracking.get("boxes", {}) state.yolo_ball_conf = tracking.get("confs", {}) state.yolo_smoothed_centers = tracking.get("smoothed_centers", {}) state.yolo_speeds = tracking.get("speeds", {}) state.yolo_kick_frames = tracking.get("frames_ordered", []) state.yolo_kick_speeds = tracking.get("speed_series", []) state.yolo_threshold = tracking.get("threshold") state.yolo_baseline_speed = tracking.get("baseline") state.yolo_kick_frame = tracking.get("kick_frame") state.yolo_initial_frame = tracking.get("frames_ordered", [None])[0] if tracking.get("frames_ordered") else None # Compute areas areas = tracking.get("areas", {}) frames_ordered = tracking.get("frames_ordered", []) state.yolo_mask_area_proxy = [areas.get(f, 0.0) for f in frames_ordered] # Compute distance from start smoothed = tracking.get("smoothed_centers", {}) if smoothed and frames_ordered: origin = smoothed.get(frames_ordered[0], (0, 0)) distance_dict = {} for f, (sx, sy) in smoothed.items(): distance_dict[f] = math.hypot(sx - origin[0], sy - origin[1]) state.yolo_distance_from_start = distance_dict state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered] # Update kick frame and SAM window kick_frame = tracking.get("kick_frame") state.kick_frame = kick_frame # Can be None # Always compute SAM window - if no kick, it will use ALL frames _compute_sam_window_from_kick(state, kick_frame) # Mark as tracked state.is_yolo_tracked = True state.ball_selection_confirmed = True coverage = tracking.get("coverage", 0.0) if kick_frame is not None: state.yolo_status = f"✅ Ball {idx+1} tracked. Kick @ frame {kick_frame}." else: state.yolo_status = f"⚠️ Ball {idx+1} tracked ({coverage:.0%} coverage) but no kick detected. SAM2 will analyze ALL frames." def draw_yolo_detections_on_frame( frame: Image.Image, candidates: list[dict], selected_idx: int = 0, show_all: bool = True, ) -> Image.Image: """ Draw YOLO bounding boxes for all ball candidates on the frame. - Selected candidate: Green box with thick border - Other candidates: Yellow/orange boxes with thinner border - Each box labeled with "Ball N (conf%)" """ from PIL import ImageDraw, ImageFont result = frame.copy() draw = ImageDraw.Draw(result) # Try to get a font, fallback to default try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) except: font = ImageFont.load_default() small_font = font for i, candidate in enumerate(candidates): box = candidate.get("box") if not box: continue x_min, y_min, x_max, y_max = box conf = candidate.get("conf", 0) is_selected = (i == selected_idx) has_kick = candidate.get("has_kick", False) # Colors and styles if is_selected: box_color = (0, 255, 0) # Green for selected text_color = (0, 255, 0) width = 4 elif has_kick: box_color = (255, 165, 0) # Orange for kicked but not selected text_color = (255, 165, 0) width = 3 else: box_color = (255, 255, 0) # Yellow for others text_color = (255, 255, 0) width = 2 # Draw bounding box for offset in range(width): draw.rectangle( [x_min - offset, y_min - offset, x_max + offset, y_max + offset], outline=box_color, ) # Draw dark outline for visibility draw.rectangle( [x_min - width - 1, y_min - width - 1, x_max + width + 1, y_max + width + 1], outline=(0, 0, 0), ) # Label label = f"Ball {i + 1} ({conf:.0%})" if is_selected: label = f"✓ {label}" if has_kick: label += " ⚽" # Draw label background text_bbox = draw.textbbox((x_min, y_min - 22), label, font=font) padding = 3 bg_box = [ text_bbox[0] - padding, text_bbox[1] - padding, text_bbox[2] + padding, text_bbox[3] + padding, ] draw.rectangle(bg_box, fill=(0, 0, 0, 200)) # Draw label text draw.text((x_min, y_min - 22), label, fill=text_color, font=font) # Draw center crosshair cx, cy = candidate.get("center", (0, 0)) cx, cy = int(cx), int(cy) cross_size = 8 draw.line([(cx - cross_size, cy), (cx + cross_size, cy)], fill=box_color, width=2) draw.line([(cx, cy - cross_size), (cx, cy + cross_size)], fill=box_color, width=2) return result def _format_ball_candidates_for_radio(candidates: list[dict]) -> list[str]: """Format ball candidates as radio button choices.""" choices = [] for i, c in enumerate(candidates): kick_info = f"⚽ Kick@{c['kick_frame']}" if c.get('has_kick') else "No kick" vel_info = f"v={c.get('max_velocity', 0):.0f}px/s" conf_info = f"conf={c.get('conf', 0):.0%}" cov_info = f"cov={c.get('coverage', 0):.0%}" pos_info = f"x={c.get('x_ratio', 0.5):.0%}" label = f"Ball {i+1}: {kick_info} | {vel_info} | {pos_info} | {conf_info}" choices.append(label) return choices def _format_ball_candidates_markdown(candidates: list[dict], selected_idx: int = 0) -> str: """Format ball candidates as markdown summary.""" if not candidates: return "" lines = [f"**{len(candidates)} ball candidate(s) detected:**\n"] for i, c in enumerate(candidates): marker = "✅" if i == selected_idx else "○" kick_info = f"⚽ Kick @ frame {c['kick_frame']}" if c.get('has_kick') else "No kick detected" vel_info = f"Max velocity: {c.get('max_velocity', 0):.0f} px/s" conf_info = f"YOLO conf: {c.get('conf', 0):.0%}" pos_x = c.get('x_ratio', 0.5) pos_info = f"Position: {pos_x:.0%} from left" lines.append(f"{marker} **Ball {i+1}**: {kick_info}") lines.append(f" - {vel_info} | {pos_info} | {conf_info}") return "\n".join(lines) def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]: """Generate a deterministic pastel RGB color for a given object id. Uses golden ratio to distribute hues; low-medium saturation, high value. """ golden_ratio_conjugate = 0.61803398875 # Map obj_id (1-based) to hue in [0,1) hue = (obj_id * golden_ratio_conjugate) % 1.0 saturation = 0.45 value = 1.0 r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value) return int(r_f * 255), int(g_f * 255), int(b_f * 255) def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]: """Load video frames as PIL Images using transformers.video_utils if available, otherwise fall back to OpenCV. Returns (frames, info). """ cap = cv2.VideoCapture(video_path_or_url) frames = [] print("loading video frames") while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) # Gather fps if available fps_val = cap.get(cv2.CAP_PROP_FPS) cap.release() print("loaded video frames") info = { "num_frames": len(frames), "fps": float(fps_val) if fps_val and fps_val > 0 else None, } return frames, info def overlay_masks_on_frame( frame: Image.Image, masks_per_object: dict[int, np.ndarray], color_by_obj: dict[int, tuple[int, int, int]], alpha: float = 0.5, ) -> Image.Image: """Overlay per-object soft masks onto the RGB frame. masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1] color_by_obj: mapping of obj_id -> (R, G, B) """ base = np.array(frame).astype(np.float32) / 255.0 # H, W, 3 in [0,1] height, width = base.shape[:2] overlay = base.copy() for obj_id, mask in masks_per_object.items(): if mask is None: continue if mask.dtype != np.float32: mask = mask.astype(np.float32) # Ensure shape is H x W if mask.ndim == 3: mask = mask.squeeze() mask = np.clip(mask, 0.0, 1.0) color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0 # Blend: overlay = (1 - a*m)*overlay + (a*m)*color a = alpha m = mask[..., None] overlay = (1.0 - a * m) * overlay + (a * m) * color out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) return Image.fromarray(out) def get_device_and_dtype() -> tuple[str, torch.dtype]: device = "cpu" dtype = torch.bfloat16 return device, dtype class AppState: def __init__(self): self.reset() def reset(self): self.video_frames: list[Image.Image] = [] self.inference_session = None self.model: Optional[AutoModel] = None self.processor: Optional[Sam2VideoProcessor] = None self.device: str = "cpu" self.dtype: torch.dtype = torch.bfloat16 self.video_fps: float | None = None self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {} self.color_by_obj: dict[int, tuple[int, int, int]] = {} self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {} self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {} # Cache of composited frames (original + masks + clicks) self.composited_frames: dict[int, Image.Image] = {} # UI state for click handler self.current_frame_idx: int = 0 self.current_obj_id: int = 1 self.current_label: str = "positive" self.current_clear_old: bool = True self.current_prompt_type: str = "Points" # or "Boxes" self.pending_box_start: tuple[int, int] | None = None self.pending_box_start_frame_idx: int | None = None self.pending_box_start_obj_id: int | None = None self.is_switching_model: bool = False self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {} self.mask_areas: dict[int, dict[int, float]] = {} self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {} self.ball_speeds: dict[int, dict[int, float]] = {} self.distance_from_start: dict[int, dict[int, float]] = {} self.direction_change: dict[int, dict[int, float]] = {} self.kick_frame: int | None = None self.kick_debug_frames: list[int] = [] self.kick_debug_speeds: list[float] = [] self.kick_debug_threshold: float | None = None self.kick_debug_baseline: float | None = None self.kick_debug_speed_std: float | None = None self.kick_debug_area: list[float] = [] self.kick_debug_kick_frame: int | None = None self.kick_debug_distance: list[float] = [] self.kick_debug_kalman_speeds: list[float] = [] self.kalman_centers: dict[int, dict[int, tuple[float, float]]] = {} self.kalman_speeds: dict[int, dict[int, float]] = {} self.kalman_residuals: dict[int, dict[int, float]] = {} self.min_impact_speed_kmh: float = 20.0 self.goal_distance_m: float = 18.0 self.impact_frame: int | None = None self.impact_debug_frames: list[int] = [] self.impact_debug_innovation: list[float] = [] self.impact_debug_innovation_threshold: float | None = None self.impact_debug_direction: list[float] = [] self.impact_debug_direction_threshold: float | None = None self.impact_debug_speed_kmh: list[float] = [] self.impact_debug_speed_threshold_px: float | None = None self.impact_meters_per_px: float | None = None # Model selection self.model_repo_key: str = "tiny" self.model_repo_id: str | None = None self.session_repo_id: str | None = None self.player_obj_id: int | None = None self.player_detection_frame: int | None = None self.player_detection_conf: float | None = None # YOLO tracking caches self.yolo_ball_centers: dict[int, tuple[float, float]] = {} self.yolo_ball_boxes: dict[int, tuple[int, int, int, int]] = {} self.yolo_ball_conf: dict[int, float] = {} self.yolo_smoothed_centers: dict[int, tuple[float, float]] = {} self.yolo_speeds: dict[int, float] = {} self.yolo_distance_from_start: dict[int, float] = {} self.yolo_threshold: float | None = None self.yolo_baseline_speed: float | None = None self.yolo_speed_std: float | None = None self.yolo_kick_frame: int | None = None self.yolo_status: str = "" self.yolo_kick_frames: list[int] = [] self.yolo_kick_speeds: list[float] = [] self.yolo_kick_distance: list[float] = [] self.yolo_mask_area_proxy: list[float] = [] self.yolo_initial_frame: int | None = None # SAM window (start_idx inclusive, end_idx exclusive) self.sam_window: tuple[int, int] | None = None # Cutout / compositing effects self.fx_soft_matte_enabled: bool = True self.fx_soft_matte_feather: float = 4.0 self.fx_soft_matte_erode: float = 0.5 self.fx_blur_enabled: bool = True self.fx_blur_sigma: float = 0.0 self.fx_bg_darkening: float = 0.75 self.fx_light_wrap_enabled: bool = False self.fx_light_wrap_strength: float = 0.6 self.fx_light_wrap_width: float = 15.0 self.fx_glow_enabled: bool = False self.fx_glow_strength: float = 0.4 self.fx_glow_radius: float = 10.0 self.fx_ghost_trail_enabled: bool = False self.fx_ball_ring_enabled: bool = True self.show_click_marks: bool = False # Ring FX parameters (initialized with defaults, but can be overridden by UI) self.fx_ring_thickness: float = BALL_RING_THICKNESS_PX self.fx_ring_alpha: float = BALL_RING_ALPHA self.fx_ring_feather: float = BALL_RING_FEATHER_SIGMA self.fx_ring_gamma: float = BALL_RING_INTENSITY_GAMMA self.fx_ring_duration: int = 30 # Default duration in frames self.fx_ring_scale_pct: float = RING_SIZE_SCALE_DEFAULT self.manual_kick_frame: int | None = None self.manual_impact_frame: int | None = None self.is_ball_detected: bool = False self.is_yolo_tracked: bool = False self.is_sam_tracked: bool = False self.is_player_detected: bool = False self.is_player_propagated: bool = False # Multi-ball candidate tracking self.ball_candidates: list[dict] = [] # All detected ball candidates self.ball_candidates_tracking: dict[int, dict] = {} # Per-candidate tracking data self.selected_ball_idx: int = 0 # Currently selected candidate index self.ball_selection_confirmed: bool = False # True after user confirms selection self.multi_ball_status: str = "" # Status message for multi-ball detection self.goal_mode: str = GOAL_MODE_IDLE self.goal_points_norm: list[tuple[float, float]] = [] self.goal_confirmed_points_norm: list[tuple[float, float]] = [] self.goal_prev_confirmed_points_norm: list[tuple[float, float]] = [] self.goal_overlay_points: list[tuple[float, float]] = [] self.goal_status_text: str = "Goal crossbar inactive." self.goal_dragging_idx: int | None = None def __repr__(self): return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})" @property def num_frames(self) -> int: return len(self.video_frames) def _model_repo_from_key(key: str) -> str: mapping = { "tiny": "facebook/sam2.1-hiera-tiny", "small": "facebook/sam2.1-hiera-small", "base_plus": "facebook/sam2.1-hiera-base-plus", "large": "facebook/sam2.1-hiera-large", } return mapping.get(key, mapping["base_plus"]) def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]: desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None: if GLOBAL_STATE.model_repo_id == desired_repo: return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype # Different repo requested: dispose current and reload GLOBAL_STATE.model = None GLOBAL_STATE.processor = None print(f"Loading model from {desired_repo}") device, dtype = get_device_and_dtype() # free up the gpu memory model = AutoModel.from_pretrained(desired_repo) processor = Sam2VideoProcessor.from_pretrained(desired_repo) model.to(device, dtype=dtype) GLOBAL_STATE.model = model GLOBAL_STATE.processor = processor GLOBAL_STATE.device = device GLOBAL_STATE.dtype = dtype GLOBAL_STATE.model_repo_id = desired_repo def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None: """Ensure the model/processor match the selected repo and inference_session exists. If a video is already loaded, re-initialize the inference session when needed. """ load_model_if_needed(GLOBAL_STATE) desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo: if GLOBAL_STATE.video_frames: # Clear session-related UI caches when switching model GLOBAL_STATE.masks_by_frame.clear() GLOBAL_STATE.clicks_by_frame_obj.clear() GLOBAL_STATE.boxes_by_frame_obj.clear() GLOBAL_STATE.composited_frames.clear() GLOBAL_STATE.inference_session = None GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session( inference_device=GLOBAL_STATE.device, video_storage_device="cpu", dtype=GLOBAL_STATE.dtype, ) GLOBAL_STATE.session_repo_id = desired_repo def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]: """Gradio handler: load video, init session, return state, slider bounds, and first frame.""" # Reset ONLY video-related fields, keep model loaded GLOBAL_STATE.video_frames = [] GLOBAL_STATE.inference_session = None GLOBAL_STATE.masks_by_frame = {} GLOBAL_STATE.color_by_obj = {} GLOBAL_STATE.ball_centers = {} GLOBAL_STATE.mask_areas = {} GLOBAL_STATE.smoothed_centers = {} GLOBAL_STATE.ball_speeds = {} GLOBAL_STATE.distance_from_start = {} GLOBAL_STATE.direction_change = {} GLOBAL_STATE.kick_frame = None GLOBAL_STATE.kalman_centers = {} GLOBAL_STATE.kalman_speeds = {} GLOBAL_STATE.kalman_residuals = {} GLOBAL_STATE.kick_debug_kalman_speeds = [] GLOBAL_STATE.kick_debug_frames = [] GLOBAL_STATE.kick_debug_speeds = [] GLOBAL_STATE.kick_debug_threshold = None GLOBAL_STATE.kick_debug_baseline = None GLOBAL_STATE.kick_debug_speed_std = None GLOBAL_STATE.kick_debug_area = [] GLOBAL_STATE.kick_debug_kick_frame = None GLOBAL_STATE.kick_debug_distance = [] GLOBAL_STATE.impact_frame = None GLOBAL_STATE.impact_debug_frames = [] GLOBAL_STATE.impact_debug_innovation = [] GLOBAL_STATE.impact_debug_innovation_threshold = None GLOBAL_STATE.impact_debug_direction = [] GLOBAL_STATE.impact_debug_direction_threshold = None GLOBAL_STATE.impact_debug_speed_kmh = [] GLOBAL_STATE.impact_debug_speed_threshold_px = None GLOBAL_STATE.impact_meters_per_px = None GLOBAL_STATE.goal_mode = GOAL_MODE_IDLE GLOBAL_STATE.goal_points_norm = [] GLOBAL_STATE.goal_confirmed_points_norm = [] GLOBAL_STATE.goal_prev_confirmed_points_norm = [] GLOBAL_STATE.goal_overlay_points = [] GLOBAL_STATE.goal_status_text = "Goal crossbar inactive." GLOBAL_STATE.goal_dragging_idx = None GLOBAL_STATE.goal_mode = GOAL_MODE_IDLE GLOBAL_STATE.goal_points_norm = [] GLOBAL_STATE.goal_confirmed_points_norm = [] GLOBAL_STATE.goal_prev_confirmed_points_norm = [] GLOBAL_STATE.goal_overlay_points = [] GLOBAL_STATE.goal_status_text = "Goal crossbar inactive." GLOBAL_STATE.goal_dragging_idx = None GLOBAL_STATE.yolo_ball_centers = {} GLOBAL_STATE.yolo_ball_boxes = {} GLOBAL_STATE.yolo_ball_conf = {} GLOBAL_STATE.yolo_smoothed_centers = {} GLOBAL_STATE.yolo_speeds = {} GLOBAL_STATE.yolo_distance_from_start = {} GLOBAL_STATE.yolo_threshold = None GLOBAL_STATE.yolo_baseline_speed = None GLOBAL_STATE.yolo_speed_std = None GLOBAL_STATE.yolo_kick_frame = None GLOBAL_STATE.yolo_status = "" GLOBAL_STATE.yolo_kick_frames = [] GLOBAL_STATE.yolo_kick_speeds = [] GLOBAL_STATE.yolo_kick_distance = [] GLOBAL_STATE.yolo_mask_area_proxy = [] GLOBAL_STATE.yolo_initial_frame = None GLOBAL_STATE.sam_window = None GLOBAL_STATE.player_obj_id = None GLOBAL_STATE.player_detection_frame = None GLOBAL_STATE.player_detection_conf = None GLOBAL_STATE.yolo_ball_centers = {} GLOBAL_STATE.yolo_ball_boxes = {} GLOBAL_STATE.yolo_ball_conf = {} GLOBAL_STATE.yolo_smoothed_centers = {} GLOBAL_STATE.yolo_speeds = {} GLOBAL_STATE.yolo_distance_from_start = {} GLOBAL_STATE.yolo_threshold = None GLOBAL_STATE.yolo_baseline_speed = None GLOBAL_STATE.yolo_speed_std = None GLOBAL_STATE.yolo_kick_frame = None GLOBAL_STATE.yolo_status = "" GLOBAL_STATE.yolo_kick_frames = [] GLOBAL_STATE.yolo_kick_speeds = [] GLOBAL_STATE.yolo_kick_distance = [] GLOBAL_STATE.yolo_mask_area_proxy = [] GLOBAL_STATE.yolo_initial_frame = None GLOBAL_STATE.sam_window = None GLOBAL_STATE.is_ball_detected = False GLOBAL_STATE.is_yolo_tracked = False GLOBAL_STATE.is_sam_tracked = False GLOBAL_STATE.is_player_detected = False GLOBAL_STATE.is_player_propagated = False load_model_if_needed(GLOBAL_STATE) # Gradio Video may provide a dict with 'name' or a direct file path video_path: Optional[str] = None if isinstance(video, dict): video_path = video.get("name") or video.get("path") or video.get("data") elif isinstance(video, str): video_path = video else: video_path = None if not video_path: raise gr.Error("Invalid video input.") frames, info = try_load_video_frames(video_path) if len(frames) == 0: raise gr.Error("No frames could be loaded from the video.") # Enforce max duration of 8 seconds (trim if longer) MAX_SECONDS = 8.0 trimmed_note = "" fps_in = info.get("fps") max_frames_allowed = int(MAX_SECONDS * fps_in) if len(frames) > max_frames_allowed: frames = frames[:max_frames_allowed] trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" if isinstance(info, dict): info["num_frames"] = len(frames) GLOBAL_STATE.video_frames = frames # Try to capture original FPS if provided by loader GLOBAL_STATE.video_fps = float(fps_in) # Initialize session inference_session = GLOBAL_STATE.processor.init_video_session( inference_device=GLOBAL_STATE.device, video_storage_device="cpu", dtype=GLOBAL_STATE.dtype, ) GLOBAL_STATE.inference_session = inference_session first_frame = frames[0] max_idx = len(frames) - 1 status = ( f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. " f"Device: {GLOBAL_STATE.device}, dtype: bfloat16" ) return GLOBAL_STATE, 0, max_idx, first_frame, status RING_RADIUS_CLAMP_RATIO = 0.2 # ±20% def _speed_to_color(ratio: float) -> tuple[int, int, int]: ratio = float(np.clip(ratio, 0.0, 1.0)) gradient = [ (255, 0, 0), # red (255, 165, 0), # orange (255, 255, 0), # yellow (0, 255, 0), # green ] segment = ratio * (len(gradient) - 1) idx = int(segment) frac = segment - idx if idx >= len(gradient) - 1: return gradient[-1] c1 = np.array(gradient[idx], dtype=float) c2 = np.array(gradient[idx + 1], dtype=float) blended = (1 - frac) * c1 + frac * c2 return tuple(int(v) for v in blended) def _speed_to_ring_color(speed_kmh: float) -> tuple[float, float, float]: """Map a speed value (km/h) to the discrete palette used across the app.""" for threshold, color in SPEED_COLOR_STOPS: if speed_kmh < threshold: return color return SPEED_COLOR_ABOVE_MAX def _get_prioritized_kick_frame(state: AppState) -> int | None: if state is None: return None for attr in ("kick_frame", "kick_debug_kick_frame", "yolo_kick_frame"): frame = getattr(state, attr, None) if frame is not None: return int(frame) return None def _median_smooth_radii(radii: list[float]) -> list[float]: if not radii: return [] if len(radii) < 3: return radii[:] smoothed: list[float] = [] n = len(radii) for i in range(n): window = radii[max(0, i - 1):min(n, i + 2)] smoothed.append(float(statistics.median(window))) return smoothed def _clamp_radii(radii: list[float], clamp_ratio: float = RING_RADIUS_CLAMP_RATIO) -> list[float]: if not radii: return [] clamped: list[float] = [] for i, value in enumerate(radii): val = max(0.0, float(value)) if i == 0: clamped.append(val) continue prev = clamped[-1] min_allowed = prev * (1.0 - clamp_ratio) max_allowed = prev * (1.0 + clamp_ratio) if prev <= FX_EPS: min_allowed = 0.0 max_allowed = val val = min(max(val, min_allowed), max_allowed) clamped.append(val) return clamped def _angle_between(v1: tuple[float, float], v2: tuple[float, float]) -> float: x1, y1 = v1 x2, y2 = v2 mag1 = math.hypot(x1, y1) mag2 = math.hypot(x2, y2) if mag1 < 1e-6 or mag2 < 1e-6: return 0.0 cos_val = (x1 * x2 + y1 * y2) / (mag1 * mag2) cos_val = max(-1.0, min(1.0, cos_val)) return math.degrees(math.acos(cos_val)) DISPLAY_MIN_WIDTH = 640 DISPLAY_MAX_WIDTH = 1280 FX_GLOW_COLOR = np.array([1.0, 0.1, 0.6], dtype=np.float32) FX_EPS = 1e-6 GHOST_TRAIL_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32) GHOST_TRAIL_ALPHA = 0.55 BALL_RING_ALPHA = 3.0 # Increased brightness BALL_RING_THICKNESS_PX = 1.0 # Thinner rings BALL_RING_FEATHER_SIGMA = 0.1 # Softer default blur BALL_RING_INTENSITY_GAMMA = 2.0 # Contrast shaping # Speed range palette (mirrors iOS app) SPEED_COLOR_STOPS = [ (30.0, (0 / 255.0, 191 / 255.0, 255 / 255.0)), # Electric Blue (50.0, (0 / 255.0, 191 / 255.0, 255 / 255.0)), # Electric Blue (same band) (70.0, (92 / 255.0, 124 / 255.0, 250 / 255.0)), # Blue Violet (90.0, (154 / 255.0, 77 / 255.0, 255 / 255.0)), # Intense Violet (110.0, (214 / 255.0, 51 / 255.0, 132 / 255.0)), # Fuchsia (130.0, (255 / 255.0, 77 / 255.0, 109 / 255.0)), # Strong Pink ] SPEED_COLOR_ABOVE_MAX = (255 / 255.0, 162 / 255.0, 0 / 255.0) # Neon Orange RING_RADIUS_CLAMP_RATIO = 0.2 # ±20% RING_SIZE_SCALE_DEFAULT = 125.0 # percent def _maybe_upscale_for_display(image: Image.Image) -> Image.Image: if image is None: return image original_width, original_height = image.size if original_width <= 0 or original_height <= 0: return image target_width = original_width if original_width < DISPLAY_MIN_WIDTH: target_width = DISPLAY_MIN_WIDTH elif original_width > DISPLAY_MAX_WIDTH: target_width = DISPLAY_MAX_WIDTH if target_width == original_width: return image scale = target_width / float(original_width) target_height = int(round(original_height * scale)) return image.resize((target_width, target_height), Image.BILINEAR) def _annotate_frame_index(image: Image.Image, frame_idx: int) -> Image.Image: if image is None: return image annotated = image.copy() draw = ImageDraw.Draw(annotated) text = f"Frame {frame_idx}" padding = 6 try: bbox = draw.textbbox((0, 0), text) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] except AttributeError: text_w, text_h = draw.textsize(text) x0, y0 = padding, padding x1, y1 = x0 + text_w + padding, y0 + text_h + padding draw.rectangle([(x0 - padding // 2, y0 - padding // 2), (x1, y1)], fill=(0, 0, 0)) draw.text((x0, y0), text, fill=(255, 255, 255)) return annotated def _apply_cutout_fx(state: "AppState", frame_np: np.ndarray, combined_mask: np.ndarray) -> np.ndarray: mask = np.clip(combined_mask.astype(np.float32), 0.0, 1.0) if mask.max() <= FX_EPS: # No foreground detected; fall back to darkened background choice bg = frame_np.copy() if state.fx_blur_enabled and state.fx_blur_sigma > FX_EPS: bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=state.fx_blur_sigma, sigmaY=state.fx_blur_sigma) bg = bg * (1.0 - np.clip(state.fx_bg_darkening, 0.0, 1.0)) return np.clip(bg * 255.0, 0, 255).astype(np.uint8) mask_soft = mask.copy() if state.fx_soft_matte_enabled: erode_px = max(0.0, float(state.fx_soft_matte_erode)) if erode_px > FX_EPS: kernel_size = int(round(erode_px * 2 + 1)) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) mask_soft = cv2.erode(mask_soft, kernel) feather = max(0.0, float(state.fx_soft_matte_feather)) if feather > FX_EPS: mask_soft = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=feather, sigmaY=feather) mask_soft = np.clip(mask_soft * 1.05, 0.0, 1.0) bg_source = frame_np.copy() if state.fx_blur_enabled and state.fx_blur_sigma > FX_EPS: bg_source = cv2.GaussianBlur(bg_source, (0, 0), sigmaX=state.fx_blur_sigma, sigmaY=state.fx_blur_sigma) darkening = np.clip(state.fx_bg_darkening, 0.0, 1.0) bg = bg_source * (1.0 - darkening) alpha = mask_soft[..., None] out = frame_np * alpha + bg * (1.0 - alpha) light_wrap_strength = float(state.fx_light_wrap_strength) light_wrap_width = max(0.0, float(state.fx_light_wrap_width)) if state.fx_light_wrap_enabled and light_wrap_strength > FX_EPS and light_wrap_width > FX_EPS: inner_blur = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=light_wrap_width, sigmaY=light_wrap_width) inner_edge = np.clip(mask_soft - inner_blur, 0.0, 1.0) if inner_edge.max() > FX_EPS: inner_edge /= (inner_edge.max() + FX_EPS) bg_wrap = cv2.GaussianBlur(bg_source, (0, 0), sigmaX=light_wrap_width * 1.5, sigmaY=light_wrap_width * 1.5) out = np.clip(out + inner_edge[..., None] * bg_wrap * light_wrap_strength, 0.0, 1.0) glow_strength = float(state.fx_glow_strength) glow_radius = max(0.0, float(state.fx_glow_radius)) if state.fx_glow_enabled and glow_strength > FX_EPS and glow_radius > FX_EPS: outer_blur = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=glow_radius, sigmaY=glow_radius) glow_band = np.clip(outer_blur - mask_soft, 0.0, 1.0) if glow_band.max() > FX_EPS: glow_band /= (glow_band.max() + FX_EPS) glow_color = FX_GLOW_COLOR[None, None, :] out = np.clip(out + glow_band[..., None] * glow_color * glow_strength, 0.0, 1.0) return np.clip(out * 255.0, 0, 255).astype(np.uint8) def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> Image.Image: if state is None or state.video_frames is None or len(state.video_frames) == 0: return None frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) goal_overlay_active = bool(getattr(state, "goal_points_norm", [])) or bool(getattr(state, "goal_overlay_points", [])) frame = state.video_frames[frame_idx] masks = state.masks_by_frame.get(frame_idx, {}) out_img: Image.Image | None = state.composited_frames.get(frame_idx) if out_img is None: out_img = frame current_union_mask: np.ndarray | None = None focus_mask: np.ndarray | None = None ball_mask_main: np.ndarray | None = None for obj_id, mask in masks.items(): if mask is None: continue mask_np = mask.astype(np.float32) if mask_np.ndim == 3: mask_np = mask_np.squeeze() mask_np = np.clip(mask_np, 0.0, 1.0) if current_union_mask is None: current_union_mask = np.zeros_like(mask_np, dtype=np.float32) current_union_mask = np.maximum(current_union_mask, mask_np) if obj_id in (BALL_OBJECT_ID, PLAYER_OBJECT_ID): if focus_mask is None: focus_mask = np.zeros_like(mask_np, dtype=np.float32) focus_mask = np.maximum(focus_mask, mask_np) if obj_id == BALL_OBJECT_ID: ball_mask_main = mask ghost_mask = _build_ball_trail_mask(state, frame_idx) ring_result = _build_ball_ring_mask(state, frame_idx) if len(masks) != 0: if remove_bg: # Remove background - show only tracked objects frame_np = np.array(frame).astype(np.float32) / 255.0 combined_mask = current_union_mask if combined_mask is None: combined_mask = np.zeros((frame_np.shape[0], frame_np.shape[1]), dtype=np.float32) # Apply falloff to ball component when rendering foreground if BALL_OBJECT_ID in masks: ball_mask = masks[BALL_OBJECT_ID] if ball_mask is not None: combined_mask = np.maximum( combined_mask, _apply_radial_falloff(np.clip(ball_mask.astype(np.float32), 0.0, 1.0), strength=1.0, solid_ratio=0.8), ) result_np = _apply_cutout_fx(state, frame_np, combined_mask) out_img = Image.fromarray(result_np) else: if masks: out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65) # Overlay feathered ball on top if BALL_OBJECT_ID in masks: ball_mask = masks[BALL_OBJECT_ID] if ball_mask is not None: ball_alpha = _apply_radial_falloff(ball_mask, strength=1.0, solid_ratio=0.8) if ball_alpha is not None and ball_alpha.max() > FX_EPS: base_np = np.array(out_img).astype(np.float32) / 255.0 color = np.array(state.color_by_obj.get(BALL_OBJECT_ID, (255, 255, 0)), dtype=np.float32) / 255.0 alpha = np.clip(ball_alpha[..., None], 0.0, 1.0) base_np = (1.0 - alpha) * base_np + alpha * color out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8)) if ghost_mask is not None: ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0) if current_union_mask is not None: ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0) if ghost_np.max() > FX_EPS: base_np = np.array(out_img).astype(np.float32) / 255.0 ghost_alpha = ghost_np[..., None] base_np = (1.0 - GHOST_TRAIL_ALPHA * ghost_alpha) * base_np + ( GHOST_TRAIL_ALPHA * ghost_alpha ) * GHOST_TRAIL_COLOR if focus_mask is not None: focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None] orig_np = np.array(frame).astype(np.float32) / 255.0 base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8)) if ring_result is not None: ring_presence, ring_color_map = ring_result ring_presence = np.clip(ring_presence.astype(np.float32), 0.0, 1.0) ring_color_map = np.clip(ring_color_map.astype(np.float32), 0.0, 1.0) if current_union_mask is not None: if ball_mask_main is not None: ball_np = np.clip(ball_mask_main.astype(np.float32), 0.0, 1.0) mask_block = np.maximum(current_union_mask - ball_np, 0.0) else: mask_block = current_union_mask mask_keep = np.clip(1.0 - mask_block, 0.0, 1.0) ring_presence = ring_presence * mask_keep ring_color_map = ring_color_map * mask_keep[..., None] if ring_presence.max() > FX_EPS and ring_color_map.max() > FX_EPS: base_np = np.array(out_img).astype(np.float32) / 255.0 alpha_val = getattr(state, "fx_ring_alpha", BALL_RING_ALPHA) added_light = np.clip(ring_color_map * alpha_val, 0.0, 1.0) base_np = np.clip(base_np + added_light, 0.0, 1.0) if focus_mask is not None: focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None] orig_np = np.array(frame).astype(np.float32) / 255.0 base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8)) _draw_goal_overlay(state, frame_idx, out_img) # Draw crosses for conditioning frames only (frames with recorded clicks) clicks_map = state.clicks_by_frame_obj.get(frame_idx) if state.show_click_marks and clicks_map: draw = ImageDraw.Draw(out_img) cross_half = 6 for obj_id, pts in clicks_map.items(): for x, y, lbl in pts: color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0) # horizontal draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) # vertical draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) # Draw temporary cross for first corner in box mode if ( state.show_click_marks and state.pending_box_start is not None and state.pending_box_start_frame_idx == frame_idx and state.pending_box_start_obj_id is not None ): draw = ImageDraw.Draw(out_img) x, y = state.pending_box_start cross_half = 6 color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255)) draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) # Draw boxes for conditioning frames box_map = state.boxes_by_frame_obj.get(frame_idx) if state.show_click_marks and box_map: draw = ImageDraw.Draw(out_img) for obj_id, boxes in box_map.items(): color = state.color_by_obj.get(obj_id, (255, 255, 255)) for x1, y1, x2, y2 in boxes: draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2) # Draw trajectory centers (all frames) if state.show_click_marks and state.ball_centers: draw = ImageDraw.Draw(out_img) cross_half = 4 for obj_id, centers in state.ball_centers.items(): if not centers: continue raw_items = sorted(centers.items()) for _, (rx, ry) in raw_items: draw.line([(rx - cross_half, ry), (rx + cross_half, ry)], fill=(160, 160, 160), width=1) draw.line([(rx, ry - cross_half), (rx, ry + cross_half)], fill=(160, 160, 160), width=1) smooth_dict = state.smoothed_centers.get(obj_id, {}) if not smooth_dict: continue smooth_items = sorted(smooth_dict.items()) distances: list[float] = [] prev_center = None for _, (sx, sy) in smooth_items: if prev_center is None: distances.append(0.0) else: dx = sx - prev_center[0] dy = sy - prev_center[1] distances.append(float(np.hypot(dx, dy))) prev_center = (sx, sy) max_dist = max(distances[1:], default=0.0) color_by_frame: dict[int, tuple[int, int, int]] = {} for (f_idx, _), dist in zip(smooth_items, distances): ratio = dist / max_dist if max_dist > 0 else 0.0 color_by_frame[f_idx] = _speed_to_color(ratio) for f_idx, (sx, sy) in reversed(smooth_items): highlight = (f_idx == frame_idx) color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0)) line_width = 1 if not highlight else 2 draw.line([(sx - cross_half, sy), (sx + cross_half, sy)], fill=color, width=line_width) draw.line([(sx, sy - cross_half), (sx, sy + cross_half)], fill=color, width=line_width) # Save to cache and return if not remove_bg and not goal_overlay_active: state.composited_frames[frame_idx] = out_img return out_img def update_frame_display(state: AppState, frame_idx: int) -> Image.Image: if state is None or state.video_frames is None or len(state.video_frames) == 0: return None frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) # Serve from cache when available cached = state.composited_frames.get(frame_idx) if cached is not None: return _maybe_upscale_for_display(cached) composed = compose_frame(state, frame_idx) return _maybe_upscale_for_display(composed) def _update_fx_controls( state: AppState, soft_enabled: bool, soft_feather: float, soft_erode: float, blur_enabled: bool, blur_sigma: float, bg_darkening: float, wrap_enabled: bool, wrap_strength: float, wrap_width: float, glow_enabled: bool, glow_strength: float, glow_radius: float, # New parameters ring_thickness: float, ring_alpha: float, ring_feather: float, ring_gamma: float, ring_scale_pct: float, ring_duration: float, ) -> Image.Image: if state is None: return None state.fx_soft_matte_enabled = bool(soft_enabled) state.fx_soft_matte_feather = max(0.0, float(soft_feather)) state.fx_soft_matte_erode = max(0.0, float(soft_erode)) state.fx_blur_enabled = bool(blur_enabled) state.fx_blur_sigma = max(0.0, float(blur_sigma)) state.fx_bg_darkening = float(np.clip(bg_darkening, 0.0, 1.0)) state.fx_light_wrap_enabled = bool(wrap_enabled) state.fx_light_wrap_strength = max(0.0, float(wrap_strength)) state.fx_light_wrap_width = max(0.0, float(wrap_width)) state.fx_glow_enabled = bool(glow_enabled) state.fx_glow_strength = max(0.0, float(glow_strength)) state.fx_glow_radius = max(0.0, float(glow_radius)) # Update Ring FX state.fx_ring_thickness = max(0.1, float(ring_thickness)) state.fx_ring_alpha = max(0.0, float(ring_alpha)) state.fx_ring_feather = max(0.0, float(ring_feather)) state.fx_ring_gamma = max(0.1, float(ring_gamma)) state.fx_ring_duration = int(max(0, float(ring_duration))) state.fx_ring_scale_pct = float(np.clip(ring_scale_pct, 10.0, 200.0)) state.composited_frames.clear() idx = int(getattr(state, "current_frame_idx", 0)) return update_frame_display(state, idx) def _toggle_ghost_trail(state: AppState, enabled: bool) -> Image.Image: if state is None: return None state.fx_ghost_trail_enabled = bool(enabled) state.composited_frames.clear() idx = int(getattr(state, "current_frame_idx", 0)) return update_frame_display(state, idx) def _toggle_ball_ring(state: AppState, enabled: bool) -> Image.Image: if state is None: return None state.fx_ball_ring_enabled = bool(enabled) state.composited_frames.clear() idx = int(getattr(state, "current_frame_idx", 0)) return update_frame_display(state, idx) def _toggle_click_marks(state: AppState, enabled: bool) -> Image.Image: if state is None: return None state.show_click_marks = bool(enabled) state.composited_frames.clear() idx = int(getattr(state, "current_frame_idx", 0)) return update_frame_display(state, idx) def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None: if ( state is None or not state.fx_ghost_trail_enabled or state.masks_by_frame is None ): return None if state.fx_ball_ring_enabled: # When ring rendering is active we skip building the filled ghost mask. return None kick_candidate = _get_prioritized_kick_frame(state) if kick_candidate is None: return None if int(frame_idx) <= int(kick_candidate): start_idx = int(kick_candidate) + 1 else: start_idx = max(int(kick_candidate) + 1, int(frame_idx)) end_idx = state.num_frames if start_idx >= end_idx: return None trail_mask: np.ndarray | None = None for idx in range(start_idx, end_idx): frame_masks = state.masks_by_frame.get(idx) if not frame_masks: continue mask = frame_masks.get(BALL_OBJECT_ID) if mask is None: continue mask_np = mask.astype(np.float32) if mask_np.ndim == 3: mask_np = mask_np.squeeze() mask_np = np.clip(mask_np, 0.0, 1.0) mask_np = _apply_radial_falloff(mask_np, strength=1.0, solid_ratio=0.8) if trail_mask is None: trail_mask = np.zeros_like(mask_np, dtype=np.float32) if trail_mask.shape != mask_np.shape: continue trail_mask = np.maximum(trail_mask, mask_np) return trail_mask def _build_ball_ring_mask( state: AppState, frame_idx: int ) -> tuple[np.ndarray, np.ndarray] | None: if ( state is None or not state.fx_ball_ring_enabled or state.masks_by_frame is None ): return None kick_candidate = _get_prioritized_kick_frame(state) if kick_candidate is None: return None if int(frame_idx) <= int(kick_candidate): start_idx = int(kick_candidate) + 1 else: start_idx = max(int(kick_candidate) + 1, int(frame_idx)) # Determine end frame based on duration limit duration = getattr(state, "fx_ring_duration", 16) limit_idx = int(kick_candidate) + 1 + duration end_idx = min(state.num_frames, limit_idx) if start_idx >= end_idx: return None ring_entries: list[tuple[int, tuple[int, int], float, np.ndarray, float]] = [] canvas_shape: tuple[int, int] | None = None ring_presence: np.ndarray | None = None ring_color_map: np.ndarray | None = None fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0 distance_m = state.goal_distance_m if state.goal_distance_m and state.goal_distance_m > 0 else 16.5 # Iterate in REVERSE order so that later frames (further in time/distance) are drawn first, # and earlier frames (closer in time/distance) are drawn on top. # This ensures the "nearest" rings (temporally) obscure the "further" rings. for idx in range(end_idx - 1, start_idx - 1, -1): frame_masks = state.masks_by_frame.get(idx) if not frame_masks: continue mask = frame_masks.get(BALL_OBJECT_ID) if mask is None: continue mask_np = mask.astype(np.float32) if mask_np.ndim == 3: mask_np = mask_np.squeeze() if mask_np.size == 0: continue mask_np = np.clip(mask_np, 0.0, 1.0) if mask_np.max() <= FX_EPS: continue if canvas_shape is None: canvas_shape = mask_np.shape if canvas_shape != mask_np.shape: continue centroid = _compute_mask_centroid(mask_np) if centroid is None: continue cx, cy = centroid ys, xs = np.nonzero(mask_np > 0.05) if xs.size == 0 or ys.size == 0: continue min_x, max_x = xs.min(), xs.max() min_y, max_y = ys.min(), ys.max() radius_x = (max_x - min_x + 1) / 2.0 radius_y = (max_y - min_y + 1) / 2.0 radius = float(max(radius_x, radius_y)) if radius <= 1.5: continue # Use dynamic parameters from state if available, else defaults thick_val = getattr(state, "fx_ring_thickness", BALL_RING_THICKNESS_PX) center = (int(round(cx)), int(round(cy))) radius_int = max(1, int(round(radius))) delta_frames = max(1, idx - int(kick_candidate)) time_s = max(delta_frames / fps, 1.0 / fps) speed_kmh = max(0.0, (distance_m / time_s) * 3.6) color_vec = np.array(_speed_to_ring_color(speed_kmh), dtype=np.float32) ring_entries.append((idx, center, radius, color_vec, thick_val)) if not ring_entries or canvas_shape is None: return None raw_radii = [entry[2] for entry in ring_entries] smoothed = _median_smooth_radii(raw_radii) smoothed = _clamp_radii(smoothed) base_radius = smoothed[0] if smoothed else 1.0 if base_radius <= FX_EPS: base_radius = 1.0 h, w = canvas_shape ring_presence = np.zeros((h, w), dtype=np.float32) ring_color_map = np.zeros((h, w, 3), dtype=np.float32) base_feather = getattr(state, "fx_ring_feather", BALL_RING_FEATHER_SIGMA) base_gamma = getattr(state, "fx_ring_gamma", BALL_RING_INTENSITY_GAMMA) scale_factor = float(getattr(state, "fx_ring_scale_pct", RING_SIZE_SCALE_DEFAULT)) / 100.0 scale_factor = np.clip(scale_factor, 0.1, 2.0) for (entry, smooth_radius) in zip(ring_entries, smoothed): _, center, _, color_vec, thick_val = entry radius_ratio = smooth_radius / base_radius if base_radius > FX_EPS else 1.0 radius_ratio = float(np.clip(radius_ratio, 0.05, 1.0)) radius_val = max(1.0, smooth_radius * scale_factor) radius_int = max(1, int(round(radius_val))) ring_local = np.zeros((h, w), dtype=np.float32) thickness_scale = max(0.1, radius_ratio) t_glow = max(1, int(round(thick_val * 4.0 * thickness_scale))) cv2.circle(ring_local, center, radius_int, 0.3, thickness=t_glow) t_mid = max(1, int(round(thick_val * 2.0 * thickness_scale))) cv2.circle(ring_local, center, radius_int, 0.6, thickness=t_mid) t_core = max(1, int(round(thick_val * thickness_scale))) cv2.circle(ring_local, center, radius_int, 1.0, thickness=t_core) effective_feather = max(0.0, base_feather * radius_ratio) ring_local = cv2.GaussianBlur(ring_local, (0, 0), sigmaX=effective_feather, sigmaY=effective_feather) if ring_local.max() <= FX_EPS: continue effective_gamma = max(0.1, base_gamma * radius_ratio) if abs(effective_gamma - 1.0) > 1e-6: ring_local = np.power(np.clip(ring_local, 0.0, 1.0), effective_gamma) ring_local = np.clip(ring_local * radius_ratio, 0.0, 1.0) ring_presence = np.maximum(ring_presence, ring_local) ring_color_map += ring_local[..., None] * color_vec if ring_presence.max() <= FX_EPS or ring_color_map.max() <= FX_EPS: return None return np.clip(ring_presence, 0.0, 1.0), np.clip(ring_color_map, 0.0, 1.0) def _ensure_color_for_obj(state: AppState, obj_id: int): if obj_id not in state.color_by_obj: state.color_by_obj[obj_id] = pastel_color_for_object(obj_id) def _compute_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None: if mask is None: return None mask_np = np.array(mask) if mask_np.ndim == 3: mask_np = mask_np.squeeze() if mask_np.size == 0: return None mask_float = np.clip(mask_np, 0.0, 1.0).astype(np.float32) moments = cv2.moments(mask_float) if moments["m00"] == 0: return None cx = int(moments["m10"] / moments["m00"]) cy = int(moments["m01"] / moments["m00"]) return cx, cy def _apply_radial_falloff(mask: np.ndarray, strength: float = 1.0, solid_ratio: float = 0.8) -> np.ndarray: if mask is None: return None mask_np = np.clip(mask.astype(np.float32), 0.0, 1.0) if mask_np.ndim == 3: mask_np = mask_np.squeeze() if mask_np.max() <= FX_EPS: return mask_np centroid = _compute_mask_centroid(mask_np) if centroid is None: return mask_np cx, cy = centroid h, w = mask_np.shape yy, xx = np.ogrid[:h, :w] dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2) max_dist = dist[mask_np > FX_EPS].max() if np.any(mask_np > FX_EPS) else 0.0 if max_dist <= FX_EPS: return mask_np if solid_ratio >= 1.0: return mask_np clipped_dist = np.clip((dist / max_dist - solid_ratio) / (1.0 - solid_ratio), 0.0, 1.0) falloff = 1.0 - np.power(clipped_dist, strength) return np.clip(mask_np * falloff, 0.0, 1.0) def _update_centroids_for_frame(state: AppState, frame_idx: int): if state is None: return masks = state.masks_by_frame.get(int(frame_idx), {}) seen_obj_ids: set[int] = set() for obj_id, mask in masks.items(): centroid = _compute_mask_centroid(mask) centers = state.ball_centers.setdefault(int(obj_id), {}) if centroid is not None: centers[int(frame_idx)] = centroid else: centers.pop(int(frame_idx), None) seen_obj_ids.add(int(obj_id)) _ensure_color_for_obj(state, int(obj_id)) mask_np = np.array(mask) if mask_np.ndim == 3: mask_np = mask_np.squeeze() mask_np = np.clip(mask_np, 0.0, 1.0) area = float(np.count_nonzero(mask_np > 0.3)) areas = state.mask_areas.setdefault(int(obj_id), {}) areas[int(frame_idx)] = area # Remove frames for objects without masks at this frame for obj_id, centers in state.ball_centers.items(): if obj_id not in seen_obj_ids: centers.pop(int(frame_idx), None) for obj_id, areas in state.mask_areas.items(): if obj_id not in seen_obj_ids: areas.pop(int(frame_idx), None) _recompute_motion_metrics(state) def _run_kalman_filter( ordered_items: list[tuple[int, tuple[float, float]]], base_dt: float, ) -> tuple[dict[int, tuple[float, float]], dict[int, float], dict[int, float]]: if not ordered_items: return {}, {}, {} H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=float) R = np.eye(2, dtype=float) * 25.0 state_vec = np.array( [ordered_items[0][1][0], ordered_items[0][1][1], 0.0, 0.0], dtype=float ) P = np.eye(4, dtype=float) * 50.0 positions: dict[int, tuple[float, float]] = {} speeds: dict[int, float] = {} residuals: dict[int, float] = {} prev_frame = ordered_items[0][0] for frame_idx, (cx, cy) in ordered_items: frame_delta = max(1, frame_idx - prev_frame) if frame_idx != prev_frame else 1 dt = frame_delta * base_dt F = np.array( [ [1, 0, dt, 0], [0, 1, 0, dt], [0, 0, 1, 0], [0, 0, 0, 1], ], dtype=float, ) q = 0.5 * dt**2 Q = np.array( [ [q, 0, dt, 0], [0, q, 0, dt], [dt, 0, 1, 0], [0, dt, 0, 1], ], dtype=float, ) * 0.05 state_vec = F @ state_vec P = F @ P @ F.T + Q z = np.array([cx, cy], dtype=float) innovation = z - H @ state_vec S = H @ P @ H.T + R K = P @ H.T @ np.linalg.inv(S) state_vec = state_vec + K @ innovation P = (np.eye(4) - K @ H) @ P positions[frame_idx] = (state_vec[0], state_vec[1]) speeds[frame_idx] = float(math.hypot(state_vec[2], state_vec[3])) residuals[frame_idx] = float(math.hypot(innovation[0], innovation[1])) prev_frame = frame_idx return positions, speeds, residuals def _build_kick_plot(state: AppState): fig = go.Figure() if state is None or not state.kick_debug_frames or not state.kick_debug_speeds: fig.update_layout( title="Kick & impact diagnostics", xaxis_title="Frame", yaxis_title="Speed (px/s)", ) return fig frames = state.kick_debug_frames speeds = state.kick_debug_speeds areas = state.kick_debug_area if state.kick_debug_area else [0.0] * len(frames) threshold = state.kick_debug_threshold or 0.0 baseline = state.kick_debug_baseline or 0.0 kick_frame = state.kick_debug_kick_frame distance = state.kick_debug_distance if state.kick_debug_distance else [0.0] * len(frames) impact_frames = state.impact_debug_frames if state.impact_debug_frames else frames fig.add_trace( go.Scatter( x=frames, y=speeds, mode="lines+markers", name="Speed (px/s)", line=dict(color="#1f77b4"), ) ) fig.add_trace( go.Scatter( x=[frames[0], frames[-1]], y=[threshold, threshold], mode="lines", name="Adaptive threshold", line=dict(color="#d62728", dash="dash"), ) ) fig.add_trace( go.Scatter( x=[frames[0], frames[-1]], y=[baseline, baseline], mode="lines", name="Baseline speed", line=dict(color="#ff7f0e", dash="dot"), ) ) fig.add_trace( go.Scatter( x=frames, y=areas, mode="lines", name="Mask area", line=dict(color="#2ca02c"), yaxis="y2", ) ) max_primary = max( max(speeds) if speeds else 0.0, threshold, baseline, max(state.kick_debug_kalman_speeds) if state.kick_debug_kalman_speeds else 0.0, state.impact_debug_innovation_threshold or 0.0, state.impact_debug_direction_threshold or 0.0, state.impact_debug_speed_threshold_px or 0.0, 1.0, ) max_distance = max(distance) if distance else 0.0 if max_distance > 0 and max_primary > 0: distance_scaled = [d * (max_primary / max_distance) for d in distance] else: distance_scaled = distance fig.add_trace( go.Scatter( x=frames, y=distance_scaled, mode="lines", name="Distance from start (scaled)", line=dict(color="#9467bd"), ) ) if state.kick_debug_kalman_speeds: fig.add_trace( go.Scatter( x=frames, y=state.kick_debug_kalman_speeds, mode="lines", name="Kalman speed", line=dict(color="#8c564b"), ) ) if state.impact_debug_innovation: fig.add_trace( go.Scatter( x=impact_frames, y=state.impact_debug_innovation, mode="lines", name="Kalman innovation", line=dict(color="#17becf"), ) ) max_primary = max(max_primary, max(state.impact_debug_innovation)) if ( state.impact_debug_innovation_threshold is not None and impact_frames and len(impact_frames) >= 2 ): fig.add_trace( go.Scatter( x=[impact_frames[0], impact_frames[-1]], y=[ state.impact_debug_innovation_threshold, state.impact_debug_innovation_threshold, ], mode="lines", name="Innovation threshold", line=dict(color="#17becf", dash="dash"), ) ) max_primary = max(max_primary, state.impact_debug_innovation_threshold or 0.0) if state.impact_debug_direction: fig.add_trace( go.Scatter( x=impact_frames, y=state.impact_debug_direction, mode="lines", name="Direction change (deg)", line=dict(color="#bcbd22"), ) ) max_primary = max(max_primary, max(state.impact_debug_direction)) if ( state.impact_debug_direction_threshold is not None and impact_frames and len(impact_frames) >= 2 ): fig.add_trace( go.Scatter( x=[impact_frames[0], impact_frames[-1]], y=[ state.impact_debug_direction_threshold, state.impact_debug_direction_threshold, ], mode="lines", name="Direction threshold (deg)", line=dict(color="#bcbd22", dash="dot"), ) ) max_primary = max(max_primary, state.impact_debug_direction_threshold or 0.0) if state.impact_debug_speed_threshold_px: fig.add_trace( go.Scatter( x=[frames[0], frames[-1]], y=[state.impact_debug_speed_threshold_px] * 2, mode="lines", name="Min impact speed (px/s)", line=dict(color="#b82e2e", dash="dot"), ) ) max_primary = max(max_primary, state.impact_debug_speed_threshold_px or 0.0) if kick_frame is not None: fig.add_trace( go.Scatter( x=[kick_frame, kick_frame], y=[0, max_primary * 1.05], mode="lines", name="Detected kick", line=dict(color="#ff00ff", dash="solid", width=3), ) ) impact_frame = state.impact_frame if impact_frame is not None: fig.add_trace( go.Scatter( x=[impact_frame, impact_frame], y=[0, max_primary * 1.05], mode="lines", name="Detected impact", line=dict(color="#ff1493", width=3), ) ) fig.update_layout( title="Kick & impact diagnostics", xaxis_title="Frame", yaxis_title="Speed (px/s)", yaxis=dict(side="left"), yaxis2=dict( title="Mask area / Distance / Direction", overlaying="y", side="right", showgrid=False, ), legend=dict(orientation="h"), margin=dict(t=40, l=40, r=40, b=40), ) return fig def _ensure_ball_prompt_from_yolo(state: AppState): if ( state is None or state.inference_session is None or not state.yolo_ball_centers ): return # Check if we already have clicks for the ball for frame_clicks in state.clicks_by_frame_obj.values(): if frame_clicks.get(BALL_OBJECT_ID): return anchor_frame = state.yolo_initial_frame if anchor_frame is None and state.yolo_ball_centers: anchor_frame = min(state.yolo_ball_centers.keys()) if anchor_frame is None or anchor_frame >= state.num_frames: return center = state.yolo_ball_centers.get(anchor_frame) if center is None: return x_center, y_center = center frame_width, frame_height = state.video_frames[anchor_frame].size x_center = int(np.clip(round(x_center), 0, frame_width - 1)) y_center = int(np.clip(round(y_center), 0, frame_height - 1)) event = SimpleNamespace( index=(x_center, y_center), value={"x": x_center, "y": y_center}, ) state.current_obj_id = BALL_OBJECT_ID state.current_label = "positive" state.current_frame_idx = anchor_frame on_image_click( update_frame_display(state, anchor_frame), state, anchor_frame, BALL_OBJECT_ID, "positive", False, event, ) def _build_yolo_plot(state: AppState): fig = go.Figure() if state is None or not state.yolo_kick_frames or not state.yolo_kick_speeds: fig.update_layout( title="YOLO kick diagnostics", xaxis_title="Frame", yaxis_title="Speed (px/s)", ) return fig frames = state.yolo_kick_frames speeds = state.yolo_kick_speeds distance = state.yolo_kick_distance if state.yolo_kick_distance else [0.0] * len(frames) areas = state.yolo_mask_area_proxy if state.yolo_mask_area_proxy else [0.0] * len(frames) threshold = state.yolo_threshold or 0.0 baseline = state.yolo_baseline_speed or 0.0 kick_frame = state.yolo_kick_frame fig.add_trace( go.Scatter( x=frames, y=speeds, mode="lines+markers", name="YOLO speed", line=dict(color="#4caf50"), ) ) fig.add_trace( go.Scatter( x=frames, y=[threshold] * len(frames), mode="lines", name="Adaptive threshold", line=dict(color="#ff9800", dash="dash"), ) ) fig.add_trace( go.Scatter( x=frames, y=[baseline] * len(frames), mode="lines", name="Baseline speed", line=dict(color="#9e9e9e", dash="dot"), ) ) fig.add_trace( go.Scatter( x=frames, y=distance, mode="lines", name="Distance from start", line=dict(color="#03a9f4"), yaxis="y2", ) ) fig.add_trace( go.Scatter( x=frames, y=areas, mode="lines", name="Box area proxy", line=dict(color="#ab47bc", dash="dot"), yaxis="y2", ) ) if kick_frame is not None: fig.add_vline( x=kick_frame, line=dict(color="#e91e63", width=2), annotation_text=f"Kick {kick_frame}", annotation_position="top right", ) fig.update_layout( title="YOLO kick diagnostics", xaxis=dict(title="Frame"), yaxis=dict(title="Speed (px/s)"), yaxis2=dict( title="Distance / Area", overlaying="y", side="right", showgrid=False, ), legend=dict(orientation="h"), margin=dict(t=40, l=40, r=40, b=40), ) return fig def _build_multi_ball_chart(state: AppState): """ Build a combined speed chart showing all ball candidates. The selected/kicked ball is highlighted in green, others in gray. """ fig = go.Figure() if state is None or not state.ball_candidates: fig.update_layout( title="Ball Candidates Speed Comparison", xaxis_title="Frame", yaxis_title="Speed (px/s)", ) return fig # Color palette for candidates colors = [ "#4caf50", # Green (selected/kicked) "#9e9e9e", # Gray "#bdbdbd", # Light gray "#757575", # Dark gray "#e0e0e0", # Very light gray ] selected_idx = state.selected_ball_idx max_speed = 0.0 kick_frames_to_mark = [] for i, candidate in enumerate(state.ball_candidates): tracking = candidate.get("tracking", {}) frames = tracking.get("frames_ordered", []) speeds = tracking.get("speed_series", []) if not frames or not speeds: continue max_speed = max(max_speed, max(speeds) if speeds else 0) is_selected = (i == selected_idx) is_kicked = candidate.get("has_kick", False) # Determine color and style if is_selected: color = "#4caf50" # Green width = 3 opacity = 1.0 elif is_kicked: color = "#ff9800" # Orange for other kicked balls width = 2 opacity = 0.7 else: color = "#9e9e9e" # Gray width = 1 opacity = 0.5 # Build label label_parts = [f"Ball {i+1}"] if is_kicked: label_parts.append("⚽") if is_selected: label_parts.append("✓") label = " ".join(label_parts) fig.add_trace( go.Scatter( x=frames, y=speeds, mode="lines", name=label, line=dict(color=color, width=width), opacity=opacity, ) ) # Mark kick frame kick_frame = candidate.get("kick_frame") if kick_frame is not None: kick_frames_to_mark.append((kick_frame, i, is_selected)) # Add vertical lines for kick frames for kick_frame, ball_idx, is_selected in kick_frames_to_mark: color = "#e91e63" if is_selected else "#ffcdd2" width = 3 if is_selected else 1 fig.add_vline( x=kick_frame, line=dict(color=color, width=width, dash="solid" if is_selected else "dot"), annotation_text=f"Ball {ball_idx+1} kick" if is_selected else "", annotation_position="top right" if is_selected else None, ) fig.update_layout( title="Ball Candidates Speed Comparison", xaxis=dict(title="Frame"), yaxis=dict(title="Speed (px/s)", range=[0, max_speed * 1.1] if max_speed > 0 else None), legend=dict(orientation="h", yanchor="bottom", y=1.02), margin=dict(t=60, l=40, r=40, b=40), hovermode="x unified", ) return fig def _jump_to_frame(state: AppState, target: int | None): if state is None or state.num_frames == 0 or target is None: return gr.update(), gr.update() idx = int(np.clip(int(target), 0, state.num_frames - 1)) state.current_frame_idx = idx return ( update_frame_display(state, idx), gr.update(value=idx), ) def _jump_to_yolo_kick(state: AppState): return _jump_to_frame(state, getattr(state, "yolo_kick_frame", None)) def _jump_to_sam_kick(state: AppState): return _jump_to_frame(state, _get_prioritized_kick_frame(state)) def _jump_to_sam_impact(state: AppState): impact = getattr(state, "impact_frame", None) if impact is None: frames = getattr(state, "impact_debug_frames", []) if frames: impact = frames[-1] return _jump_to_frame(state, impact) def _jump_to_manual_kick(state: AppState): return _jump_to_frame(state, getattr(state, "manual_kick_frame", None)) def _jump_to_manual_impact(state: AppState): return _jump_to_frame(state, getattr(state, "manual_impact_frame", None)) def _format_impact_status(state: AppState) -> str: def fmt(value: int | None) -> str: return str(int(value)) if value is not None else "N/A" def impact_value(st: AppState | None) -> int | None: if st is None: return None if st.impact_frame is not None: return st.impact_frame debug_frames = getattr(st, "impact_debug_frames", []) if debug_frames: return debug_frames[-1] return None yolo_kick = fmt(getattr(state, "yolo_kick_frame", None) if state else None) sam_kick = fmt(_get_prioritized_kick_frame(state)) sam_impact = fmt(impact_value(state)) lines = [ f"YOLO13 · Kick ⚽ {yolo_kick} · Impact 🚩 N/A", f"SAM2 · Kick ⚽ {sam_kick} · Impact 🚩 {sam_impact}", ] return "\n".join(lines) def _format_kick_text(state: AppState) -> str: if state is None: return "Kick: n/a" parts = [] if getattr(state, "yolo_kick_frame", None) is not None: parts.append(f"YOLO ⚽ {state.yolo_kick_frame}") sam_kick = _get_prioritized_kick_frame(state) if sam_kick is not None: parts.append(f"SAM ⚽ {sam_kick}") if parts: return " | ".join(parts) return "Kick: n/a" def _format_impact_text(state: AppState) -> str: if state is None: return "Impact: n/a" impact = getattr(state, "impact_frame", None) if impact is None: frames = getattr(state, "impact_debug_frames", []) if frames: impact = frames[-1] return f"SAM 🚩 {impact}" if impact is not None else "Impact: n/a" def _format_kick_status(state: AppState) -> str: if state is None or not isinstance(state, AppState): return "Kick frame: not computed" frame = state.kick_frame if frame is None: frame = getattr(state, "kick_debug_kick_frame", None) if frame is None: if state.kick_debug_frames: return "Kick frame: not detected" return "Kick frame: not computed" if state.kick_frame is None and frame is not None: state.kick_frame = frame time_part = "" if state.video_fps and state.video_fps > 1e-6: time_part = f" (~{frame / state.video_fps:.2f}s)" return f"Kick frame ≈ {frame}{time_part}" def _mark_kick_frame(state: AppState, frame_value: float): if state is None or state.num_frames == 0: propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) status_updates = _ui_status_updates(state) return ( gr.update(), gr.update(value="Load a video first.", visible=True), gr.update(), _build_kick_plot(state), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) idx = int(np.clip(int(frame_value), 0, state.num_frames - 1)) state.kick_frame = idx state.manual_kick_frame = idx _compute_sam_window_from_kick(state, idx) state.current_frame_idx = idx msg = f"⚽ Kick frame manually set to {idx}" propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) status_updates = _ui_status_updates(state) return ( update_frame_display(state, idx), gr.update(value=msg, visible=True), gr.update(value=idx), _build_kick_plot(state), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) def _mark_impact_frame(state: AppState, frame_value: float): if state is None or state.num_frames == 0: propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) status_updates = _ui_status_updates(state) return ( gr.update(), gr.update(value="Load a video first.", visible=True), gr.update(), _build_kick_plot(state), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) idx = int(np.clip(int(frame_value), 0, state.num_frames - 1)) state.impact_frame = idx state.manual_impact_frame = idx msg = f"🚩 Impact frame manually set to {idx}" propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) status_updates = _ui_status_updates(state) return ( update_frame_display(state, idx), gr.update(value=msg, visible=True), gr.update(value=idx), _build_kick_plot(state), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) def _kick_button_updates(state: AppState) -> tuple[Any, ...]: def fmt(symbol: str, value: int | None, clickable: bool = True): text = f"{symbol}: {value if value is not None else 'N/A'}" return gr.update(value=text, interactive=clickable and value is not None) yolo_kick = getattr(state, "yolo_kick_frame", None) sam_kick = _get_prioritized_kick_frame(state) sam_impact = getattr(state, "impact_frame", None) if sam_impact is None: frames = getattr(state, "impact_debug_frames", []) if frames: sam_impact = frames[-1] manual_kick = getattr(state, "manual_kick_frame", None) manual_impact = getattr(state, "manual_impact_frame", None) return ( fmt("⚽", yolo_kick), fmt("🚩", None, clickable=False), fmt("⚽", sam_kick), fmt("🚩", sam_impact), fmt("⚽", manual_kick), fmt("🚩", manual_impact), ) def _impact_status_update(state: AppState): return gr.update(value=_format_impact_status(state), visible=False) def _ball_has_masks(state: AppState, target_obj_id: int = BALL_OBJECT_ID) -> bool: if state is None: return False for masks in state.masks_by_frame.values(): if target_obj_id in masks: return True return False def _player_has_masks(state: AppState) -> bool: if state is None or state.player_obj_id is None: return False player_id = state.player_obj_id for masks in state.masks_by_frame.values(): if player_id in masks: return True return False def _button_updates(state: AppState) -> tuple[Any, Any, Any]: yolo_ready = isinstance(state, AppState) and state.yolo_kick_frame is not None propagate_main_enabled = _ball_has_masks(state) or yolo_ready detect_player_enabled = yolo_ready propagate_player_enabled = _player_has_masks(state) sam_tracked = isinstance(state, AppState) and getattr(state, "is_sam_tracked", False) player_detected = isinstance(state, AppState) and getattr(state, "is_player_detected", False) player_propagated = isinstance(state, AppState) and getattr(state, "is_player_propagated", False) sam_variant = "secondary" if sam_tracked else "stop" detect_variant = "secondary" if player_detected else "stop" propagate_variant = "secondary" if player_propagated else "stop" return ( gr.update(interactive=propagate_main_enabled, variant=sam_variant), gr.update(interactive=detect_player_enabled, variant=detect_variant), gr.update(interactive=propagate_player_enabled, variant=propagate_variant), ) def _ball_button_updates(state: AppState) -> tuple[Any, Any]: def variant(flag: bool) -> str: return "secondary" if flag else "stop" ball_detected = isinstance(state, AppState) and getattr(state, "is_ball_detected", False) yolo_tracked = isinstance(state, AppState) and getattr(state, "is_yolo_tracked", False) return ( gr.update(variant=variant(ball_detected)), gr.update(variant=variant(yolo_tracked)), ) def _ui_status_updates(state: AppState) -> tuple[Any, ...]: return _kick_button_updates(state) + _ball_button_updates(state) + _goal_button_updates(state) def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1): centers = state.ball_centers.get(target_obj_id) if not centers or len(centers) < 3: state.smoothed_centers[target_obj_id] = {} state.ball_speeds[target_obj_id] = {} state.kick_frame = None state.kick_debug_frames = [] state.kick_debug_speeds = [] state.kick_debug_threshold = None state.kick_debug_baseline = None state.kick_debug_speed_std = None state.kick_debug_area = [] state.kick_debug_kick_frame = None state.kick_debug_distance = [] state.kalman_centers[target_obj_id] = {} state.kalman_speeds[target_obj_id] = {} state.kalman_residuals[target_obj_id] = {} state.kick_debug_kalman_speeds = [] state.distance_from_start[target_obj_id] = {} state.direction_change[target_obj_id] = {} state.impact_frame = None state.impact_debug_frames = [] state.impact_debug_innovation = [] state.impact_debug_innovation_threshold = None state.impact_debug_direction = [] state.impact_debug_direction_threshold = None state.impact_debug_speed_kmh = [] state.impact_debug_speed_threshold_px = None state.impact_meters_per_px = None return items = sorted(centers.items()) dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 alpha = 0.35 smoothed: dict[int, tuple[float, float]] = {} speeds: dict[int, float] = {} prev_frame = None prev_smooth = None for frame_idx, (cx, cy) in items: if prev_smooth is None: smooth_x, smooth_y = float(cx), float(cy) else: smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) smoothed[frame_idx] = (smooth_x, smooth_y) if prev_smooth is None or prev_frame is None: speeds[frame_idx] = 0.0 else: frame_delta = max(1, frame_idx - prev_frame) time_delta = frame_delta * dt dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) speed = dist / time_delta if time_delta > 0 else dist speeds[frame_idx] = speed prev_smooth = (smooth_x, smooth_y) prev_frame = frame_idx state.smoothed_centers[target_obj_id] = smoothed state.ball_speeds[target_obj_id] = speeds if smoothed: first_frame = min(smoothed.keys()) origin = smoothed[first_frame] distance_dict: dict[int, float] = {} for frame_idx, (sx, sy) in smoothed.items(): distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1]) state.distance_from_start[target_obj_id] = distance_dict state.kick_debug_distance = [distance_dict.get(f, 0.0) for f in sorted(smoothed.keys())] kalman_pos, kalman_speed, kalman_res = _run_kalman_filter(items, dt) state.kalman_centers[target_obj_id] = kalman_pos state.kalman_speeds[target_obj_id] = kalman_speed state.kalman_residuals[target_obj_id] = kalman_res state.kick_frame = _detect_kick_frame(state, target_obj_id) state.impact_frame = _detect_impact_frame(state, target_obj_id) def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None: smoothed = state.smoothed_centers.get(target_obj_id, {}) speeds = state.ball_speeds.get(target_obj_id, {}) if len(smoothed) < 5: return None frames = sorted(smoothed.keys()) speed_series = [speeds.get(f, 0.0) for f in frames] baseline_window = min(10, len(frames) // 3 or 1) baseline_speeds = speed_series[:baseline_window] baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 base_threshold = baseline_speed + 4.0 * speed_std if base_threshold < baseline_speed * 3.0: base_threshold = baseline_speed * 3.0 speed_threshold = max(base_threshold, 15.0) sustain_frames = 3 holdout_frames = 8 area_window = 4 area_drop_ratio = 0.75 areas_dict = state.mask_areas.get(target_obj_id, {}) initial_center = smoothed[frames[0]] initial_area = areas_dict.get(frames[0], 1.0) or 1.0 radius_estimate = math.sqrt(initial_area / math.pi) adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0)) state.kick_debug_frames = frames state.kick_debug_speeds = speed_series state.kick_debug_threshold = speed_threshold state.kick_debug_baseline = baseline_speed state.kick_debug_speed_std = speed_std state.kick_debug_area = [areas_dict.get(f, 0.0) for f in frames] state.kick_debug_distance = [ math.hypot(smoothed[f][0] - initial_center[0], smoothed[f][1] - initial_center[1]) for f in frames ] kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {}) state.kick_debug_kalman_speeds = [kalman_speed_dict.get(f, 0.0) for f in frames] state.kick_debug_kick_frame = None for idx in range(baseline_window, len(frames)): frame = frames[idx] speed = speed_series[idx] if speed < speed_threshold: continue sustain_ok = True for j in range(1, sustain_frames + 1): if idx + j >= len(frames): break if speed_series[idx + j] < speed_threshold * 0.7: sustain_ok = False break if not sustain_ok: continue current_area = areas_dict.get(frame) area_pass = True if current_area: prev_areas = [ areas_dict.get(f) for f in frames[max(0, idx - area_window):idx] if areas_dict.get(f) is not None ] if prev_areas: median_prev = statistics.median(prev_areas) if median_prev > 0: ratio = current_area / median_prev if ratio > area_drop_ratio: area_pass = False if not area_pass and speed < speed_threshold * 1.2: continue future_frames = frames[idx:min(len(frames), idx + holdout_frames)] max_future_dist = 0.0 for future_frame in future_frames: cx, cy = smoothed[future_frame] dist = math.hypot(cx - initial_center[0], cy - initial_center[1]) if dist > max_future_dist: max_future_dist = dist if max_future_dist < adaptive_return_distance: continue state.kick_debug_kick_frame = frame return frame state.kick_debug_kick_frame = None return None def _detect_impact_frame(state: AppState, target_obj_id: int) -> int | None: residuals = state.kalman_residuals.get(target_obj_id, {}) frames = sorted(residuals.keys()) state.impact_debug_frames = frames state.impact_debug_innovation = [residuals.get(f, 0.0) for f in frames] state.impact_debug_innovation_threshold = None state.impact_debug_direction = [] state.impact_debug_direction_threshold = None state.impact_debug_speed_kmh = [] state.impact_debug_speed_threshold_px = None state.impact_meters_per_px = None if not frames or state.kick_frame is None: state.impact_frame = None return None kalman_positions = state.kalman_centers.get(target_obj_id, {}) direction_dict: dict[int, float] = {} prev_pos: tuple[float, float] | None = None prev_vec: tuple[float, float] | None = None for frame in frames: pos = kalman_positions.get(frame) if pos is None: direction_dict[frame] = 0.0 continue if prev_pos is None: direction_dict[frame] = 0.0 prev_vec = (0.0, 0.0) else: vec = (pos[0] - prev_pos[0], pos[1] - prev_pos[1]) if prev_vec is None: direction_dict[frame] = 0.0 else: direction_dict[frame] = _angle_between(prev_vec, vec) prev_vec = vec prev_pos = pos state.direction_change[target_obj_id] = direction_dict state.impact_debug_direction = [direction_dict.get(f, 0.0) for f in frames] distance_dict = state.distance_from_start.get(target_obj_id, {}) max_distance_px = max(distance_dict.values()) if distance_dict else 0.0 goal_distance_m = max(state.goal_distance_m, 0.0) meters_per_px = goal_distance_m / max_distance_px if goal_distance_m > 0 and max_distance_px > 1e-6 else None state.impact_meters_per_px = meters_per_px kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {}) if meters_per_px: state.impact_debug_speed_kmh = [ kalman_speed_dict.get(f, 0.0) * meters_per_px * 3.6 for f in frames ] if state.min_impact_speed_kmh > 0: state.impact_debug_speed_threshold_px = (state.min_impact_speed_kmh / 3.6) / meters_per_px else: state.impact_debug_speed_kmh = [0.0 for _ in frames] state.impact_debug_speed_threshold_px = None baseline_frames = [f for f in frames if f <= state.kick_frame] if not baseline_frames: baseline_frames = frames[: max(1, min(len(frames), 10))] baseline_vals = [residuals.get(f, 0.0) for f in baseline_frames] baseline_median = statistics.median(baseline_vals) if baseline_vals else 0.0 baseline_std = statistics.pstdev(baseline_vals) if len(baseline_vals) > 1 else 0.0 innovation_threshold = baseline_median + 4.0 * baseline_std innovation_threshold = max(innovation_threshold, baseline_median * 3.0, 5.0) state.impact_debug_innovation_threshold = innovation_threshold direction_threshold = 25.0 state.impact_debug_direction_threshold = direction_threshold post_kick_buffer = 3 candidates: list[tuple[float, float, int]] = [] meters_limit = goal_distance_m * 1.1 if goal_distance_m > 0 else None frame_list_len = len(frames) for idx, frame in enumerate(frames): if frame <= state.kick_frame + post_kick_buffer: continue innovation = residuals.get(frame, 0.0) if innovation < innovation_threshold: continue direction_delta = direction_dict.get(frame, 0.0) if direction_delta < direction_threshold: continue speed_px = kalman_speed_dict.get(frame, 0.0) if state.impact_debug_speed_threshold_px and speed_px < state.impact_debug_speed_threshold_px: continue if meters_per_px and meters_limit is not None: distance_m = distance_dict.get(frame, 0.0) * meters_per_px if distance_m > meters_limit: continue # approximate local peak filter prev_innovation = residuals.get(frames[idx - 1], innovation) if idx > 0 else innovation next_innovation = residuals.get(frames[idx + 1], innovation) if idx + 1 < frame_list_len else innovation if innovation < prev_innovation and innovation < next_innovation: continue candidates.append((innovation, -frame, frame)) if not candidates: state.impact_frame = None return None candidates.sort(reverse=True) impact_frame = candidates[0][2] state.impact_frame = impact_frame return impact_frame def on_image_click( img: Image.Image | np.ndarray, state: AppState, frame_idx: int, obj_id: int, label: str, clear_old: bool, evt: gr.SelectData, ) -> Image.Image: if state is None or state.inference_session is None: return img # no-op preview when not ready if state.is_switching_model: # Gracefully ignore input during model switch; return current preview unchanged return update_frame_display(state, int(frame_idx)) # Parse click coordinates from event x = y = None if evt is not None: # Try different gradio event data shapes for robustness try: if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2: x, y = int(evt.index[0]), int(evt.index[1]) elif hasattr(evt, "value") and isinstance(evt.value, dict) and "x" in evt.value and "y" in evt.value: x, y = int(evt.value["x"]), int(evt.value["y"]) except Exception: x = y = None if x is None or y is None: raise gr.Error("Could not read click coordinates.") _ensure_color_for_obj(state, int(obj_id)) processor = state.processor model = state.model inference_session = state.inference_session original_size = None pixel_values = None if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt") original_size = inputs.original_sizes[0] pixel_values = inputs.pixel_values[0] if state.current_prompt_type == "Boxes": # Two-click box input if state.pending_box_start is None: # For boxes, always clear old inputs (points) for this object on this frame frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) frame_clicks[int(obj_id)] = [] state.composited_frames.pop(int(frame_idx), None) state.pending_box_start = (int(x), int(y)) state.pending_box_start_frame_idx = int(frame_idx) state.pending_box_start_obj_id = int(obj_id) # Invalidate cache so temporary cross is drawn state.composited_frames.pop(int(frame_idx), None) return update_frame_display(state, int(frame_idx)) else: x1, y1 = state.pending_box_start x2, y2 = int(x), int(y) # Clear temporary state and invalidate cache state.pending_box_start = None state.pending_box_start_frame_idx = None state.pending_box_start_obj_id = None state.composited_frames.pop(int(frame_idx), None) x_min, y_min = min(x1, x2), min(y1, y2) x_max, y_max = max(x1, x2), max(y1, y2) processor.add_inputs_to_inference_session( inference_session=inference_session, frame_idx=int(frame_idx), obj_ids=int(obj_id), input_boxes=[[[x_min, y_min, x_max, y_max]]], clear_old_inputs=True, # For boxes, always clear old inputs original_size=original_size, ) frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) obj_boxes = frame_boxes.setdefault(int(obj_id), []) # For boxes, always clear old inputs obj_boxes.clear() obj_boxes.append((x_min, y_min, x_max, y_max)) state.composited_frames.pop(int(frame_idx), None) else: # Points mode label_int = 1 if str(label).lower().startswith("pos") else 0 # If clear_old is enabled, clear prior boxes for this object on this frame if bool(clear_old): frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) frame_boxes[int(obj_id)] = [] state.composited_frames.pop(int(frame_idx), None) processor.add_inputs_to_inference_session( inference_session=inference_session, frame_idx=int(frame_idx), obj_ids=int(obj_id), input_points=[[[[int(x), int(y)]]]], input_labels=[[[int(label_int)]]], original_size=original_size, clear_old_inputs=bool(clear_old), ) frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) obj_clicks = frame_clicks.setdefault(int(obj_id), []) if bool(clear_old): obj_clicks.clear() obj_clicks.append((int(x), int(y), int(label_int))) state.composited_frames.pop(int(frame_idx), None) # Forward on that frame with torch.inference_mode(): outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx)) H = inference_session.video_height W = inference_session.video_width # Detach and move off GPU as early as possible to reduce GPU memory pressure pred_masks = outputs.pred_masks.detach().cpu() video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] # Map returned masks to object ids. For single object forward, it's [1, 1, H, W] # But to be safe, iterate over session.obj_ids order. masks_for_frame: dict[int, np.ndarray] = {} obj_ids_order = list(inference_session.obj_ids) for i, oid in enumerate(obj_ids_order): mask_i = video_res_masks[i] # mask_i shape could be (1, H, W) or (H, W); squeeze to 2D mask_2d = mask_i.cpu().numpy().squeeze() masks_for_frame[int(oid)] = mask_2d state.masks_by_frame[int(frame_idx)] = masks_for_frame _update_centroids_for_frame(state, int(frame_idx)) # Invalidate cache for this frame to force recomposition state.composited_frames.pop(int(frame_idx), None) # Return updated preview return update_frame_display(state, int(frame_idx)) def _on_image_click_with_updates( img: Image.Image | np.ndarray, state: AppState, frame_idx: int, obj_id: int, label: str, clear_old: bool, evt: gr.SelectData, ): frame_idx = int(frame_idx) handled_preview = None handled = False if state is not None and state.goal_mode != GOAL_MODE_IDLE: handled_preview, handled = _goal_process_preview_click(state, frame_idx, evt) if handled and handled_preview is not None: preview_img = handled_preview else: preview_img = on_image_click(img, state, frame_idx, obj_id, label, clear_old, evt) propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) status_updates = _ui_status_updates(state) return ( preview_img, propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) @spaces.GPU() def propagate_masks(GLOBAL_STATE: gr.State): if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: # yield GLOBAL_STATE, "Load a video first.", gr.update() propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) return ( GLOBAL_STATE, "Load a video first.", gr.update(), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) _ensure_ball_prompt_from_yolo(GLOBAL_STATE) processor = deepcopy(GLOBAL_STATE.processor) model = deepcopy(GLOBAL_STATE.model) inference_session = deepcopy(GLOBAL_STATE.inference_session) # set inference device to cuda to use zero gpu inference_session.inference_device = "cuda" inference_session.cache.inference_device = "cuda" model.to("cuda") if not GLOBAL_STATE.sam_window: _compute_sam_window_from_kick( GLOBAL_STATE, _get_prioritized_kick_frame(GLOBAL_STATE), ) start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames) start_idx = max(0, int(start_idx)) end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx))) total = max(1, end_idx - start_idx) processed = 0 _ensure_ball_prompt_from_yolo(GLOBAL_STATE) # Initial status; no slider change yet GLOBAL_STATE.is_sam_tracked = False propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) yield ( GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) last_frame_idx = start_idx with torch.inference_mode(): for frame_idx in range(start_idx, end_idx): frame = GLOBAL_STATE.video_frames[frame_idx] pixel_values = None if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) H = inference_session.video_height W = inference_session.video_width pred_masks = sam2_video_output.pred_masks.detach().cpu() video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] last_frame_idx = frame_idx masks_for_frame: dict[int, np.ndarray] = {} obj_ids_order = list(inference_session.obj_ids) for i, oid in enumerate(obj_ids_order): mask_2d = video_res_masks[i].cpu().numpy().squeeze() masks_for_frame[int(oid)] = mask_2d GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame _update_centroids_for_frame(GLOBAL_STATE, frame_idx) # Invalidate cache for that frame to force recomposition GLOBAL_STATE.composited_frames.pop(frame_idx, None) processed += 1 # Every 15th frame (or last), move slider to current frame to update preview via slider binding if processed % 30 == 0 or processed == total: propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) yield ( GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects." # Focus UI on kick frame if available; otherwise stick to last processed frame target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None) if target_frame is None: target_frame = last_frame_idx target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1))) GLOBAL_STATE.current_frame_idx = target_frame # Final status; ensure slider points to the target frame (kick frame when detected) GLOBAL_STATE.is_sam_tracked = True propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) yield ( GLOBAL_STATE, text, gr.update(value=target_frame), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any, Any]: # Reset only session-related state, keep uploaded video and model if not GLOBAL_STATE.video_frames: # Nothing loaded; keep behavior propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) GLOBAL_STATE.is_ball_detected = False GLOBAL_STATE.is_yolo_tracked = False GLOBAL_STATE.is_sam_tracked = False GLOBAL_STATE.is_player_detected = False GLOBAL_STATE.is_player_propagated = False return ( GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video.", gr.update(visible=False, value=""), _build_kick_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) # Clear prompts and caches GLOBAL_STATE.masks_by_frame.clear() GLOBAL_STATE.clicks_by_frame_obj.clear() GLOBAL_STATE.boxes_by_frame_obj.clear() GLOBAL_STATE.composited_frames.clear() GLOBAL_STATE.pending_box_start = None GLOBAL_STATE.pending_box_start_frame_idx = None GLOBAL_STATE.pending_box_start_obj_id = None GLOBAL_STATE.ball_centers.clear() GLOBAL_STATE.mask_areas.clear() GLOBAL_STATE.smoothed_centers.clear() GLOBAL_STATE.ball_speeds.clear() GLOBAL_STATE.distance_from_start.clear() GLOBAL_STATE.direction_change.clear() GLOBAL_STATE.kick_frame = None GLOBAL_STATE.ball_centers.clear() GLOBAL_STATE.kalman_centers.clear() GLOBAL_STATE.kalman_speeds.clear() GLOBAL_STATE.kalman_residuals.clear() GLOBAL_STATE.kick_debug_frames = [] GLOBAL_STATE.kick_debug_speeds = [] GLOBAL_STATE.kick_debug_threshold = None GLOBAL_STATE.kick_debug_baseline = None GLOBAL_STATE.kick_debug_speed_std = None GLOBAL_STATE.kick_debug_area = [] GLOBAL_STATE.kick_debug_kick_frame = None GLOBAL_STATE.kick_debug_distance = [] GLOBAL_STATE.kick_debug_kalman_speeds = [] GLOBAL_STATE.is_ball_detected = False GLOBAL_STATE.is_yolo_tracked = False GLOBAL_STATE.is_sam_tracked = False GLOBAL_STATE.is_player_detected = False GLOBAL_STATE.is_player_propagated = False GLOBAL_STATE.impact_frame = None GLOBAL_STATE.impact_debug_frames = [] GLOBAL_STATE.impact_debug_innovation = [] GLOBAL_STATE.impact_debug_innovation_threshold = None GLOBAL_STATE.impact_debug_direction = [] GLOBAL_STATE.impact_debug_direction_threshold = None GLOBAL_STATE.impact_debug_speed_kmh = [] GLOBAL_STATE.impact_debug_speed_threshold_px = None GLOBAL_STATE.impact_meters_per_px = None # Dispose and re-init inference session for current model with existing frames try: if GLOBAL_STATE.inference_session is not None: GLOBAL_STATE.inference_session.reset_inference_session() except Exception: pass GLOBAL_STATE.inference_session = None gc.collect() ensure_session_for_current_model(GLOBAL_STATE) # Keep current slider index if possible current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0)) current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1)) preview_img = update_frame_display(GLOBAL_STATE, current_idx) slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True) slider_value = gr.update(value=current_idx) status = "Session reset. Prompts cleared; video preserved." propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) # clear and reload model and processor return ( GLOBAL_STATE, preview_img, slider_minmax, slider_value, status, gr.update(visible=False, value=""), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) def create_annotation_preview(video_file, annotations): """ Create a preview image showing annotation points on video frames. Args: video_file: Path to video file annotations: List of annotation dicts Returns: PIL Image with annotations visualized """ import tempfile from pathlib import Path # Get video frames for the annotated frame indices cap = cv2.VideoCapture(video_file) if not cap.isOpened(): return None # Group annotations by frame frames_to_show = {} for ann in annotations: frame_idx = ann.get("frame", 0) if frame_idx not in frames_to_show: frames_to_show[frame_idx] = [] frames_to_show[frame_idx].append(ann) # Read and annotate frames annotated_frames = [] for frame_idx in sorted(frames_to_show.keys())[:3]: # Show max 3 frames cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret: continue # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(frame_rgb) draw = ImageDraw.Draw(pil_img) # Draw annotations for ann in frames_to_show[frame_idx]: x, y = ann.get("x", 0), ann.get("y", 0) obj_id = ann.get("object_id", 1) label = ann.get("label", "positive") # Color based on object ID color = pastel_color_for_object(obj_id) # Draw crosshair size = 20 draw.line([(x-size, y), (x+size, y)], fill=color, width=3) draw.line([(x, y-size), (x, y+size)], fill=color, width=3) draw.ellipse([(x-10, y-10), (x+10, y+10)], outline=color, width=3) # Draw label text = f"Obj{obj_id} F{frame_idx}" draw.text((x+15, y-15), text, fill=color) # Add frame number label draw.text((10, 10), f"Frame {frame_idx}", fill=(255, 255, 255)) annotated_frames.append(pil_img) cap.release() # Combine frames horizontally if not annotated_frames: return None total_width = sum(img.width for img in annotated_frames) max_height = max(img.height for img in annotated_frames) combined = Image.new('RGB', (total_width, max_height)) x_offset = 0 for img in annotated_frames: combined.paste(img, (x_offset, 0)) x_offset += img.width return combined @spaces.GPU(duration=120) # Allocate GPU for up to 2 minutes def process_video_api( video_file, annotations_json_str: str, checkpoint: str = "base_plus", remove_background: bool = True, ): """ Single-endpoint API for programmatic video processing. Args: video_file: Uploaded video file annotations_json_str: Optional JSON string containing helper annotations checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large) remove_background: Whether to remove the background in the render Returns: Tuple of (preview_image, processed_video_path, progress_log) """ import json try: log_entries: list[str] = [] def log_msg(message: str): text = f"[API] {message}" print(text) log_entries.append(text) # Parse annotations (optional) annotations_payload = annotations_json_str or "" annotations_data = json.loads(annotations_payload) if annotations_payload.strip() else {} annotations = annotations_data.get("annotations", []) client_fps = annotations_data.get("fps", None) log_msg(f"Received {len(annotations)} annotations") log_msg(f"Checkpoint: {checkpoint} | Remove background: {remove_background}") preview_img = create_annotation_preview(video_file, annotations) if annotations else None # Create a temporary state for this API call api_state = AppState() api_state.model_repo_key = checkpoint # Step 1: Initialize session with video log_msg("Loading video...") api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file) space_fps = api_state.video_fps log_msg(status) log_msg(f"Client FPS={client_fps} | Space FPS={space_fps}") # If FPS mismatch, warn about potential frame offset if client_fps and space_fps and abs(client_fps - space_fps) > 0.5: offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps))) log_msg(f"⚠️ FPS mismatch detected. Frame indices may be off by ~{offset_estimate} frames.") log_msg("ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy.") # Step 2: Apply each annotation if annotations: for i, ann in enumerate(annotations): object_id = ann.get("object_id", 1) timestamp_ms = ann.get("timestamp_ms", None) frame_idx = ann.get("frame", None) x = ann.get("x", 0) y = ann.get("y", 0) label = ann.get("label", "positive") # Calculate frame from timestamp using Space's FPS (more accurate) if timestamp_ms is not None and space_fps and space_fps > 0: calculated_frame = int((timestamp_ms / 1000.0) * space_fps) if frame_idx is not None and calculated_frame != frame_idx: log_msg(f"Annotation {i+1}: using timestamp {timestamp_ms}ms → Frame {calculated_frame} (client sent {frame_idx})") else: log_msg(f"Annotation {i+1}: timestamp {timestamp_ms}ms → Frame {calculated_frame}") frame_idx = calculated_frame elif frame_idx is None: log_msg(f"Annotation {i+1}: ⚠️ No timestamp/frame provided, defaulting to frame 0") frame_idx = 0 log_msg(f"Adding annotation {i+1}/{len(annotations)} | Obj {object_id} | Frame {frame_idx}") # Sync state api_state.current_frame_idx = int(frame_idx) api_state.current_obj_id = int(object_id) api_state.current_label = str(label) # Create a mock event with coordinates class MockEvent: def __init__(self, x, y): self.index = (x, y) mock_evt = MockEvent(x, y) # Add the point annotation preview_img = on_image_click( first_frame, api_state, frame_idx, object_id, label, clear_old=False, evt=mock_evt ) if preview_img is None: preview_img = first_frame # Helper to consume generator-based steps and capture log messages def _run_generator(gen, label: str): final_state = None for outputs in gen: if not outputs: continue final_state = outputs[0] status_msg = outputs[1] if len(outputs) > 1 else "" if status_msg: log_msg(f"{label}: {status_msg}") if final_state is not None: return final_state raise gr.Error(f"{label} did not produce any output.") # Step 3: YOLO13 detect ball api_state.current_obj_id = BALL_OBJECT_ID api_state.current_label = "positive" log_msg("YOLO13 · Detect ball (single-frame search)") _auto_detect_ball(api_state, BALL_OBJECT_ID, "positive", False) if not api_state.is_ball_detected: raise gr.Error("YOLO13 could not detect the ball automatically.") # Step 4: YOLO13 track ball log_msg("YOLO13 · Track ball across clip") _track_ball_yolo(api_state) if not api_state.is_yolo_tracked: raise gr.Error("YOLO13 tracking failed.") # Step 5: SAM2 track ball around kick window log_msg("SAM2 · Track ball around kick window") api_state = _run_generator(propagate_masks(api_state), "SAM2 · Ball") sam_kick = _get_prioritized_kick_frame(api_state) yolo_kick = api_state.yolo_kick_frame if sam_kick is not None: log_msg(f"SAM2 kick frame ≈ {sam_kick}") if yolo_kick is not None: log_msg(f"YOLO kick frame ≈ {yolo_kick}") # Fallback: re-run SAM2 on entire video if kicks disagree if ( yolo_kick is not None and sam_kick is not None and int(yolo_kick) != int(sam_kick) ): log_msg("Kick disagreement detected → re-running SAM2 across entire video.") api_state.sam_window = (0, api_state.num_frames) api_state = _run_generator(propagate_masks(api_state), "SAM2 · Full sweep") sam_kick = _get_prioritized_kick_frame(api_state) log_msg(f"SAM2 full sweep kick frame ≈ {sam_kick}") else: log_msg("Kick frames aligned. No full sweep required.") # Step 6: YOLO detect player on SAM2 kick frame log_msg("YOLO13 · Detect player on SAM2 kick frame") _auto_detect_player(api_state) if api_state.is_player_detected: log_msg("YOLO13 · Player detected successfully.") else: log_msg("YOLO13 · Player detection failed; continuing without player propagation.") # Step 7: SAM2 track player if detection succeeded if api_state.is_player_detected: log_msg("SAM2 · Track player around kick window") try: api_state = _run_generator(propagate_player_masks(api_state), "SAM2 · Player") except gr.Error as player_error: log_msg(f"SAM2 player propagation warning: {player_error}") # Step 8: Render the final video log_msg(f"Rendering video (remove_background={remove_background})") result_video_path = _render_video(api_state, remove_background, log_fn=log_msg) log_msg("Processing complete 🎉") return preview_img, result_video_path, "\n".join(log_entries) except Exception as e: print(f"[API] ❌ Error: {str(e)}") import traceback traceback.print_exc() raise gr.Error(f"Processing failed: {str(e)}") theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate") CUSTOM_CSS = """ .gr-button-stop { background-color: #f97316 !important; border-color: #ea580c !important; color: #fff !important; } .gr-button-stop:hover { background-color: #ea580c !important; border-color: #c2410c !important; } .gr-button-stop:disabled { opacity: 0.7 !important; color: #fff !important; } .model-row { display: flex; align-items: center; gap: 0.4rem; flex-wrap: nowrap !important; } .model-label { min-width: 68px; font-weight: 600; } .model-label p { margin: 0 !important; } .model-section { background: rgba(255, 255, 255, 0.02); border-radius: 0.4rem; padding: 0.45rem 0.65rem; margin-bottom: 0.45rem; display: flex; flex-direction: column; gap: 0.3rem; } .model-actions { flex: 1 1 auto; display: flex; flex-wrap: nowrap; gap: 0.35rem; } .model-actions .gr-button { flex: 0 0 auto; min-width: unset; width: fit-content; padding: 0.32rem 0.7rem; } .model-status { flex: 0 0 auto; display: flex; gap: 0.25rem; margin-left: auto; } .model-status .gr-button { min-width: unset; width: fit-content; padding: 0.25rem 0.55rem; } """ BUTTON_TOOLTIPS = { "btn-reset-session": ( "Clears the entire workspace: YOLO detections, SAM2 masks, manual kick/impact overrides, and FX settings all " "return to defaults so you can load a new clip without leftover state." ), "btn-mark-kick": ( "Stores the current frame as the definitive kick moment. We override YOLO or SAM guesses immediately so SAM2 " "propagation, ring rendering, impact-speed math, and player workflows all pivot around this human-confirmed " "timestamp until it is cleared." ), "btn-mark-impact": ( "Declares the current frame as the impact (goal crossing or contact). Automatic impact detection is still in " "progress, so this manual anchor feeds the diagnostics plot and tells the renderer when to fade rings or ghost trails." ), "btn-detect-ball": ( "Runs YOLO13 over the entire video to find the stationary ball before it moves. We keep only the single best " "candidate per frame so the rest of the pipeline has one anchor. This yields the first kick guess, an initial radius, " "and enables the tracking and player steps." ), "btn-track-ball-yolo": ( "Sweeps YOLO13 tracking across every frame while forcing exactly one plausible ball trajectory. We smooth detections, " "locate the velocity spike that marks the kick, and cache the radius there. This fast scout tells SAM2 where to focus " "later and populates the future ring / ghost trail data." ), "btn-detect-player": ( "Samples the prioritized kick frame (manual > SAM > YOLO) and runs player detection there. Aligning the player mask " "with the kick ensures SAM2 can later propagate the athlete through the same window, unlocking the Track Player step." ), "btn-track-ball-sam": ( "Runs the SAM2 transformer on a tight window centered on the prioritized kick frame. We seed it with YOLO’s latest " "ball mask so SAM2 delivers high-fidelity segmentation only where it matters, refreshing ring radii without scanning " "the full clip." ), "btn-track-player-sam": ( "After a player mask exists, SAM2 propagates it within the same kick-centric window. This keeps athlete and ball masks " "time-synced, enabling combined overlays, exports, and analytics comparing foot position to ball contact." ), "btn-goal-start": ( "Enters goal mapping mode so you can click the two crossbar corners. After the first click a handle appears; the second " "click closes the bar and exposes draggable anchors before you confirm." ), "btn-goal-confirm": ( "Locks the currently placed crossbar across the entire video. The line and handles stay visible on every frame and can " "be re-edited later by tapping Map Goal again." ), "btn-goal-clear": ( "Removes the current crossbar (and any in-progress points) so you can restart the goal alignment workflow from scratch." ), "btn-goal-back": ( "Restores the previously confirmed crossbar if the latest edits missed the mark. Useful when you want to compare two " "placements without re-clicking both corners." ), } def _build_tooltip_script() -> str: data = json.dumps(BUTTON_TOOLTIPS) return f""" """ TOOLTIP_SCRIPT = _build_tooltip_script() with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme, css=CUSTOM_CSS) as demo: GLOBAL_STATE = gr.State(AppState()) gr.Markdown( """ ### KickTrimmer Lab · Ball-Speed Video Twin This Space acts as a desktop twin of the KickTrimmer mobile app: load a football clip, detect the kick, and estimate the ball speed frame-by-frame. It previews future ball rings color-coded by hypothetical impact velocity as the ball travels toward the goal, letting you experiment with FX settings before shipping them to the phone build. ⚠️ **Work in progress:** we are still closing the gap with the mobile feature set (automatic horizon & goal finding, diagonal speed correction, etc.), so the numbers you see here are prototypes—not final certified speeds. """ ) with gr.Row(): with gr.Column(): gr.Markdown( """ **Quick start** - **Load a video**: Upload your own or pick an example below. - **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy). - **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames. - **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically. """ ) with gr.Column(): gr.Markdown( """ **Working with results** - **Preview**: Use the slider to navigate frames and see the current masks. - **Track**: Click “Track ball (SAM2)” to track all defined objects across the selected window. The preview follows progress periodically to keep things responsive. - **Export**: Render an MP4 for smooth playback using the original video FPS. - **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video). """ ) with gr.Row(equal_height=True): with gr.Column(scale=1): video_in = gr.Video( label="Upload video", sources=["upload", "webcam"], interactive=True, elem_id="video-pane", ) ckpt_radio = gr.Radio( choices=["tiny", "small", "base_plus", "large"], value="tiny", label="SAM2.1 checkpoint", ) ckpt_progress = gr.Markdown(visible=False) load_status = gr.Markdown(visible=True) reset_btn = gr.Button("Reset Session", variant="secondary", elem_id="btn-reset-session") with gr.Column(scale=1): gr.Markdown("**Preview**") preview = gr.Image( interactive=True, elem_id="preview-pane", container=False, show_label=False, ) frame_slider = gr.Slider( label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True, elem_id="frame-slider", ) with gr.Column(): with gr.Column(elem_classes=["model-section"]): with gr.Row(elem_classes=["model-row"]): gr.Markdown("Manual", elem_classes=["model-label"]) with gr.Row(elem_classes=["model-actions"]): mark_kick_btn = gr.Button("⚽ Mark Kick", variant="primary", elem_id="btn-mark-kick") mark_impact_btn = gr.Button("🚩 Mark Impact", variant="primary", elem_id="btn-mark-impact") with gr.Row(elem_classes=["model-status"]): manual_kick_btn = gr.Button("⚽: N/A", interactive=False) manual_impact_btn = gr.Button("🚩: N/A", interactive=False) with gr.Row(elem_classes=["model-actions"]): goal_start_btn = gr.Button( "Map Goal", variant="secondary", elem_id="btn-goal-start", ) goal_confirm_btn = gr.Button( "Confirm", variant="primary", interactive=False, elem_id="btn-goal-confirm", ) goal_clear_btn = gr.Button( "Clear", variant="secondary", interactive=False, elem_id="btn-goal-clear", ) goal_back_btn = gr.Button( "Back", variant="secondary", interactive=False, elem_id="btn-goal-back", ) goal_status = gr.Markdown("Goal crossbar inactive.", elem_id="goal-status-text") with gr.Column(elem_classes=["model-section"]): with gr.Row(elem_classes=["model-row"]): gr.Markdown("YOLO13", elem_classes=["model-label"]) with gr.Row(elem_classes=["model-actions"]): detect_ball_btn = gr.Button("Detect Ball", variant="stop", elem_id="btn-detect-ball") track_ball_yolo_btn = gr.Button("Track Ball", variant="stop", elem_id="btn-track-ball-yolo") detect_player_btn = gr.Button( "Detect Player", variant="stop", interactive=False, elem_id="btn-detect-player", ) with gr.Row(elem_classes=["model-status"]): yolo_kick_btn = gr.Button("⚽: N/A", interactive=False) yolo_impact_btn = gr.Button("🚩: N/A", interactive=False) # Multi-ball candidate selection UI with gr.Column(visible=False) as multi_ball_selection_col: multi_ball_status_md = gr.Markdown("", visible=True) ball_candidate_radio = gr.Radio( choices=[], value=None, label="Select Ball Candidate", interactive=True, ) with gr.Row(): confirm_ball_btn = gr.Button("Confirm Selection", variant="primary") multi_ball_chart = gr.Plot(label="Ball Candidates Speed Comparison", show_label=True) yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True) with gr.Column(elem_classes=["model-section"]): with gr.Row(elem_classes=["model-row"]): gr.Markdown("SAM2", elem_classes=["model-label"]) with gr.Row(elem_classes=["model-actions"]): propagate_btn = gr.Button( "Track Ball", variant="stop", interactive=False, elem_id="btn-track-ball-sam" ) propagate_player_btn = gr.Button( "Track Player", variant="stop", interactive=False, elem_id="btn-track-player-sam", ) with gr.Row(elem_classes=["model-status"]): sam_kick_btn = gr.Button("⚽: N/A", interactive=False) sam_impact_btn = gr.Button("🚩: N/A", interactive=False) kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True) gr.HTML(value=TOOLTIP_SCRIPT, visible=False) with gr.Row(): min_impact_speed_slider = gr.Slider( label="Min impact speed (km/h)", minimum=0, maximum=120, step=1, value=20, interactive=True, ) goal_distance_slider = gr.Slider( label="Distance to goal (m)", minimum=1, maximum=60, step=0.5, value=18, interactive=True, ) ball_status = gr.Markdown(visible=False) propagate_status = gr.Markdown(visible=True) impact_status = gr.Markdown("Impact frame: not computed", visible=False) with gr.Row(): obj_id_inp = gr.Number(value=1, precision=0, label="Object ID", scale=0) label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label") clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object") prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type") # Wire events def _on_video_change(GLOBAL_STATE: gr.State, video): GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video) propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) return ( GLOBAL_STATE, gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), first_frame, status, gr.update(visible=False, value=""), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), *status_updates, propagate_main_update, detect_btn_update, propagate_player_update, ) video_in.change( _on_video_change, inputs=[GLOBAL_STATE, video_in], outputs=[ GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, propagate_btn, detect_player_btn, propagate_player_btn, ], show_progress=True, ) example_video_path = ensure_example_video() examples_list = [ [None, example_video_path], ] with gr.Row(): gr.Examples( examples=examples_list, inputs=[GLOBAL_STATE, video_in], fn=_on_video_change, outputs=[ GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, propagate_btn, detect_player_btn, propagate_player_btn, ], label="Examples", cache_examples=False, examples_per_page=5, ) with gr.Row(): with gr.Column(scale=1): remove_bg_checkbox = gr.Checkbox( label="Remove Background", value=True, info="If checked, shows only tracked objects on black background. If unchecked, overlays colored masks on original video.", ) with gr.Column(scale=1): ghost_trail_chk = gr.Checkbox( label="Ghost trail (ball)", value=False, info="Overlay post-kick SAM2 ball masks in magenta to visualize trajectory.", ) with gr.Column(scale=1): ball_ring_chk = gr.Checkbox( label="Ball rings (future)", value=True, info="Replace the ghost trail fill with magenta rings at future ball positions.", ) with gr.Column(scale=1): click_marks_chk = gr.Checkbox( label="Show annotation '+'", value=False, info="If unchecked, hides the '+' markers from clicks in preview and renders.", ) with gr.Accordion("Cutout FX", open=False): gr.Markdown("These options apply when rendering with background removal.") with gr.Row(): with gr.Column(scale=1): soft_matte_chk = gr.Checkbox(label="Soft matte", value=True) with gr.Column(scale=2): soft_matte_feather = gr.Slider( label="Feather radius (px)", minimum=0.0, maximum=12.0, step=0.5, value=4.0, ) with gr.Column(scale=2): soft_matte_erode = gr.Slider( label="Edge shrink (px)", minimum=0.0, maximum=5.0, step=0.5, value=0.5, ) with gr.Row(): with gr.Column(scale=1): blur_bg_chk = gr.Checkbox(label="Blur background", value=True) with gr.Column(scale=2): blur_radius = gr.Slider( label="Background blur (px)", minimum=0.0, maximum=45.0, step=1.0, value=0.0, ) with gr.Column(scale=2): bg_darkening = gr.Slider( label="Darken background", minimum=0.0, maximum=1.0, step=0.05, value=0.75, info="0 keeps original brightness, 1 turns the background black.", ) with gr.Row(): with gr.Column(scale=1): light_wrap_chk = gr.Checkbox(label="Light wrap", value=False) with gr.Column(scale=2): light_wrap_strength = gr.Slider( label="Wrap strength", minimum=0.0, maximum=1.0, step=0.05, value=0.6, ) with gr.Column(scale=2): light_wrap_width = gr.Slider( label="Wrap width (px)", minimum=0.0, maximum=25.0, step=0.5, value=15.0, ) with gr.Row(): with gr.Column(scale=1): glow_chk = gr.Checkbox(label="Neon glow", value=False) with gr.Column(scale=2): glow_strength = gr.Slider( label="Glow strength", minimum=0.0, maximum=1.0, step=0.05, value=0.4, ) with gr.Column(scale=2): glow_radius = gr.Slider( label="Glow radius (px)", minimum=0.0, maximum=35.0, step=0.5, value=10.0, ) # New Ring FX Controls gr.Markdown("### Ring FX Settings") with gr.Row(): with gr.Column(scale=1): ring_thickness = gr.Slider( label="Ring Thickness", minimum=0.1, maximum=2.0, step=0.1, value=1.0, ) with gr.Column(scale=1): ring_alpha = gr.Slider( label="Ring Intensity (Alpha)", minimum=0.1, maximum=3.0, step=0.1, value=3.0, ) with gr.Column(scale=1): ring_feather = gr.Slider( label="Ring Softness (Blur)", minimum=0.0, maximum=5.0, step=0.1, value=0.1, ) with gr.Column(scale=1): ring_gamma = gr.Slider( label="Ring Gamma (Contrast)", minimum=0.1, maximum=2.0, step=0.05, value=2.0, info="Lower values = higher contrast/sharper falloff" ) with gr.Column(scale=1): ring_duration = gr.Slider( label="Rings Duration (frames)", minimum=0, maximum=120, step=1, value=30, info="Limit how many frames after the kick to show rings (approx 0-4s)" ) with gr.Column(scale=1): ring_scale_pct = gr.Slider( label="Ring Size Scale (%)", minimum=10, maximum=200, step=5, value=125, info="Adjust overall ring size relative to detected ball radius." ) with gr.Row(): render_btn = gr.Button("Render MP4 for smooth playback", variant="primary") playback_video = gr.Video(label="Rendered Playback", interactive=False) fx_inputs = [ soft_matte_chk, soft_matte_feather, soft_matte_erode, blur_bg_chk, blur_radius, bg_darkening, light_wrap_chk, light_wrap_strength, light_wrap_width, glow_chk, glow_strength, glow_radius, # New inputs ring_thickness, ring_alpha, ring_feather, ring_gamma, ring_scale_pct, ring_duration, ] for comp in fx_inputs: comp.change( _update_fx_controls, inputs=[GLOBAL_STATE] + fx_inputs, outputs=preview, ) ghost_trail_chk.change( _toggle_ghost_trail, inputs=[GLOBAL_STATE, ghost_trail_chk], outputs=preview, ) ball_ring_chk.change( _toggle_ball_ring, inputs=[GLOBAL_STATE, ball_ring_chk], outputs=preview, ) click_marks_chk.change( _toggle_click_marks, inputs=[GLOBAL_STATE, click_marks_chk], outputs=preview, ) def _on_ckpt_change(s: AppState, key: str): if s is not None and key: key = str(key) if key != s.model_repo_key: # Update and drop current model to reload lazily next time s.is_switching_model = True s.model_repo_key = key s.model_repo_id = None s.model = None s.processor = None # Stream progress text while loading (first yield shows text) yield gr.update(visible=True, value=f"Loading checkpoint: {key}...") ensure_session_for_current_model(s) if s is not None: s.is_switching_model = False # Final yield hides the text yield gr.update(visible=False, value="") ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress]) def _sync_frame_idx(state_in: AppState, idx: int): if state_in is not None: state_in.current_frame_idx = int(idx) return update_frame_display(state_in, int(idx)) frame_slider.change( _sync_frame_idx, inputs=[GLOBAL_STATE, frame_slider], outputs=preview, ) yolo_kick_btn.click( _jump_to_yolo_kick, inputs=[GLOBAL_STATE], outputs=[preview, frame_slider], ) sam_kick_btn.click( _jump_to_sam_kick, inputs=[GLOBAL_STATE], outputs=[preview, frame_slider], ) sam_impact_btn.click( _jump_to_sam_impact, inputs=[GLOBAL_STATE], outputs=[preview, frame_slider], ) manual_kick_btn.click( _jump_to_manual_kick, inputs=[GLOBAL_STATE], outputs=[preview, frame_slider], ) manual_impact_btn.click( _jump_to_manual_impact, inputs=[GLOBAL_STATE], outputs=[preview, frame_slider], ) mark_kick_btn.click( _mark_kick_frame, inputs=[GLOBAL_STATE, frame_slider], outputs=[ preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) mark_impact_btn.click( _mark_impact_frame, inputs=[GLOBAL_STATE, frame_slider], outputs=[ preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) def _sync_obj_id(s: AppState, oid): if s is not None and oid is not None: s.current_obj_id = int(oid) return gr.update() obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[]) def _sync_label(s: AppState, lab: str): if s is not None and lab is not None: s.current_label = str(lab) return gr.update() label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[]) def _sync_prompt_type(s: AppState, val: str): if s is not None and val is not None: s.current_prompt_type = str(val) s.pending_box_start = None is_points = str(val).lower() == "points" # Show labels only for points; hide and disable clear_old when boxes updates = [ gr.update(visible=is_points), gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False), ] return updates prompt_type.change( _sync_prompt_type, inputs=[GLOBAL_STATE, prompt_type], outputs=[label_radio, clear_old_chk], ) def _update_min_impact_speed(s: AppState, val: float): if s is not None and val is not None: s.min_impact_speed_kmh = float(val) _recompute_motion_metrics(s) propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s) return ( _build_kick_plot(s), _impact_status_update(s), gr.update(value=_format_kick_status(s), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, ) def _update_goal_distance(s: AppState, val: float): if s is not None and val is not None: s.goal_distance_m = float(val) _recompute_motion_metrics(s) propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s) return ( _build_kick_plot(s), _impact_status_update(s), gr.update(value=_format_kick_status(s), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, ) min_impact_speed_slider.change( _update_min_impact_speed, inputs=[GLOBAL_STATE, min_impact_speed_slider], outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], ) goal_distance_slider.change( _update_goal_distance, inputs=[GLOBAL_STATE, goal_distance_slider], outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], ) def _auto_detect_ball( state_in: AppState, obj_id, label_value: str, clear_old_value: bool, ): if state_in is None or state_in.num_frames == 0: raise gr.Error("Load a video first, then try auto-detect.") state_in.is_ball_detected = False frame_idx = 0 frame = state_in.video_frames[frame_idx] print(f"[_auto_detect_ball] Frame size: {frame.size}") # First, try multi-ball detection candidates = detect_all_balls(frame) print(f"[_auto_detect_ball] detect_all_balls returned {len(candidates)} candidates") # Default multi-ball UI updates (hidden) multi_ball_col_update = gr.update(visible=False) multi_ball_status_update = gr.update(value="") multi_ball_radio_update = gr.update(choices=[], value=None, visible=False) multi_ball_chart_update = gr.update(value=None) if not candidates: # Fallback to single-ball detection print("[_auto_detect_ball] No candidates from detect_all_balls, trying detect_ball_center...") detection = detect_ball_center(frame) print(f"[_auto_detect_ball] detect_ball_center returned: {detection}") if detection is None: propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) status_updates = _ui_status_updates(state_in) return ( update_frame_display(state_in, frame_idx), gr.update( value="❌ Unable to auto-detect the ball. Please add a point manually.", visible=True, ), gr.update(value=frame_idx), _build_kick_plot(state_in), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, multi_ball_col_update, multi_ball_status_update, multi_ball_radio_update, multi_ball_chart_update, ) x_center, y_center, _, _, conf = detection state_in.ball_candidates = [] else: # Store all candidates state_in.ball_candidates = candidates state_in.selected_ball_idx = 0 # Use the best candidate (first one after sorting by confidence) best = candidates[0] x_center, y_center = best["center"] conf = best["conf"] if len(candidates) > 1: state_in.multi_ball_status = f"⚠️ {len(candidates)} balls detected in frame. Click 'Track Ball' to analyze which one is kicked." # Show multi-ball UI with candidate list multi_ball_col_update = gr.update(visible=True) multi_ball_status_update = gr.update( value=f"**{len(candidates)} balls detected!** YOLO found multiple balls in the first frame.\n\n" f"The best candidate (highest confidence) is auto-selected.\n" f"Click **Track Ball** to analyze all candidates and find the one being kicked." ) # Don't show radio yet - will show after tracking multi_ball_radio_update = gr.update(choices=[], value=None, visible=False) else: state_in.multi_ball_status = "" frame_width, frame_height = frame.size x_center = max(0, min(frame_width - 1, int(x_center))) y_center = max(0, min(frame_height - 1, int(y_center))) obj_id_int = int(obj_id) if obj_id is not None else state_in.current_obj_id label_str = label_value if label_value else state_in.current_label clear_old_flag = bool(clear_old_value) # Build a synthetic click event to reuse existing handler synthetic_evt = SimpleNamespace( index=(x_center, y_center), value={"x": x_center, "y": y_center}, ) state_in.current_frame_idx = frame_idx preview_img = on_image_click( update_frame_display(state_in, frame_idx), state_in, frame_idx, obj_id_int, label_str, clear_old_flag, synthetic_evt, ) state_in.is_ball_detected = True num_candidates = len(getattr(state_in, 'ball_candidates', [])) # Draw YOLO bounding boxes on preview if we have candidates if num_candidates > 0 and isinstance(preview_img, Image.Image): preview_img = draw_yolo_detections_on_frame( preview_img, state_in.ball_candidates, selected_idx=0, ) if num_candidates > 1: status_text = f"⚠️ {num_candidates} balls found! Best at ({x_center}, {y_center}) (conf={conf:.2f}). Click 'Track Ball' to analyze all." else: status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})" status_text += f" | {_format_kick_status(state_in)}" propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) status_updates = _ui_status_updates(state_in) return ( preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx), _build_kick_plot(state_in), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, multi_ball_col_update, multi_ball_status_update, multi_ball_radio_update, multi_ball_chart_update, ) detect_ball_btn.click( _auto_detect_ball, inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk], outputs=[ preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, # Multi-ball UI outputs multi_ball_selection_col, multi_ball_status_md, ball_candidate_radio, multi_ball_chart, ], ) def _track_ball_yolo(state_in: AppState): if state_in is None or state_in.num_frames == 0: raise gr.Error("Load a video first, then track the ball with YOLO.") progress = gr.Progress(track_tqdm=False) state_in.is_yolo_tracked = False # Check if we have multiple ball candidates num_candidates = len(getattr(state_in, 'ball_candidates', [])) # Default multi-ball UI updates multi_ball_col_update = gr.update(visible=False) multi_ball_status_update = gr.update(value="") multi_ball_radio_update = gr.update(choices=[], value=None, visible=False) multi_ball_chart_update = gr.update(value=None) if num_candidates > 1: # Multi-ball mode: track all candidates and show comparison _detect_and_track_all_ball_candidates(state_in, progress=progress) # Apply the best candidate to YOLO state _apply_selected_ball_to_yolo_state(state_in) base_msg = state_in.multi_ball_status or state_in.yolo_status or "" # Build the multi-ball UI candidates = state_in.ball_candidates if len(candidates) > 1: radio_choices = _format_ball_candidates_for_radio(candidates) selected_value = radio_choices[state_in.selected_ball_idx] if radio_choices else None multi_ball_col_update = gr.update(visible=True) multi_ball_status_update = gr.update( value=_format_ball_candidates_markdown(candidates, state_in.selected_ball_idx) ) multi_ball_radio_update = gr.update( choices=radio_choices, value=selected_value, visible=True, ) multi_ball_chart_update = gr.update(value=_build_multi_ball_chart(state_in)) else: # Single ball mode: use original tracking _perform_yolo_ball_tracking(state_in, progress=progress) base_msg = state_in.yolo_status or "" target_frame = ( state_in.yolo_kick_frame if state_in.yolo_kick_frame is not None else state_in.yolo_initial_frame if state_in.yolo_initial_frame is not None else 0 ) if state_in.num_frames: target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1)) state_in.current_frame_idx = target_frame preview_img = update_frame_display(state_in, target_frame) # Draw YOLO bounding boxes on preview if we have candidates (after tracking, with kick info) candidates = getattr(state_in, 'ball_candidates', []) if len(candidates) > 0 and isinstance(preview_img, Image.Image): preview_img = draw_yolo_detections_on_frame( preview_img, candidates, selected_idx=state_in.selected_ball_idx, ) kick_msg = _format_kick_status(state_in) status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg state_in.is_yolo_tracked = True propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) status_updates = _ui_status_updates(state_in) return ( preview_img, gr.update(value=status_text, visible=True), gr.update(value=target_frame), _build_kick_plot(state_in), _build_yolo_plot(state_in), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, multi_ball_col_update, multi_ball_status_update, multi_ball_radio_update, multi_ball_chart_update, ) track_ball_yolo_btn.click( _track_ball_yolo, inputs=[GLOBAL_STATE], outputs=[ preview, ball_status, frame_slider, kick_plot, yolo_plot, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, # Multi-ball UI outputs multi_ball_selection_col, multi_ball_status_md, ball_candidate_radio, multi_ball_chart, ], ) # Multi-ball selection handlers def _on_ball_candidate_change(state_in: AppState, selected_label: str): """Handle radio button selection change.""" if state_in is None or not state_in.ball_candidates: return gr.update(), gr.update(), gr.update() # Find the selected index from the label radio_choices = _format_ball_candidates_for_radio(state_in.ball_candidates) try: new_idx = radio_choices.index(selected_label) except ValueError: new_idx = 0 state_in.selected_ball_idx = new_idx # Update the preview to show the new selection highlighted frame_idx = state_in.current_frame_idx preview_img = update_frame_display(state_in, frame_idx) if isinstance(preview_img, Image.Image): preview_img = draw_yolo_detections_on_frame( preview_img, state_in.ball_candidates, selected_idx=new_idx, ) # Update the chart to highlight the new selection chart_update = gr.update(value=_build_multi_ball_chart(state_in)) status_update = gr.update( value=_format_ball_candidates_markdown(state_in.ball_candidates, new_idx) ) return preview_img, chart_update, status_update ball_candidate_radio.change( _on_ball_candidate_change, inputs=[GLOBAL_STATE, ball_candidate_radio], outputs=[preview, multi_ball_chart, multi_ball_status_md], ) def _on_confirm_ball_selection(state_in: AppState): """Confirm the selected ball and apply it to the main tracking state.""" if state_in is None or not state_in.ball_candidates: raise gr.Error("No ball candidates to confirm.") # Apply the selected candidate to YOLO state _apply_selected_ball_to_yolo_state(state_in) # Get the selected candidate info idx = state_in.selected_ball_idx candidate = state_in.ball_candidates[idx] # Jump to kick frame if available target_frame = candidate.get('kick_frame') or 0 if state_in.num_frames: target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1)) state_in.current_frame_idx = target_frame # Clear the candidates list to indicate selection is done state_in.ball_selection_confirmed = True preview_img = update_frame_display(state_in, target_frame) kick_info = f"Kick @ frame {candidate.get('kick_frame')}" if candidate.get('has_kick') else "No kick detected" status_text = f"✅ Ball {idx + 1} confirmed. {kick_info}" propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) status_updates = _ui_status_updates(state_in) return ( preview_img, gr.update(value=status_text, visible=True), gr.update(value=target_frame), _build_kick_plot(state_in), _build_yolo_plot(state_in), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, # Hide the multi-ball selection UI after confirmation gr.update(visible=False), ) confirm_ball_btn.click( _on_confirm_ball_selection, inputs=[GLOBAL_STATE], outputs=[ preview, ball_status, frame_slider, kick_plot, yolo_plot, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, multi_ball_selection_col, ], ) def _auto_detect_player(state_in: AppState): if state_in is None or state_in.num_frames == 0: raise gr.Error("Load a video first, then try auto-detect.") if state_in.inference_session is None or state_in.processor is None or state_in.model is None: raise gr.Error("Model session is not ready. Load a video and propagate masks first.") state_in.is_player_detected = False priority_frames: list[int] = [] sam_frame = state_in.kick_frame or getattr(state_in, "kick_debug_kick_frame", None) if sam_frame is not None: priority_frames.append(int(sam_frame)) yolo_frame = getattr(state_in, "yolo_kick_frame", None) if yolo_frame is not None: yolo_int = int(yolo_frame) if yolo_int not in priority_frames: priority_frames.append(yolo_int) if not priority_frames: raise gr.Error("Detect the kick frame first by propagating the ball masks.") detection = None used_frame_idx = None for candidate in priority_frames: frame_idx = int(np.clip(candidate, 0, state_in.num_frames - 1)) frame = state_in.video_frames[frame_idx] detection = detect_person_box(frame) if detection is not None: used_frame_idx = frame_idx break frame_idx = used_frame_idx if detection is not None else priority_frames[0] frame_idx = int(np.clip(frame_idx, 0, state_in.num_frames - 1)) state_in.current_frame_idx = frame_idx def _result(preview_img, status_text): propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) status_updates = _ui_status_updates(state_in) return ( preview_img, gr.update(value=status_text, visible=True), gr.update(value=frame_idx), _build_kick_plot(state_in), propagate_main_update, detect_btn_update, propagate_player_update, gr.update(), _impact_status_update(state_in), *status_updates, ) if detection is None: state_in.is_player_detected = False status_text = ( f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. " "Please add a box manually." ) return _result(update_frame_display(state_in, frame_idx), status_text) x_min, y_min, x_max, y_max, conf = detection state_in.player_obj_id = PLAYER_OBJECT_ID state_in.player_detection_frame = frame_idx state_in.player_detection_conf = conf state_in.current_obj_id = PLAYER_OBJECT_ID state_in.is_player_detected = True # Clear previous player-specific prompts/masks for frame_boxes in state_in.boxes_by_frame_obj.values(): frame_boxes.pop(PLAYER_OBJECT_ID, None) for frame_clicks in state_in.clicks_by_frame_obj.values(): frame_clicks.pop(PLAYER_OBJECT_ID, None) for frame_masks in state_in.masks_by_frame.values(): frame_masks.pop(PLAYER_OBJECT_ID, None) _ensure_color_for_obj(state_in, PLAYER_OBJECT_ID) processor = state_in.processor model = state_in.model inference_session = state_in.inference_session inputs = processor(images=frame, device=state_in.device, return_tensors="pt") original_size = inputs.original_sizes[0] pixel_values = inputs.pixel_values[0] processor.add_inputs_to_inference_session( inference_session=inference_session, frame_idx=frame_idx, obj_ids=PLAYER_OBJECT_ID, input_boxes=[[[x_min, y_min, x_max, y_max]]], clear_old_inputs=True, original_size=original_size, ) frame_boxes = state_in.boxes_by_frame_obj.setdefault(frame_idx, {}) frame_boxes[PLAYER_OBJECT_ID] = [(x_min, y_min, x_max, y_max)] state_in.composited_frames.pop(frame_idx, None) with torch.inference_mode(): outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) H = inference_session.video_height W = inference_session.video_width pred_masks = outputs.pred_masks.detach().cpu() video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] masks_for_frame = state_in.masks_by_frame.get(frame_idx, {}).copy() obj_ids_order = list(inference_session.obj_ids) for i, oid in enumerate(obj_ids_order): mask_i = video_res_masks[i].cpu().numpy().squeeze() masks_for_frame[int(oid)] = mask_i state_in.masks_by_frame[frame_idx] = masks_for_frame _update_centroids_for_frame(state_in, frame_idx) state_in.composited_frames.pop(frame_idx, None) state_in.current_frame_idx = frame_idx status_text = ( f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})" ) return _result(update_frame_display(state_in, frame_idx), status_text) detect_player_btn.click( _auto_detect_player, inputs=[GLOBAL_STATE], outputs=[ preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, obj_id_inp, impact_status, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) @spaces.GPU() def propagate_player_masks(GLOBAL_STATE: gr.State): if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) return ( GLOBAL_STATE, "Load a video first.", gr.update(), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE): propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) return ( GLOBAL_STATE, "Detect the player before propagating.", gr.update(), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) processor = deepcopy(GLOBAL_STATE.processor) model = deepcopy(GLOBAL_STATE.model) inference_session = deepcopy(GLOBAL_STATE.inference_session) inference_session.inference_device = "cuda" inference_session.cache.inference_device = "cuda" model.to("cuda") if not GLOBAL_STATE.sam_window: _compute_sam_window_from_kick( GLOBAL_STATE, _get_prioritized_kick_frame(GLOBAL_STATE), ) start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames) start_idx = max(0, int(start_idx)) end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx))) total = max(1, end_idx - start_idx) processed = 0 last_frame_idx = start_idx GLOBAL_STATE.is_player_propagated = False propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) yield ( GLOBAL_STATE, f"Propagating player: {processed}/{total}", gr.update(), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID with torch.inference_mode(): for frame_idx in range(start_idx, end_idx): frame = GLOBAL_STATE.video_frames[frame_idx] pixel_values = None if ( inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames ): pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] sam2_video_output = model( inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx ) H = inference_session.video_height W = inference_session.video_width pred_masks = sam2_video_output.pred_masks.detach().cpu() video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] masks_for_frame = GLOBAL_STATE.masks_by_frame.get(frame_idx, {}).copy() obj_ids_order = list(inference_session.obj_ids) for i, oid in enumerate(obj_ids_order): mask_2d = video_res_masks[i].cpu().numpy().squeeze() if int(oid) == int(player_id): masks_for_frame[int(player_id)] = mask_2d GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame _update_centroids_for_frame(GLOBAL_STATE, frame_idx) GLOBAL_STATE.composited_frames.pop(frame_idx, None) processed += 1 last_frame_idx = frame_idx if processed % 30 == 0 or processed == total: propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) yield ( GLOBAL_STATE, f"Propagating player: {processed}/{total}", gr.update(value=frame_idx), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) text = f"Propagated player across {processed} frames." target_frame = GLOBAL_STATE.player_detection_frame if target_frame is None: target_frame = _get_prioritized_kick_frame(GLOBAL_STATE) if target_frame is None: target_frame = last_frame_idx target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1))) GLOBAL_STATE.current_frame_idx = target_frame GLOBAL_STATE.is_player_propagated = True propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) status_updates = _ui_status_updates(GLOBAL_STATE) yield ( GLOBAL_STATE, text, gr.update(value=target_frame), _build_kick_plot(GLOBAL_STATE), _build_yolo_plot(GLOBAL_STATE), _impact_status_update(GLOBAL_STATE), gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), propagate_main_update, detect_btn_update, propagate_player_update, *status_updates, ) propagate_player_btn.click( propagate_player_masks, inputs=[GLOBAL_STATE], outputs=[ GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) # Image click to add a point and run forward on that frame preview.select( _on_image_click_with_updates, [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], [ preview, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) goal_start_btn.click( _goal_start_mapping, inputs=[GLOBAL_STATE], outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status], ) goal_confirm_btn.click( _goal_confirm_mapping, inputs=[GLOBAL_STATE], outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status], ) goal_clear_btn.click( _goal_clear_mapping, inputs=[GLOBAL_STATE], outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status], ) goal_back_btn.click( _goal_back_mapping, inputs=[GLOBAL_STATE], outputs=[preview, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status], ) # Playback via MP4 rendering only # Render a smooth MP4 using imageio/pyav (fallbacks to imageio v2 / OpenCV) def _render_video(s: AppState, remove_bg: bool = False, log_fn=None): if s is None or s.num_frames == 0: raise gr.Error("Load a video first.") fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12 trim_duration_sec = 4.0 target_window_frames = max(1, int(round(fps * trim_duration_sec))) half_window = target_window_frames // 2 kick_frame = s.kick_frame or getattr(s, "kick_debug_kick_frame", None) start_idx = 0 end_idx = min(s.num_frames, target_window_frames) if kick_frame is not None: start_idx = max(0, int(kick_frame) - half_window) end_idx = start_idx + target_window_frames if end_idx > s.num_frames: end_idx = s.num_frames start_idx = max(0, end_idx - target_window_frames) else: end_idx = min(s.num_frames, start_idx + target_window_frames) if end_idx <= start_idx: end_idx = min(s.num_frames, start_idx + 1) # Compose all frames in trimmed window frames_np = [] first = compose_frame(s, start_idx, remove_bg=remove_bg) h, w = first.size[1], first.size[0] total_frames = max(1, end_idx - start_idx) for idx in range(start_idx, end_idx): # Don't use cache when remove_bg changes behavior if remove_bg: img = compose_frame(s, idx, remove_bg=True) else: img = s.composited_frames.get(idx) if img is None: img = compose_frame(s, idx, remove_bg=False) img_with_idx = _annotate_frame_index(img, idx) frames_np.append(np.array(img_with_idx)[:, :, ::-1]) # BGR for cv2 # Periodically release CPU mem to reduce pressure if (idx + 1) % 60 == 0: gc.collect() processed = idx - start_idx + 1 if log_fn and (processed % 20 == 0 or processed == total_frames): log_fn(f"Rendering frames {processed}/{total_frames}") import tempfile out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") out_path = out_file.name out_file.close() def _write_with_opencv(): fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) if not writer.isOpened(): writer.release() raise RuntimeError("OpenCV VideoWriter failed to open (missing codec?).") for fr_bgr in frames_np: writer.write(fr_bgr) writer.release() def _write_with_imageio(): import imageio with imageio.get_writer(out_path, fps=fps, codec="libx264", mode="I", quality=8) as writer: for fr_bgr in frames_np: writer.append_data(fr_bgr[:, :, ::-1]) # convert back to RGB try: _write_with_opencv() except Exception as cv_err: print(f"OpenCV writer failed: {cv_err}") try: if log_fn: log_fn("OpenCV writer unavailable, falling back to imageio/pyav.") _write_with_imageio() except Exception as io_err: print(f"Failed to render video: {io_err}") raise gr.Error(f"Failed to render video: {io_err}") return out_path render_btn.click(_render_video, inputs=[GLOBAL_STATE, remove_bg_checkbox], outputs=[playback_video]) # While propagating, we stream two outputs: status text and slider value updates propagate_btn.click( propagate_masks, inputs=[GLOBAL_STATE], outputs=[ GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) reset_btn.click( reset_session, inputs=GLOBAL_STATE, outputs=[ GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn, yolo_kick_btn, yolo_impact_btn, sam_kick_btn, sam_impact_btn, manual_kick_btn, manual_impact_btn, detect_ball_btn, track_ball_yolo_btn, goal_start_btn, goal_confirm_btn, goal_clear_btn, goal_back_btn, goal_status, ], ) # ============================================================================ # COMBINED INTERFACE WITH EXPLICIT API ENDPOINT # ============================================================================ # Create API interface with explicit endpoint api_interface = gr.Interface( fn=process_video_api, inputs=[ gr.Video(label="Video File"), gr.Textbox( label="Annotations JSON (optional)", placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}', lines=5 ), gr.Radio( choices=["tiny", "small", "base_plus", "large"], value="base_plus", label="SAM2 Checkpoint" ), gr.Checkbox(label="Remove Background", value=True) ], outputs=[ gr.Image(label="Annotation Preview / First Frame"), gr.Video(label="Processed Video"), gr.Textbox(label="Processing Log", lines=12) ], title="SAM2 API", description=""" ## Programmatic KickTrimmer Pipeline Submitting a video here runs the same automated workflow as the Interactive UI: 1. **Upload** the raw MP4. 2. `YOLO13` **detects** and **tracks** the ball to get the first kick estimate. 3. `SAM2` **tracks the ball** around that kick window. If SAM2's kick disagrees with YOLO's, it automatically re-tracks **the entire clip** for better accuracy. 4. `YOLO13` **detects the player** on the SAM2 kick frame, then `SAM2` propagates the player masks around that window. 5. The Space **renders a default cutout video** and returns it together with the processing log below. ### Optional annotations You can still send helper points via JSON: ```json { "annotations": [ {"object_id": 1, "frame": 0, "x": 363, "y": 631, "label": "positive"}, {"object_id": 1, "frame": 187, "x": 296, "y": 485, "label": "positive"}, {"object_id": 2, "frame": 187, "x": 296, "y": 412, "label": "positive"} ] } ``` - **Object 1** = ball, **Object 2** = player. Use timestamps when possible; the API will reconcile timestamps and frame indices for you. """ ) # Use gr.Blocks to combine both with proper API exposure with gr.Blocks(title="SAM2 Video Tracking") as combined_demo: gr.Markdown("# SAM2 Video Tracking") with gr.Tabs(): with gr.TabItem("Interactive UI"): demo.render() with gr.TabItem("API"): api_interface.render() # Explicitly expose the API function at root level for external API calls # This creates the /api/predict endpoint api_video_input_hidden = gr.Video(visible=False) api_annotations_input_hidden = gr.Textbox(visible=False) api_checkpoint_input_hidden = gr.Radio(choices=["tiny", "small", "base_plus", "large"], visible=False) api_remove_bg_input_hidden = gr.Checkbox(visible=False) api_preview_output_hidden = gr.Image(visible=False) api_video_output_hidden = gr.Video(visible=False) api_logs_output_hidden = gr.Textbox(visible=False) # This dummy component creates the external API endpoint api_dummy_btn = gr.Button("API", visible=False) api_dummy_btn.click( fn=process_video_api, inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden], outputs=[api_preview_output_hidden, api_video_output_hidden, api_logs_output_hidden], api_name="predict" # This creates /api/predict for external calls ) # Launch with API enabled if __name__ == "__main__": combined_demo.queue(api_open=True).launch()