Mirko Trasciatti
commited on
Commit
·
7e8596f
1
Parent(s):
a34bbd4
Refine ghost trail to be anticipatory
Browse files
app.py
CHANGED
|
@@ -893,7 +893,8 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 893 |
frame = state.video_frames[frame_idx]
|
| 894 |
masks = state.masks_by_frame.get(frame_idx, {})
|
| 895 |
out_img = frame
|
| 896 |
-
|
|
|
|
| 897 |
if len(masks) != 0:
|
| 898 |
if remove_bg:
|
| 899 |
# Remove background - show only tracked objects
|
|
@@ -910,20 +911,19 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 910 |
out_img = Image.fromarray(result_np)
|
| 911 |
else:
|
| 912 |
overlay_masks = masks
|
| 913 |
-
if
|
| 914 |
overlay_masks = {oid: mask for oid, mask in masks.items() if oid != BALL_OBJECT_ID}
|
| 915 |
if overlay_masks:
|
| 916 |
out_img = overlay_masks_on_frame(out_img, overlay_masks, state.color_by_obj, alpha=0.65)
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
| 927 |
|
| 928 |
# Draw crosses for conditioning frames only (frames with recorded clicks)
|
| 929 |
clicks_map = state.clicks_by_frame_obj.get(frame_idx)
|
|
@@ -1064,11 +1064,13 @@ def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None
|
|
| 1064 |
):
|
| 1065 |
return None
|
| 1066 |
kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame
|
| 1067 |
-
if kick_candidate is None
|
| 1068 |
return None
|
| 1069 |
|
| 1070 |
-
start_idx = max(int(kick_candidate) + 1,
|
| 1071 |
-
end_idx =
|
|
|
|
|
|
|
| 1072 |
trail_mask: np.ndarray | None = None
|
| 1073 |
|
| 1074 |
for idx in range(start_idx, end_idx):
|
|
|
|
| 893 |
frame = state.video_frames[frame_idx]
|
| 894 |
masks = state.masks_by_frame.get(frame_idx, {})
|
| 895 |
out_img = frame
|
| 896 |
+
ghost_mask = _build_ball_trail_mask(state, frame_idx)
|
| 897 |
+
|
| 898 |
if len(masks) != 0:
|
| 899 |
if remove_bg:
|
| 900 |
# Remove background - show only tracked objects
|
|
|
|
| 911 |
out_img = Image.fromarray(result_np)
|
| 912 |
else:
|
| 913 |
overlay_masks = masks
|
| 914 |
+
if ghost_mask is not None:
|
| 915 |
overlay_masks = {oid: mask for oid, mask in masks.items() if oid != BALL_OBJECT_ID}
|
| 916 |
if overlay_masks:
|
| 917 |
out_img = overlay_masks_on_frame(out_img, overlay_masks, state.color_by_obj, alpha=0.65)
|
| 918 |
+
|
| 919 |
+
if ghost_mask is not None:
|
| 920 |
+
base_np = np.array(out_img).astype(np.float32) / 255.0
|
| 921 |
+
mask_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
|
| 922 |
+
if mask_np.ndim == 3:
|
| 923 |
+
mask_np = mask_np.squeeze()
|
| 924 |
+
mask_np = mask_np[..., None]
|
| 925 |
+
base_np = (1.0 - GHOST_TRAIL_ALPHA * mask_np) * base_np + (GHOST_TRAIL_ALPHA * mask_np) * GHOST_TRAIL_COLOR
|
| 926 |
+
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
|
|
|
| 927 |
|
| 928 |
# Draw crosses for conditioning frames only (frames with recorded clicks)
|
| 929 |
clicks_map = state.clicks_by_frame_obj.get(frame_idx)
|
|
|
|
| 1064 |
):
|
| 1065 |
return None
|
| 1066 |
kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame
|
| 1067 |
+
if kick_candidate is None:
|
| 1068 |
return None
|
| 1069 |
|
| 1070 |
+
start_idx = max(int(kick_candidate) + 1, int(frame_idx) + 1)
|
| 1071 |
+
end_idx = state.num_frames
|
| 1072 |
+
if start_idx >= end_idx:
|
| 1073 |
+
return None
|
| 1074 |
trail_mask: np.ndarray | None = None
|
| 1075 |
|
| 1076 |
for idx in range(start_idx, end_idx):
|