Mirko Trasciatti commited on
Commit
996c8dd
·
1 Parent(s): b230fd1

Sync API tab with YOLO→SAM2 workflow

Browse files
Files changed (1) hide show
  1. app.py +164 -99
app.py CHANGED
@@ -3230,123 +3230,182 @@ def process_video_api(
3230
  video_file,
3231
  annotations_json_str: str,
3232
  checkpoint: str = "base_plus",
3233
- remove_background: bool = True
3234
  ):
3235
  """
3236
  Single-endpoint API for programmatic video processing.
3237
-
3238
  Args:
3239
  video_file: Uploaded video file
3240
- annotations_json_str: JSON string with format:
3241
- {
3242
- "annotations": [
3243
- {"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"},
3244
- {"object_id": 1, "frame": 156, "x": 374, "y": 513, "label": "positive"},
3245
- {"object_id": 2, "frame": 156, "x": 374, "y": 257, "label": "positive"}
3246
- ]
3247
- }
3248
  checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large)
3249
- remove_background: Whether to remove background (default: True)
3250
-
3251
  Returns:
3252
- Tuple of (preview_image, processed_video_path)
3253
  """
3254
  import json
3255
 
3256
  try:
3257
- # Parse annotations
3258
- annotations_data = json.loads(annotations_json_str)
 
 
 
 
 
 
 
 
3259
  annotations = annotations_data.get("annotations", [])
3260
- client_fps = annotations_data.get("fps", None) # FPS used by iOS app to calculate frame indices
3261
-
3262
- print(f"[API] Processing video with {len(annotations)} annotations")
3263
- print(f"[API] Client FPS: {client_fps}")
3264
- print(f"[API] Checkpoint: {checkpoint}")
3265
- print(f"[API] Remove background: {remove_background}")
3266
-
3267
- # Create preview of annotation points
3268
- preview_img = create_annotation_preview(video_file, annotations)
3269
-
3270
  # Create a temporary state for this API call
3271
  api_state = AppState()
3272
  api_state.model_repo_key = checkpoint
3273
 
3274
  # Step 1: Initialize session with video
 
3275
  api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file)
3276
  space_fps = api_state.video_fps
3277
- print(f"[API] Video loaded: {status}")
3278
- print(f"[API] ⚠️ FPS mismatch check: Client={client_fps}, Space={space_fps}")
3279
 
3280
  # If FPS mismatch, warn about potential frame offset
3281
  if client_fps and space_fps and abs(client_fps - space_fps) > 0.5:
3282
  offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps)))
3283
- print(f"[API] ⚠️ FPS mismatch detected! Frame indices may be off by ~{offset_estimate} frames")
3284
- print(f"[API] ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy")
3285
 
3286
  # Step 2: Apply each annotation
3287
- for i, ann in enumerate(annotations):
3288
- object_id = ann.get("object_id", 1)
3289
- timestamp_ms = ann.get("timestamp_ms", None)
3290
- frame_idx = ann.get("frame", None)
3291
- x = ann.get("x", 0)
3292
- y = ann.get("y", 0)
3293
- label = ann.get("label", "positive")
3294
-
3295
- # Calculate frame from timestamp using Space's FPS (more accurate)
3296
- if timestamp_ms is not None and space_fps and space_fps > 0:
3297
- calculated_frame = int((timestamp_ms / 1000.0) * space_fps)
3298
- if frame_idx is not None and calculated_frame != frame_idx:
3299
- print(f"[API] Using timestamp: {timestamp_ms}ms Frame {calculated_frame} (client sent frame {frame_idx})")
3300
- else:
3301
- print(f"[API] ✅ Calculated frame from timestamp: {timestamp_ms}ms → Frame {calculated_frame}")
3302
- frame_idx = calculated_frame
3303
- elif frame_idx is None:
3304
- print(f"[API] ⚠️ Warning: No timestamp or frame provided, using frame 0")
3305
- frame_idx = 0
3306
-
3307
- print(f"[API] Adding annotation {i+1}/{len(annotations)}: "
3308
- f"Object {object_id}, Frame {frame_idx}, ({x}, {y}), {label}")
3309
-
3310
- # Sync state
3311
- api_state.current_frame_idx = int(frame_idx)
3312
- api_state.current_obj_id = int(object_id)
3313
- api_state.current_label = str(label)
3314
-
3315
- # Create a mock event with coordinates
3316
- class MockEvent:
3317
- def __init__(self, x, y):
3318
- self.index = (x, y)
3319
-
3320
- mock_evt = MockEvent(x, y)
3321
-
3322
- # Add the point annotation
3323
- preview_img = on_image_click(
3324
- first_frame,
3325
- api_state,
3326
- frame_idx,
3327
- object_id,
3328
- label,
3329
- clear_old=False,
3330
- evt=mock_evt
3331
- )
3332
 
3333
- # Step 3: Propagate masks across all frames
3334
- print("[API] Propagating masks across video...")
3335
- # We need to consume the generator
3336
- for outputs in propagate_masks(api_state):
3337
- if not outputs:
3338
- continue
3339
- api_state = outputs[0]
3340
- status_msg = outputs[1] if len(outputs) > 1 else ""
3341
- if status_msg:
3342
- print(f"[API] Progress: {status_msg}")
3343
 
3344
- # Step 4: Render the final video
3345
- print(f"[API] Rendering video with remove_background={remove_background}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3346
  result_video_path = _render_video(api_state, remove_background)
3347
 
3348
- print(f"[API] ✅ Processing complete: {result_video_path}")
3349
- return preview_img, result_video_path
3350
 
3351
  except Exception as e:
3352
  print(f"[API] ❌ Error: {str(e)}")
@@ -4797,7 +4856,7 @@ api_interface = gr.Interface(
4797
  inputs=[
4798
  gr.Video(label="Video File"),
4799
  gr.Textbox(
4800
- label="Annotations JSON",
4801
  placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}',
4802
  lines=5
4803
  ),
@@ -4809,16 +4868,24 @@ api_interface = gr.Interface(
4809
  gr.Checkbox(label="Remove Background", value=True)
4810
  ],
4811
  outputs=[
4812
- gr.Image(label="Annotation Preview (shows where points are placed)"),
4813
- gr.Video(label="Processed Video")
 
4814
  ],
4815
  title="SAM2 API",
4816
  description="""
4817
- ## Programmatic API for Video Background Removal
4818
 
4819
- **The preview image shows where your annotation points are placed on the video frames.**
4820
 
4821
- **Annotations JSON Format:**
 
 
 
 
 
 
 
4822
  ```json
4823
  {
4824
  "annotations": [
@@ -4828,10 +4895,7 @@ api_interface = gr.Interface(
4828
  ]
4829
  }
4830
  ```
4831
-
4832
- - **Object 1** (Ball): Frame 0 + Impact frame
4833
- - **Object 2** (Player): Impact frame
4834
- - Colors represent different objects
4835
  """
4836
  )
4837
 
@@ -4854,13 +4918,14 @@ with gr.Blocks(title="SAM2 Video Tracking") as combined_demo:
4854
  api_remove_bg_input_hidden = gr.Checkbox(visible=False)
4855
  api_preview_output_hidden = gr.Image(visible=False)
4856
  api_video_output_hidden = gr.Video(visible=False)
 
4857
 
4858
  # This dummy component creates the external API endpoint
4859
  api_dummy_btn = gr.Button("API", visible=False)
4860
  api_dummy_btn.click(
4861
  fn=process_video_api,
4862
  inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden],
4863
- outputs=[api_preview_output_hidden, api_video_output_hidden],
4864
  api_name="predict" # This creates /api/predict for external calls
4865
  )
4866
 
 
3230
  video_file,
3231
  annotations_json_str: str,
3232
  checkpoint: str = "base_plus",
3233
+ remove_background: bool = True,
3234
  ):
3235
  """
3236
  Single-endpoint API for programmatic video processing.
3237
+
3238
  Args:
3239
  video_file: Uploaded video file
3240
+ annotations_json_str: Optional JSON string containing helper annotations
 
 
 
 
 
 
 
3241
  checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large)
3242
+ remove_background: Whether to remove the background in the render
3243
+
3244
  Returns:
3245
+ Tuple of (preview_image, processed_video_path, progress_log)
3246
  """
3247
  import json
3248
 
3249
  try:
3250
+ log_entries: list[str] = []
3251
+
3252
+ def log_msg(message: str):
3253
+ text = f"[API] {message}"
3254
+ print(text)
3255
+ log_entries.append(text)
3256
+
3257
+ # Parse annotations (optional)
3258
+ annotations_payload = annotations_json_str or ""
3259
+ annotations_data = json.loads(annotations_payload) if annotations_payload.strip() else {}
3260
  annotations = annotations_data.get("annotations", [])
3261
+ client_fps = annotations_data.get("fps", None)
3262
+
3263
+ log_msg(f"Received {len(annotations)} annotations")
3264
+ log_msg(f"Checkpoint: {checkpoint} | Remove background: {remove_background}")
3265
+
3266
+ preview_img = create_annotation_preview(video_file, annotations) if annotations else None
3267
+
 
 
 
3268
  # Create a temporary state for this API call
3269
  api_state = AppState()
3270
  api_state.model_repo_key = checkpoint
3271
 
3272
  # Step 1: Initialize session with video
3273
+ log_msg("Loading video...")
3274
  api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file)
3275
  space_fps = api_state.video_fps
3276
+ log_msg(status)
3277
+ log_msg(f"Client FPS={client_fps} | Space FPS={space_fps}")
3278
 
3279
  # If FPS mismatch, warn about potential frame offset
3280
  if client_fps and space_fps and abs(client_fps - space_fps) > 0.5:
3281
  offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps)))
3282
+ log_msg(f"⚠️ FPS mismatch detected. Frame indices may be off by ~{offset_estimate} frames.")
3283
+ log_msg("ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy.")
3284
 
3285
  # Step 2: Apply each annotation
3286
+ if annotations:
3287
+ for i, ann in enumerate(annotations):
3288
+ object_id = ann.get("object_id", 1)
3289
+ timestamp_ms = ann.get("timestamp_ms", None)
3290
+ frame_idx = ann.get("frame", None)
3291
+ x = ann.get("x", 0)
3292
+ y = ann.get("y", 0)
3293
+ label = ann.get("label", "positive")
3294
+
3295
+ # Calculate frame from timestamp using Space's FPS (more accurate)
3296
+ if timestamp_ms is not None and space_fps and space_fps > 0:
3297
+ calculated_frame = int((timestamp_ms / 1000.0) * space_fps)
3298
+ if frame_idx is not None and calculated_frame != frame_idx:
3299
+ log_msg(f"Annotation {i+1}: using timestamp {timestamp_ms}ms → Frame {calculated_frame} (client sent {frame_idx})")
3300
+ else:
3301
+ log_msg(f"Annotation {i+1}: timestamp {timestamp_ms}ms → Frame {calculated_frame}")
3302
+ frame_idx = calculated_frame
3303
+ elif frame_idx is None:
3304
+ log_msg(f"Annotation {i+1}: ⚠️ No timestamp/frame provided, defaulting to frame 0")
3305
+ frame_idx = 0
3306
+
3307
+ log_msg(f"Adding annotation {i+1}/{len(annotations)} | Obj {object_id} | Frame {frame_idx}")
3308
+
3309
+ # Sync state
3310
+ api_state.current_frame_idx = int(frame_idx)
3311
+ api_state.current_obj_id = int(object_id)
3312
+ api_state.current_label = str(label)
3313
+
3314
+ # Create a mock event with coordinates
3315
+ class MockEvent:
3316
+ def __init__(self, x, y):
3317
+ self.index = (x, y)
3318
+
3319
+ mock_evt = MockEvent(x, y)
3320
+
3321
+ # Add the point annotation
3322
+ preview_img = on_image_click(
3323
+ first_frame,
3324
+ api_state,
3325
+ frame_idx,
3326
+ object_id,
3327
+ label,
3328
+ clear_old=False,
3329
+ evt=mock_evt
3330
+ )
3331
 
3332
+ if preview_img is None:
3333
+ preview_img = first_frame
 
 
 
 
 
 
 
 
3334
 
3335
+ # Helper to consume generator-based steps and capture log messages
3336
+ def _run_generator(gen, label: str):
3337
+ final_state = None
3338
+ for outputs in gen:
3339
+ if not outputs:
3340
+ continue
3341
+ final_state = outputs[0]
3342
+ status_msg = outputs[1] if len(outputs) > 1 else ""
3343
+ if status_msg:
3344
+ log_msg(f"{label}: {status_msg}")
3345
+ if final_state is not None:
3346
+ return final_state
3347
+ raise gr.Error(f"{label} did not produce any output.")
3348
+
3349
+ # Step 3: YOLO13 detect ball
3350
+ api_state.current_obj_id = BALL_OBJECT_ID
3351
+ api_state.current_label = "positive"
3352
+ log_msg("YOLO13 · Detect ball (single-frame search)")
3353
+ _auto_detect_ball(api_state, BALL_OBJECT_ID, "positive", False)
3354
+ if not api_state.is_ball_detected:
3355
+ raise gr.Error("YOLO13 could not detect the ball automatically.")
3356
+
3357
+ # Step 4: YOLO13 track ball
3358
+ log_msg("YOLO13 · Track ball across clip")
3359
+ _track_ball_yolo(api_state)
3360
+ if not api_state.is_yolo_tracked:
3361
+ raise gr.Error("YOLO13 tracking failed.")
3362
+
3363
+ # Step 5: SAM2 track ball around kick window
3364
+ log_msg("SAM2 · Track ball around kick window")
3365
+ api_state = _run_generator(propagate_masks(api_state), "SAM2 · Ball")
3366
+ sam_kick = _get_prioritized_kick_frame(api_state)
3367
+ yolo_kick = api_state.yolo_kick_frame
3368
+ if sam_kick is not None:
3369
+ log_msg(f"SAM2 kick frame ≈ {sam_kick}")
3370
+ if yolo_kick is not None:
3371
+ log_msg(f"YOLO kick frame ≈ {yolo_kick}")
3372
+
3373
+ # Fallback: re-run SAM2 on entire video if kicks disagree
3374
+ if (
3375
+ yolo_kick is not None
3376
+ and sam_kick is not None
3377
+ and int(yolo_kick) != int(sam_kick)
3378
+ ):
3379
+ log_msg("Kick disagreement detected → re-running SAM2 across entire video.")
3380
+ api_state.sam_window = (0, api_state.num_frames)
3381
+ api_state = _run_generator(propagate_masks(api_state), "SAM2 · Full sweep")
3382
+ sam_kick = _get_prioritized_kick_frame(api_state)
3383
+ log_msg(f"SAM2 full sweep kick frame ≈ {sam_kick}")
3384
+ else:
3385
+ log_msg("Kick frames aligned. No full sweep required.")
3386
+
3387
+ # Step 6: YOLO detect player on SAM2 kick frame
3388
+ log_msg("YOLO13 · Detect player on SAM2 kick frame")
3389
+ _auto_detect_player(api_state)
3390
+ if api_state.is_player_detected:
3391
+ log_msg("YOLO13 · Player detected successfully.")
3392
+ else:
3393
+ log_msg("YOLO13 · Player detection failed; continuing without player propagation.")
3394
+
3395
+ # Step 7: SAM2 track player if detection succeeded
3396
+ if api_state.is_player_detected:
3397
+ log_msg("SAM2 · Track player around kick window")
3398
+ try:
3399
+ api_state = _run_generator(propagate_player_masks(api_state), "SAM2 · Player")
3400
+ except gr.Error as player_error:
3401
+ log_msg(f"SAM2 player propagation warning: {player_error}")
3402
+
3403
+ # Step 8: Render the final video
3404
+ log_msg(f"Rendering video (remove_background={remove_background})")
3405
  result_video_path = _render_video(api_state, remove_background)
3406
 
3407
+ log_msg("Processing complete 🎉")
3408
+ return preview_img, result_video_path, "\n".join(log_entries)
3409
 
3410
  except Exception as e:
3411
  print(f"[API] ❌ Error: {str(e)}")
 
4856
  inputs=[
4857
  gr.Video(label="Video File"),
4858
  gr.Textbox(
4859
+ label="Annotations JSON (optional)",
4860
  placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}',
4861
  lines=5
4862
  ),
 
4868
  gr.Checkbox(label="Remove Background", value=True)
4869
  ],
4870
  outputs=[
4871
+ gr.Image(label="Annotation Preview / First Frame"),
4872
+ gr.Video(label="Processed Video"),
4873
+ gr.Textbox(label="Processing Log", lines=12)
4874
  ],
4875
  title="SAM2 API",
4876
  description="""
4877
+ ## Programmatic KickTrimmer Pipeline
4878
 
4879
+ Submitting a video here runs the same automated workflow as the Interactive UI:
4880
 
4881
+ 1. **Upload** the raw MP4.
4882
+ 2. `YOLO13` **detects** and **tracks** the ball to get the first kick estimate.
4883
+ 3. `SAM2` **tracks the ball** around that kick window. If SAM2's kick disagrees with YOLO's, it automatically re-tracks **the entire clip** for better accuracy.
4884
+ 4. `YOLO13` **detects the player** on the SAM2 kick frame, then `SAM2` propagates the player masks around that window.
4885
+ 5. The Space **renders a default cutout video** and returns it together with the processing log below.
4886
+
4887
+ ### Optional annotations
4888
+ You can still send helper points via JSON:
4889
  ```json
4890
  {
4891
  "annotations": [
 
4895
  ]
4896
  }
4897
  ```
4898
+ - **Object 1** = ball, **Object 2** = player. Use timestamps when possible; the API will reconcile timestamps and frame indices for you.
 
 
 
4899
  """
4900
  )
4901
 
 
4918
  api_remove_bg_input_hidden = gr.Checkbox(visible=False)
4919
  api_preview_output_hidden = gr.Image(visible=False)
4920
  api_video_output_hidden = gr.Video(visible=False)
4921
+ api_logs_output_hidden = gr.Textbox(visible=False)
4922
 
4923
  # This dummy component creates the external API endpoint
4924
  api_dummy_btn = gr.Button("API", visible=False)
4925
  api_dummy_btn.click(
4926
  fn=process_video_api,
4927
  inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden],
4928
+ outputs=[api_preview_output_hidden, api_video_output_hidden, api_logs_output_hidden],
4929
  api_name="predict" # This creates /api/predict for external calls
4930
  )
4931