Mirko Trasciatti
commited on
Commit
·
996c8dd
1
Parent(s):
b230fd1
Sync API tab with YOLO→SAM2 workflow
Browse files
app.py
CHANGED
|
@@ -3230,123 +3230,182 @@ def process_video_api(
|
|
| 3230 |
video_file,
|
| 3231 |
annotations_json_str: str,
|
| 3232 |
checkpoint: str = "base_plus",
|
| 3233 |
-
remove_background: bool = True
|
| 3234 |
):
|
| 3235 |
"""
|
| 3236 |
Single-endpoint API for programmatic video processing.
|
| 3237 |
-
|
| 3238 |
Args:
|
| 3239 |
video_file: Uploaded video file
|
| 3240 |
-
annotations_json_str: JSON string
|
| 3241 |
-
{
|
| 3242 |
-
"annotations": [
|
| 3243 |
-
{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"},
|
| 3244 |
-
{"object_id": 1, "frame": 156, "x": 374, "y": 513, "label": "positive"},
|
| 3245 |
-
{"object_id": 2, "frame": 156, "x": 374, "y": 257, "label": "positive"}
|
| 3246 |
-
]
|
| 3247 |
-
}
|
| 3248 |
checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large)
|
| 3249 |
-
remove_background: Whether to remove background
|
| 3250 |
-
|
| 3251 |
Returns:
|
| 3252 |
-
Tuple of (preview_image, processed_video_path)
|
| 3253 |
"""
|
| 3254 |
import json
|
| 3255 |
|
| 3256 |
try:
|
| 3257 |
-
|
| 3258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3259 |
annotations = annotations_data.get("annotations", [])
|
| 3260 |
-
client_fps = annotations_data.get("fps", None)
|
| 3261 |
-
|
| 3262 |
-
|
| 3263 |
-
|
| 3264 |
-
|
| 3265 |
-
|
| 3266 |
-
|
| 3267 |
-
# Create preview of annotation points
|
| 3268 |
-
preview_img = create_annotation_preview(video_file, annotations)
|
| 3269 |
-
|
| 3270 |
# Create a temporary state for this API call
|
| 3271 |
api_state = AppState()
|
| 3272 |
api_state.model_repo_key = checkpoint
|
| 3273 |
|
| 3274 |
# Step 1: Initialize session with video
|
|
|
|
| 3275 |
api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file)
|
| 3276 |
space_fps = api_state.video_fps
|
| 3277 |
-
|
| 3278 |
-
|
| 3279 |
|
| 3280 |
# If FPS mismatch, warn about potential frame offset
|
| 3281 |
if client_fps and space_fps and abs(client_fps - space_fps) > 0.5:
|
| 3282 |
offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps)))
|
| 3283 |
-
|
| 3284 |
-
|
| 3285 |
|
| 3286 |
# Step 2: Apply each annotation
|
| 3287 |
-
|
| 3288 |
-
|
| 3289 |
-
|
| 3290 |
-
|
| 3291 |
-
|
| 3292 |
-
|
| 3293 |
-
|
| 3294 |
-
|
| 3295 |
-
|
| 3296 |
-
|
| 3297 |
-
|
| 3298 |
-
|
| 3299 |
-
|
| 3300 |
-
|
| 3301 |
-
|
| 3302 |
-
|
| 3303 |
-
|
| 3304 |
-
|
| 3305 |
-
|
| 3306 |
-
|
| 3307 |
-
|
| 3308 |
-
|
| 3309 |
-
|
| 3310 |
-
|
| 3311 |
-
|
| 3312 |
-
|
| 3313 |
-
|
| 3314 |
-
|
| 3315 |
-
|
| 3316 |
-
|
| 3317 |
-
|
| 3318 |
-
|
| 3319 |
-
|
| 3320 |
-
|
| 3321 |
-
|
| 3322 |
-
|
| 3323 |
-
|
| 3324 |
-
|
| 3325 |
-
|
| 3326 |
-
|
| 3327 |
-
|
| 3328 |
-
|
| 3329 |
-
|
| 3330 |
-
|
| 3331 |
-
|
| 3332 |
|
| 3333 |
-
|
| 3334 |
-
|
| 3335 |
-
# We need to consume the generator
|
| 3336 |
-
for outputs in propagate_masks(api_state):
|
| 3337 |
-
if not outputs:
|
| 3338 |
-
continue
|
| 3339 |
-
api_state = outputs[0]
|
| 3340 |
-
status_msg = outputs[1] if len(outputs) > 1 else ""
|
| 3341 |
-
if status_msg:
|
| 3342 |
-
print(f"[API] Progress: {status_msg}")
|
| 3343 |
|
| 3344 |
-
#
|
| 3345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3346 |
result_video_path = _render_video(api_state, remove_background)
|
| 3347 |
|
| 3348 |
-
|
| 3349 |
-
return preview_img, result_video_path
|
| 3350 |
|
| 3351 |
except Exception as e:
|
| 3352 |
print(f"[API] ❌ Error: {str(e)}")
|
|
@@ -4797,7 +4856,7 @@ api_interface = gr.Interface(
|
|
| 4797 |
inputs=[
|
| 4798 |
gr.Video(label="Video File"),
|
| 4799 |
gr.Textbox(
|
| 4800 |
-
label="Annotations JSON",
|
| 4801 |
placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}',
|
| 4802 |
lines=5
|
| 4803 |
),
|
|
@@ -4809,16 +4868,24 @@ api_interface = gr.Interface(
|
|
| 4809 |
gr.Checkbox(label="Remove Background", value=True)
|
| 4810 |
],
|
| 4811 |
outputs=[
|
| 4812 |
-
gr.Image(label="Annotation Preview
|
| 4813 |
-
gr.Video(label="Processed Video")
|
|
|
|
| 4814 |
],
|
| 4815 |
title="SAM2 API",
|
| 4816 |
description="""
|
| 4817 |
-
## Programmatic
|
| 4818 |
|
| 4819 |
-
|
| 4820 |
|
| 4821 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4822 |
```json
|
| 4823 |
{
|
| 4824 |
"annotations": [
|
|
@@ -4828,10 +4895,7 @@ api_interface = gr.Interface(
|
|
| 4828 |
]
|
| 4829 |
}
|
| 4830 |
```
|
| 4831 |
-
|
| 4832 |
-
- **Object 1** (Ball): Frame 0 + Impact frame
|
| 4833 |
-
- **Object 2** (Player): Impact frame
|
| 4834 |
-
- Colors represent different objects
|
| 4835 |
"""
|
| 4836 |
)
|
| 4837 |
|
|
@@ -4854,13 +4918,14 @@ with gr.Blocks(title="SAM2 Video Tracking") as combined_demo:
|
|
| 4854 |
api_remove_bg_input_hidden = gr.Checkbox(visible=False)
|
| 4855 |
api_preview_output_hidden = gr.Image(visible=False)
|
| 4856 |
api_video_output_hidden = gr.Video(visible=False)
|
|
|
|
| 4857 |
|
| 4858 |
# This dummy component creates the external API endpoint
|
| 4859 |
api_dummy_btn = gr.Button("API", visible=False)
|
| 4860 |
api_dummy_btn.click(
|
| 4861 |
fn=process_video_api,
|
| 4862 |
inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden],
|
| 4863 |
-
outputs=[api_preview_output_hidden, api_video_output_hidden],
|
| 4864 |
api_name="predict" # This creates /api/predict for external calls
|
| 4865 |
)
|
| 4866 |
|
|
|
|
| 3230 |
video_file,
|
| 3231 |
annotations_json_str: str,
|
| 3232 |
checkpoint: str = "base_plus",
|
| 3233 |
+
remove_background: bool = True,
|
| 3234 |
):
|
| 3235 |
"""
|
| 3236 |
Single-endpoint API for programmatic video processing.
|
| 3237 |
+
|
| 3238 |
Args:
|
| 3239 |
video_file: Uploaded video file
|
| 3240 |
+
annotations_json_str: Optional JSON string containing helper annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3241 |
checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large)
|
| 3242 |
+
remove_background: Whether to remove the background in the render
|
| 3243 |
+
|
| 3244 |
Returns:
|
| 3245 |
+
Tuple of (preview_image, processed_video_path, progress_log)
|
| 3246 |
"""
|
| 3247 |
import json
|
| 3248 |
|
| 3249 |
try:
|
| 3250 |
+
log_entries: list[str] = []
|
| 3251 |
+
|
| 3252 |
+
def log_msg(message: str):
|
| 3253 |
+
text = f"[API] {message}"
|
| 3254 |
+
print(text)
|
| 3255 |
+
log_entries.append(text)
|
| 3256 |
+
|
| 3257 |
+
# Parse annotations (optional)
|
| 3258 |
+
annotations_payload = annotations_json_str or ""
|
| 3259 |
+
annotations_data = json.loads(annotations_payload) if annotations_payload.strip() else {}
|
| 3260 |
annotations = annotations_data.get("annotations", [])
|
| 3261 |
+
client_fps = annotations_data.get("fps", None)
|
| 3262 |
+
|
| 3263 |
+
log_msg(f"Received {len(annotations)} annotations")
|
| 3264 |
+
log_msg(f"Checkpoint: {checkpoint} | Remove background: {remove_background}")
|
| 3265 |
+
|
| 3266 |
+
preview_img = create_annotation_preview(video_file, annotations) if annotations else None
|
| 3267 |
+
|
|
|
|
|
|
|
|
|
|
| 3268 |
# Create a temporary state for this API call
|
| 3269 |
api_state = AppState()
|
| 3270 |
api_state.model_repo_key = checkpoint
|
| 3271 |
|
| 3272 |
# Step 1: Initialize session with video
|
| 3273 |
+
log_msg("Loading video...")
|
| 3274 |
api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file)
|
| 3275 |
space_fps = api_state.video_fps
|
| 3276 |
+
log_msg(status)
|
| 3277 |
+
log_msg(f"Client FPS={client_fps} | Space FPS={space_fps}")
|
| 3278 |
|
| 3279 |
# If FPS mismatch, warn about potential frame offset
|
| 3280 |
if client_fps and space_fps and abs(client_fps - space_fps) > 0.5:
|
| 3281 |
offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps)))
|
| 3282 |
+
log_msg(f"⚠️ FPS mismatch detected. Frame indices may be off by ~{offset_estimate} frames.")
|
| 3283 |
+
log_msg("ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy.")
|
| 3284 |
|
| 3285 |
# Step 2: Apply each annotation
|
| 3286 |
+
if annotations:
|
| 3287 |
+
for i, ann in enumerate(annotations):
|
| 3288 |
+
object_id = ann.get("object_id", 1)
|
| 3289 |
+
timestamp_ms = ann.get("timestamp_ms", None)
|
| 3290 |
+
frame_idx = ann.get("frame", None)
|
| 3291 |
+
x = ann.get("x", 0)
|
| 3292 |
+
y = ann.get("y", 0)
|
| 3293 |
+
label = ann.get("label", "positive")
|
| 3294 |
+
|
| 3295 |
+
# Calculate frame from timestamp using Space's FPS (more accurate)
|
| 3296 |
+
if timestamp_ms is not None and space_fps and space_fps > 0:
|
| 3297 |
+
calculated_frame = int((timestamp_ms / 1000.0) * space_fps)
|
| 3298 |
+
if frame_idx is not None and calculated_frame != frame_idx:
|
| 3299 |
+
log_msg(f"Annotation {i+1}: using timestamp {timestamp_ms}ms → Frame {calculated_frame} (client sent {frame_idx})")
|
| 3300 |
+
else:
|
| 3301 |
+
log_msg(f"Annotation {i+1}: timestamp {timestamp_ms}ms → Frame {calculated_frame}")
|
| 3302 |
+
frame_idx = calculated_frame
|
| 3303 |
+
elif frame_idx is None:
|
| 3304 |
+
log_msg(f"Annotation {i+1}: ⚠️ No timestamp/frame provided, defaulting to frame 0")
|
| 3305 |
+
frame_idx = 0
|
| 3306 |
+
|
| 3307 |
+
log_msg(f"Adding annotation {i+1}/{len(annotations)} | Obj {object_id} | Frame {frame_idx}")
|
| 3308 |
+
|
| 3309 |
+
# Sync state
|
| 3310 |
+
api_state.current_frame_idx = int(frame_idx)
|
| 3311 |
+
api_state.current_obj_id = int(object_id)
|
| 3312 |
+
api_state.current_label = str(label)
|
| 3313 |
+
|
| 3314 |
+
# Create a mock event with coordinates
|
| 3315 |
+
class MockEvent:
|
| 3316 |
+
def __init__(self, x, y):
|
| 3317 |
+
self.index = (x, y)
|
| 3318 |
+
|
| 3319 |
+
mock_evt = MockEvent(x, y)
|
| 3320 |
+
|
| 3321 |
+
# Add the point annotation
|
| 3322 |
+
preview_img = on_image_click(
|
| 3323 |
+
first_frame,
|
| 3324 |
+
api_state,
|
| 3325 |
+
frame_idx,
|
| 3326 |
+
object_id,
|
| 3327 |
+
label,
|
| 3328 |
+
clear_old=False,
|
| 3329 |
+
evt=mock_evt
|
| 3330 |
+
)
|
| 3331 |
|
| 3332 |
+
if preview_img is None:
|
| 3333 |
+
preview_img = first_frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3334 |
|
| 3335 |
+
# Helper to consume generator-based steps and capture log messages
|
| 3336 |
+
def _run_generator(gen, label: str):
|
| 3337 |
+
final_state = None
|
| 3338 |
+
for outputs in gen:
|
| 3339 |
+
if not outputs:
|
| 3340 |
+
continue
|
| 3341 |
+
final_state = outputs[0]
|
| 3342 |
+
status_msg = outputs[1] if len(outputs) > 1 else ""
|
| 3343 |
+
if status_msg:
|
| 3344 |
+
log_msg(f"{label}: {status_msg}")
|
| 3345 |
+
if final_state is not None:
|
| 3346 |
+
return final_state
|
| 3347 |
+
raise gr.Error(f"{label} did not produce any output.")
|
| 3348 |
+
|
| 3349 |
+
# Step 3: YOLO13 detect ball
|
| 3350 |
+
api_state.current_obj_id = BALL_OBJECT_ID
|
| 3351 |
+
api_state.current_label = "positive"
|
| 3352 |
+
log_msg("YOLO13 · Detect ball (single-frame search)")
|
| 3353 |
+
_auto_detect_ball(api_state, BALL_OBJECT_ID, "positive", False)
|
| 3354 |
+
if not api_state.is_ball_detected:
|
| 3355 |
+
raise gr.Error("YOLO13 could not detect the ball automatically.")
|
| 3356 |
+
|
| 3357 |
+
# Step 4: YOLO13 track ball
|
| 3358 |
+
log_msg("YOLO13 · Track ball across clip")
|
| 3359 |
+
_track_ball_yolo(api_state)
|
| 3360 |
+
if not api_state.is_yolo_tracked:
|
| 3361 |
+
raise gr.Error("YOLO13 tracking failed.")
|
| 3362 |
+
|
| 3363 |
+
# Step 5: SAM2 track ball around kick window
|
| 3364 |
+
log_msg("SAM2 · Track ball around kick window")
|
| 3365 |
+
api_state = _run_generator(propagate_masks(api_state), "SAM2 · Ball")
|
| 3366 |
+
sam_kick = _get_prioritized_kick_frame(api_state)
|
| 3367 |
+
yolo_kick = api_state.yolo_kick_frame
|
| 3368 |
+
if sam_kick is not None:
|
| 3369 |
+
log_msg(f"SAM2 kick frame ≈ {sam_kick}")
|
| 3370 |
+
if yolo_kick is not None:
|
| 3371 |
+
log_msg(f"YOLO kick frame ≈ {yolo_kick}")
|
| 3372 |
+
|
| 3373 |
+
# Fallback: re-run SAM2 on entire video if kicks disagree
|
| 3374 |
+
if (
|
| 3375 |
+
yolo_kick is not None
|
| 3376 |
+
and sam_kick is not None
|
| 3377 |
+
and int(yolo_kick) != int(sam_kick)
|
| 3378 |
+
):
|
| 3379 |
+
log_msg("Kick disagreement detected → re-running SAM2 across entire video.")
|
| 3380 |
+
api_state.sam_window = (0, api_state.num_frames)
|
| 3381 |
+
api_state = _run_generator(propagate_masks(api_state), "SAM2 · Full sweep")
|
| 3382 |
+
sam_kick = _get_prioritized_kick_frame(api_state)
|
| 3383 |
+
log_msg(f"SAM2 full sweep kick frame ≈ {sam_kick}")
|
| 3384 |
+
else:
|
| 3385 |
+
log_msg("Kick frames aligned. No full sweep required.")
|
| 3386 |
+
|
| 3387 |
+
# Step 6: YOLO detect player on SAM2 kick frame
|
| 3388 |
+
log_msg("YOLO13 · Detect player on SAM2 kick frame")
|
| 3389 |
+
_auto_detect_player(api_state)
|
| 3390 |
+
if api_state.is_player_detected:
|
| 3391 |
+
log_msg("YOLO13 · Player detected successfully.")
|
| 3392 |
+
else:
|
| 3393 |
+
log_msg("YOLO13 · Player detection failed; continuing without player propagation.")
|
| 3394 |
+
|
| 3395 |
+
# Step 7: SAM2 track player if detection succeeded
|
| 3396 |
+
if api_state.is_player_detected:
|
| 3397 |
+
log_msg("SAM2 · Track player around kick window")
|
| 3398 |
+
try:
|
| 3399 |
+
api_state = _run_generator(propagate_player_masks(api_state), "SAM2 · Player")
|
| 3400 |
+
except gr.Error as player_error:
|
| 3401 |
+
log_msg(f"SAM2 player propagation warning: {player_error}")
|
| 3402 |
+
|
| 3403 |
+
# Step 8: Render the final video
|
| 3404 |
+
log_msg(f"Rendering video (remove_background={remove_background})")
|
| 3405 |
result_video_path = _render_video(api_state, remove_background)
|
| 3406 |
|
| 3407 |
+
log_msg("Processing complete 🎉")
|
| 3408 |
+
return preview_img, result_video_path, "\n".join(log_entries)
|
| 3409 |
|
| 3410 |
except Exception as e:
|
| 3411 |
print(f"[API] ❌ Error: {str(e)}")
|
|
|
|
| 4856 |
inputs=[
|
| 4857 |
gr.Video(label="Video File"),
|
| 4858 |
gr.Textbox(
|
| 4859 |
+
label="Annotations JSON (optional)",
|
| 4860 |
placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}',
|
| 4861 |
lines=5
|
| 4862 |
),
|
|
|
|
| 4868 |
gr.Checkbox(label="Remove Background", value=True)
|
| 4869 |
],
|
| 4870 |
outputs=[
|
| 4871 |
+
gr.Image(label="Annotation Preview / First Frame"),
|
| 4872 |
+
gr.Video(label="Processed Video"),
|
| 4873 |
+
gr.Textbox(label="Processing Log", lines=12)
|
| 4874 |
],
|
| 4875 |
title="SAM2 API",
|
| 4876 |
description="""
|
| 4877 |
+
## Programmatic KickTrimmer Pipeline
|
| 4878 |
|
| 4879 |
+
Submitting a video here runs the same automated workflow as the Interactive UI:
|
| 4880 |
|
| 4881 |
+
1. **Upload** the raw MP4.
|
| 4882 |
+
2. `YOLO13` **detects** and **tracks** the ball to get the first kick estimate.
|
| 4883 |
+
3. `SAM2` **tracks the ball** around that kick window. If SAM2's kick disagrees with YOLO's, it automatically re-tracks **the entire clip** for better accuracy.
|
| 4884 |
+
4. `YOLO13` **detects the player** on the SAM2 kick frame, then `SAM2` propagates the player masks around that window.
|
| 4885 |
+
5. The Space **renders a default cutout video** and returns it together with the processing log below.
|
| 4886 |
+
|
| 4887 |
+
### Optional annotations
|
| 4888 |
+
You can still send helper points via JSON:
|
| 4889 |
```json
|
| 4890 |
{
|
| 4891 |
"annotations": [
|
|
|
|
| 4895 |
]
|
| 4896 |
}
|
| 4897 |
```
|
| 4898 |
+
- **Object 1** = ball, **Object 2** = player. Use timestamps when possible; the API will reconcile timestamps and frame indices for you.
|
|
|
|
|
|
|
|
|
|
| 4899 |
"""
|
| 4900 |
)
|
| 4901 |
|
|
|
|
| 4918 |
api_remove_bg_input_hidden = gr.Checkbox(visible=False)
|
| 4919 |
api_preview_output_hidden = gr.Image(visible=False)
|
| 4920 |
api_video_output_hidden = gr.Video(visible=False)
|
| 4921 |
+
api_logs_output_hidden = gr.Textbox(visible=False)
|
| 4922 |
|
| 4923 |
# This dummy component creates the external API endpoint
|
| 4924 |
api_dummy_btn = gr.Button("API", visible=False)
|
| 4925 |
api_dummy_btn.click(
|
| 4926 |
fn=process_video_api,
|
| 4927 |
inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden],
|
| 4928 |
+
outputs=[api_preview_output_hidden, api_video_output_hidden, api_logs_output_hidden],
|
| 4929 |
api_name="predict" # This creates /api/predict for external calls
|
| 4930 |
)
|
| 4931 |
|