Mirko Trasciatti
commited on
Commit
·
c12e0a2
1
Parent(s):
853fa4b
Render ball rings as geometric circles
Browse files
app.py
CHANGED
|
@@ -796,6 +796,7 @@ GHOST_TRAIL_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
|
|
| 796 |
GHOST_TRAIL_ALPHA = 0.55
|
| 797 |
BALL_RING_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
|
| 798 |
BALL_RING_THICKNESS_PX = 2
|
|
|
|
| 799 |
|
| 800 |
|
| 801 |
def _maybe_upscale_for_display(image: Image.Image) -> Image.Image:
|
|
@@ -897,9 +898,6 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 897 |
frame = state.video_frames[frame_idx]
|
| 898 |
masks = state.masks_by_frame.get(frame_idx, {})
|
| 899 |
ball_mask_raw = masks.get(BALL_OBJECT_ID)
|
| 900 |
-
ball_ring_mask: np.ndarray | None = None
|
| 901 |
-
if state.fx_ball_ring_enabled and ball_mask_raw is not None:
|
| 902 |
-
ball_ring_mask = _mask_to_ring(ball_mask_raw, BALL_RING_THICKNESS_PX)
|
| 903 |
out_img: Image.Image | None = state.composited_frames.get(frame_idx)
|
| 904 |
if out_img is None:
|
| 905 |
out_img = frame
|
|
@@ -922,7 +920,12 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 922 |
focus_mask = np.zeros_like(mask_np, dtype=np.float32)
|
| 923 |
focus_mask = np.maximum(focus_mask, mask_np)
|
| 924 |
|
| 925 |
-
ghost_mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
|
| 927 |
if len(masks) != 0:
|
| 928 |
if remove_bg:
|
|
@@ -958,26 +961,13 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 958 |
base_np = (1.0 - alpha) * base_np + alpha * color
|
| 959 |
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
| 960 |
|
| 961 |
-
if state.fx_ball_ring_enabled and ball_ring_mask is not None and np.max(ball_ring_mask) > FX_EPS:
|
| 962 |
-
ring_mask = np.clip(ball_ring_mask, 0.0, 1.0)
|
| 963 |
-
if ring_mask.ndim == 3:
|
| 964 |
-
ring_mask = ring_mask.squeeze()
|
| 965 |
-
ring_pixels = ring_mask > FX_EPS
|
| 966 |
-
if np.any(ring_pixels):
|
| 967 |
-
base_np = np.array(out_img).astype(np.float32) / 255.0
|
| 968 |
-
base_np[ring_pixels] = BALL_RING_COLOR
|
| 969 |
-
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
| 970 |
-
|
| 971 |
if ghost_mask is not None:
|
| 972 |
ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
|
| 973 |
if current_union_mask is not None:
|
| 974 |
ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0)
|
| 975 |
-
|
| 976 |
-
if state.fx_ball_ring_enabled:
|
| 977 |
-
ghost_display = _mask_to_ring(ghost_np, BALL_RING_THICKNESS_PX)
|
| 978 |
-
if ghost_display is not None and ghost_display.max() > FX_EPS:
|
| 979 |
base_np = np.array(out_img).astype(np.float32) / 255.0
|
| 980 |
-
ghost_alpha =
|
| 981 |
base_np = (1.0 - ghost_alpha) * base_np + ghost_alpha * GHOST_TRAIL_COLOR
|
| 982 |
if focus_mask is not None:
|
| 983 |
focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None]
|
|
@@ -985,6 +975,18 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 985 |
base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np
|
| 986 |
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
| 987 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 988 |
# Draw crosses for conditioning frames only (frames with recorded clicks)
|
| 989 |
clicks_map = state.clicks_by_frame_obj.get(frame_idx)
|
| 990 |
if state.show_click_marks and clicks_map:
|
|
@@ -1173,6 +1175,57 @@ def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None
|
|
| 1173 |
return trail_mask
|
| 1174 |
|
| 1175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1176 |
def _ensure_color_for_obj(state: AppState, obj_id: int):
|
| 1177 |
if obj_id not in state.color_by_obj:
|
| 1178 |
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
|
@@ -1222,30 +1275,6 @@ def _apply_radial_falloff(mask: np.ndarray, strength: float = 1.0, solid_ratio:
|
|
| 1222 |
return np.clip(mask_np * falloff, 0.0, 1.0)
|
| 1223 |
|
| 1224 |
|
| 1225 |
-
def _mask_to_ring(mask: np.ndarray, thickness_px: int = BALL_RING_THICKNESS_PX) -> np.ndarray | None:
|
| 1226 |
-
if mask is None:
|
| 1227 |
-
return None
|
| 1228 |
-
mask_np = np.array(mask, dtype=np.float32)
|
| 1229 |
-
if mask_np.ndim == 3:
|
| 1230 |
-
mask_np = mask_np.squeeze()
|
| 1231 |
-
if mask_np.size == 0:
|
| 1232 |
-
return None
|
| 1233 |
-
mask_np = np.clip(mask_np, 0.0, 1.0)
|
| 1234 |
-
if mask_np.max() <= FX_EPS:
|
| 1235 |
-
return np.zeros_like(mask_np, dtype=np.float32)
|
| 1236 |
-
binary = (mask_np > 0.05).astype(np.uint8)
|
| 1237 |
-
thickness = max(1, int(round(thickness_px)))
|
| 1238 |
-
kernel_size = thickness * 2 + 1
|
| 1239 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 1240 |
-
dilated = cv2.dilate(binary, kernel)
|
| 1241 |
-
eroded = cv2.erode(binary, kernel)
|
| 1242 |
-
ring = cv2.subtract(dilated, eroded).astype(np.float32)
|
| 1243 |
-
ring = np.clip(ring, 0.0, 1.0)
|
| 1244 |
-
if np.max(ring) <= FX_EPS:
|
| 1245 |
-
return np.zeros_like(mask_np, dtype=np.float32)
|
| 1246 |
-
return ring
|
| 1247 |
-
|
| 1248 |
-
|
| 1249 |
def _update_centroids_for_frame(state: AppState, frame_idx: int):
|
| 1250 |
if state is None:
|
| 1251 |
return
|
|
|
|
| 796 |
GHOST_TRAIL_ALPHA = 0.55
|
| 797 |
BALL_RING_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
|
| 798 |
BALL_RING_THICKNESS_PX = 2
|
| 799 |
+
BALL_RING_COLOR_RGB = tuple(int(max(0, min(255, round(c * 255.0)))) for c in BALL_RING_COLOR.tolist())
|
| 800 |
|
| 801 |
|
| 802 |
def _maybe_upscale_for_display(image: Image.Image) -> Image.Image:
|
|
|
|
| 898 |
frame = state.video_frames[frame_idx]
|
| 899 |
masks = state.masks_by_frame.get(frame_idx, {})
|
| 900 |
ball_mask_raw = masks.get(BALL_OBJECT_ID)
|
|
|
|
|
|
|
|
|
|
| 901 |
out_img: Image.Image | None = state.composited_frames.get(frame_idx)
|
| 902 |
if out_img is None:
|
| 903 |
out_img = frame
|
|
|
|
| 920 |
focus_mask = np.zeros_like(mask_np, dtype=np.float32)
|
| 921 |
focus_mask = np.maximum(focus_mask, mask_np)
|
| 922 |
|
| 923 |
+
ghost_mask: np.ndarray | None = None
|
| 924 |
+
circle_trail: list[tuple[float, float, float]] = []
|
| 925 |
+
if state.fx_ball_ring_enabled:
|
| 926 |
+
circle_trail = _collect_ball_trail_circles(state, frame_idx)
|
| 927 |
+
else:
|
| 928 |
+
ghost_mask = _build_ball_trail_mask(state, frame_idx)
|
| 929 |
|
| 930 |
if len(masks) != 0:
|
| 931 |
if remove_bg:
|
|
|
|
| 961 |
base_np = (1.0 - alpha) * base_np + alpha * color
|
| 962 |
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
| 963 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 964 |
if ghost_mask is not None:
|
| 965 |
ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
|
| 966 |
if current_union_mask is not None:
|
| 967 |
ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0)
|
| 968 |
+
if ghost_np.max() > FX_EPS:
|
|
|
|
|
|
|
|
|
|
| 969 |
base_np = np.array(out_img).astype(np.float32) / 255.0
|
| 970 |
+
ghost_alpha = ghost_np[..., None] * GHOST_TRAIL_ALPHA
|
| 971 |
base_np = (1.0 - ghost_alpha) * base_np + ghost_alpha * GHOST_TRAIL_COLOR
|
| 972 |
if focus_mask is not None:
|
| 973 |
focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None]
|
|
|
|
| 975 |
base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np
|
| 976 |
out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
|
| 977 |
|
| 978 |
+
if state.fx_ball_ring_enabled and circle_trail:
|
| 979 |
+
draw = ImageDraw.Draw(out_img)
|
| 980 |
+
ring_width = max(1, int(round(BALL_RING_THICKNESS_PX)))
|
| 981 |
+
for cx, cy, radius in circle_trail:
|
| 982 |
+
if radius <= 1.0:
|
| 983 |
+
continue
|
| 984 |
+
left = cx - radius
|
| 985 |
+
top = cy - radius
|
| 986 |
+
right = cx + radius
|
| 987 |
+
bottom = cy + radius
|
| 988 |
+
draw.ellipse((left, top, right, bottom), outline=BALL_RING_COLOR_RGB, width=ring_width)
|
| 989 |
+
|
| 990 |
# Draw crosses for conditioning frames only (frames with recorded clicks)
|
| 991 |
clicks_map = state.clicks_by_frame_obj.get(frame_idx)
|
| 992 |
if state.show_click_marks and clicks_map:
|
|
|
|
| 1175 |
return trail_mask
|
| 1176 |
|
| 1177 |
|
| 1178 |
+
def _collect_ball_trail_circles(state: AppState, frame_idx: int) -> list[tuple[float, float, float]]:
|
| 1179 |
+
if (
|
| 1180 |
+
state is None
|
| 1181 |
+
or not state.fx_ghost_trail_enabled
|
| 1182 |
+
or state.masks_by_frame is None
|
| 1183 |
+
):
|
| 1184 |
+
return []
|
| 1185 |
+
kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame
|
| 1186 |
+
if kick_candidate is None:
|
| 1187 |
+
return []
|
| 1188 |
+
|
| 1189 |
+
start_idx = max(int(kick_candidate) + 1, int(frame_idx) + 1)
|
| 1190 |
+
end_idx = state.num_frames
|
| 1191 |
+
if start_idx >= end_idx:
|
| 1192 |
+
return []
|
| 1193 |
+
|
| 1194 |
+
circles: list[tuple[float, float, float]] = []
|
| 1195 |
+
for idx in range(start_idx, end_idx):
|
| 1196 |
+
frame_masks = state.masks_by_frame.get(idx)
|
| 1197 |
+
if not frame_masks:
|
| 1198 |
+
continue
|
| 1199 |
+
mask = frame_masks.get(BALL_OBJECT_ID)
|
| 1200 |
+
if mask is None:
|
| 1201 |
+
continue
|
| 1202 |
+
mask_np = np.array(mask, dtype=np.float32)
|
| 1203 |
+
if mask_np.ndim == 3:
|
| 1204 |
+
mask_np = mask_np.squeeze()
|
| 1205 |
+
if mask_np.size == 0:
|
| 1206 |
+
continue
|
| 1207 |
+
mask_np = np.clip(mask_np, 0.0, 1.0)
|
| 1208 |
+
if mask_np.max() <= FX_EPS:
|
| 1209 |
+
continue
|
| 1210 |
+
centroid = _compute_mask_centroid(mask_np)
|
| 1211 |
+
if centroid is None:
|
| 1212 |
+
continue
|
| 1213 |
+
cx, cy = centroid
|
| 1214 |
+
ys, xs = np.nonzero(mask_np > 0.05)
|
| 1215 |
+
if xs.size == 0 or ys.size == 0:
|
| 1216 |
+
continue
|
| 1217 |
+
min_x, max_x = xs.min(), xs.max()
|
| 1218 |
+
min_y, max_y = ys.min(), ys.max()
|
| 1219 |
+
radius_x = (max_x - min_x + 1) / 2.0
|
| 1220 |
+
radius_y = (max_y - min_y + 1) / 2.0
|
| 1221 |
+
radius = float(max(radius_x, radius_y))
|
| 1222 |
+
if radius <= 1.0:
|
| 1223 |
+
continue
|
| 1224 |
+
circles.append((float(cx), float(cy), radius))
|
| 1225 |
+
|
| 1226 |
+
return circles
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
def _ensure_color_for_obj(state: AppState, obj_id: int):
|
| 1230 |
if obj_id not in state.color_by_obj:
|
| 1231 |
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
|
|
|
| 1275 |
return np.clip(mask_np * falloff, 0.0, 1.0)
|
| 1276 |
|
| 1277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1278 |
def _update_centroids_for_frame(state: AppState, frame_idx: int):
|
| 1279 |
if state is None:
|
| 1280 |
return
|