Mirko Trasciatti commited on
Commit
7e8596f
·
1 Parent(s): a34bbd4

Refine ghost trail to be anticipatory

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -893,7 +893,8 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
893
  frame = state.video_frames[frame_idx]
894
  masks = state.masks_by_frame.get(frame_idx, {})
895
  out_img = frame
896
-
 
897
  if len(masks) != 0:
898
  if remove_bg:
899
  # Remove background - show only tracked objects
@@ -910,20 +911,19 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
910
  out_img = Image.fromarray(result_np)
911
  else:
912
  overlay_masks = masks
913
- if state.fx_ghost_trail_enabled:
914
  overlay_masks = {oid: mask for oid, mask in masks.items() if oid != BALL_OBJECT_ID}
915
  if overlay_masks:
916
  out_img = overlay_masks_on_frame(out_img, overlay_masks, state.color_by_obj, alpha=0.65)
917
- if state.fx_ghost_trail_enabled:
918
- trail_mask = _build_ball_trail_mask(state, frame_idx)
919
- if trail_mask is not None:
920
- base_np = np.array(out_img).astype(np.float32) / 255.0
921
- mask_np = np.clip(trail_mask.astype(np.float32), 0.0, 1.0)
922
- mask_np = mask_np[..., None]
923
- base_np = (1.0 - GHOST_TRAIL_ALPHA * mask_np) * base_np + (
924
- GHOST_TRAIL_ALPHA * mask_np
925
- ) * GHOST_TRAIL_COLOR
926
- out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
927
 
928
  # Draw crosses for conditioning frames only (frames with recorded clicks)
929
  clicks_map = state.clicks_by_frame_obj.get(frame_idx)
@@ -1064,11 +1064,13 @@ def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None
1064
  ):
1065
  return None
1066
  kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame
1067
- if kick_candidate is None or frame_idx <= kick_candidate + 1:
1068
  return None
1069
 
1070
- start_idx = max(int(kick_candidate) + 1, 0)
1071
- end_idx = int(frame_idx)
 
 
1072
  trail_mask: np.ndarray | None = None
1073
 
1074
  for idx in range(start_idx, end_idx):
 
893
  frame = state.video_frames[frame_idx]
894
  masks = state.masks_by_frame.get(frame_idx, {})
895
  out_img = frame
896
+ ghost_mask = _build_ball_trail_mask(state, frame_idx)
897
+
898
  if len(masks) != 0:
899
  if remove_bg:
900
  # Remove background - show only tracked objects
 
911
  out_img = Image.fromarray(result_np)
912
  else:
913
  overlay_masks = masks
914
+ if ghost_mask is not None:
915
  overlay_masks = {oid: mask for oid, mask in masks.items() if oid != BALL_OBJECT_ID}
916
  if overlay_masks:
917
  out_img = overlay_masks_on_frame(out_img, overlay_masks, state.color_by_obj, alpha=0.65)
918
+
919
+ if ghost_mask is not None:
920
+ base_np = np.array(out_img).astype(np.float32) / 255.0
921
+ mask_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
922
+ if mask_np.ndim == 3:
923
+ mask_np = mask_np.squeeze()
924
+ mask_np = mask_np[..., None]
925
+ base_np = (1.0 - GHOST_TRAIL_ALPHA * mask_np) * base_np + (GHOST_TRAIL_ALPHA * mask_np) * GHOST_TRAIL_COLOR
926
+ out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
 
927
 
928
  # Draw crosses for conditioning frames only (frames with recorded clicks)
929
  clicks_map = state.clicks_by_frame_obj.get(frame_idx)
 
1064
  ):
1065
  return None
1066
  kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame
1067
+ if kick_candidate is None:
1068
  return None
1069
 
1070
+ start_idx = max(int(kick_candidate) + 1, int(frame_idx) + 1)
1071
+ end_idx = state.num_frames
1072
+ if start_idx >= end_idx:
1073
+ return None
1074
  trail_mask: np.ndarray | None = None
1075
 
1076
  for idx in range(start_idx, end_idx):