Mirko Trasciatti commited on
Commit
c12e0a2
·
1 Parent(s): 853fa4b

Render ball rings as geometric circles

Browse files
Files changed (1) hide show
  1. app.py +72 -43
app.py CHANGED
@@ -796,6 +796,7 @@ GHOST_TRAIL_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
796
  GHOST_TRAIL_ALPHA = 0.55
797
  BALL_RING_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
798
  BALL_RING_THICKNESS_PX = 2
 
799
 
800
 
801
  def _maybe_upscale_for_display(image: Image.Image) -> Image.Image:
@@ -897,9 +898,6 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
897
  frame = state.video_frames[frame_idx]
898
  masks = state.masks_by_frame.get(frame_idx, {})
899
  ball_mask_raw = masks.get(BALL_OBJECT_ID)
900
- ball_ring_mask: np.ndarray | None = None
901
- if state.fx_ball_ring_enabled and ball_mask_raw is not None:
902
- ball_ring_mask = _mask_to_ring(ball_mask_raw, BALL_RING_THICKNESS_PX)
903
  out_img: Image.Image | None = state.composited_frames.get(frame_idx)
904
  if out_img is None:
905
  out_img = frame
@@ -922,7 +920,12 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
922
  focus_mask = np.zeros_like(mask_np, dtype=np.float32)
923
  focus_mask = np.maximum(focus_mask, mask_np)
924
 
925
- ghost_mask = _build_ball_trail_mask(state, frame_idx)
 
 
 
 
 
926
 
927
  if len(masks) != 0:
928
  if remove_bg:
@@ -958,26 +961,13 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
958
  base_np = (1.0 - alpha) * base_np + alpha * color
959
  out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
960
 
961
- if state.fx_ball_ring_enabled and ball_ring_mask is not None and np.max(ball_ring_mask) > FX_EPS:
962
- ring_mask = np.clip(ball_ring_mask, 0.0, 1.0)
963
- if ring_mask.ndim == 3:
964
- ring_mask = ring_mask.squeeze()
965
- ring_pixels = ring_mask > FX_EPS
966
- if np.any(ring_pixels):
967
- base_np = np.array(out_img).astype(np.float32) / 255.0
968
- base_np[ring_pixels] = BALL_RING_COLOR
969
- out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
970
-
971
  if ghost_mask is not None:
972
  ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
973
  if current_union_mask is not None:
974
  ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0)
975
- ghost_display = ghost_np
976
- if state.fx_ball_ring_enabled:
977
- ghost_display = _mask_to_ring(ghost_np, BALL_RING_THICKNESS_PX)
978
- if ghost_display is not None and ghost_display.max() > FX_EPS:
979
  base_np = np.array(out_img).astype(np.float32) / 255.0
980
- ghost_alpha = np.clip(ghost_display, 0.0, 1.0)[..., None] * GHOST_TRAIL_ALPHA
981
  base_np = (1.0 - ghost_alpha) * base_np + ghost_alpha * GHOST_TRAIL_COLOR
982
  if focus_mask is not None:
983
  focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None]
@@ -985,6 +975,18 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
985
  base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np
986
  out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
987
 
 
 
 
 
 
 
 
 
 
 
 
 
988
  # Draw crosses for conditioning frames only (frames with recorded clicks)
989
  clicks_map = state.clicks_by_frame_obj.get(frame_idx)
990
  if state.show_click_marks and clicks_map:
@@ -1173,6 +1175,57 @@ def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None
1173
  return trail_mask
1174
 
1175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1176
  def _ensure_color_for_obj(state: AppState, obj_id: int):
1177
  if obj_id not in state.color_by_obj:
1178
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
@@ -1222,30 +1275,6 @@ def _apply_radial_falloff(mask: np.ndarray, strength: float = 1.0, solid_ratio:
1222
  return np.clip(mask_np * falloff, 0.0, 1.0)
1223
 
1224
 
1225
- def _mask_to_ring(mask: np.ndarray, thickness_px: int = BALL_RING_THICKNESS_PX) -> np.ndarray | None:
1226
- if mask is None:
1227
- return None
1228
- mask_np = np.array(mask, dtype=np.float32)
1229
- if mask_np.ndim == 3:
1230
- mask_np = mask_np.squeeze()
1231
- if mask_np.size == 0:
1232
- return None
1233
- mask_np = np.clip(mask_np, 0.0, 1.0)
1234
- if mask_np.max() <= FX_EPS:
1235
- return np.zeros_like(mask_np, dtype=np.float32)
1236
- binary = (mask_np > 0.05).astype(np.uint8)
1237
- thickness = max(1, int(round(thickness_px)))
1238
- kernel_size = thickness * 2 + 1
1239
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
1240
- dilated = cv2.dilate(binary, kernel)
1241
- eroded = cv2.erode(binary, kernel)
1242
- ring = cv2.subtract(dilated, eroded).astype(np.float32)
1243
- ring = np.clip(ring, 0.0, 1.0)
1244
- if np.max(ring) <= FX_EPS:
1245
- return np.zeros_like(mask_np, dtype=np.float32)
1246
- return ring
1247
-
1248
-
1249
  def _update_centroids_for_frame(state: AppState, frame_idx: int):
1250
  if state is None:
1251
  return
 
796
  GHOST_TRAIL_ALPHA = 0.55
797
  BALL_RING_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32)
798
  BALL_RING_THICKNESS_PX = 2
799
+ BALL_RING_COLOR_RGB = tuple(int(max(0, min(255, round(c * 255.0)))) for c in BALL_RING_COLOR.tolist())
800
 
801
 
802
  def _maybe_upscale_for_display(image: Image.Image) -> Image.Image:
 
898
  frame = state.video_frames[frame_idx]
899
  masks = state.masks_by_frame.get(frame_idx, {})
900
  ball_mask_raw = masks.get(BALL_OBJECT_ID)
 
 
 
901
  out_img: Image.Image | None = state.composited_frames.get(frame_idx)
902
  if out_img is None:
903
  out_img = frame
 
920
  focus_mask = np.zeros_like(mask_np, dtype=np.float32)
921
  focus_mask = np.maximum(focus_mask, mask_np)
922
 
923
+ ghost_mask: np.ndarray | None = None
924
+ circle_trail: list[tuple[float, float, float]] = []
925
+ if state.fx_ball_ring_enabled:
926
+ circle_trail = _collect_ball_trail_circles(state, frame_idx)
927
+ else:
928
+ ghost_mask = _build_ball_trail_mask(state, frame_idx)
929
 
930
  if len(masks) != 0:
931
  if remove_bg:
 
961
  base_np = (1.0 - alpha) * base_np + alpha * color
962
  out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
963
 
 
 
 
 
 
 
 
 
 
 
964
  if ghost_mask is not None:
965
  ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0)
966
  if current_union_mask is not None:
967
  ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0)
968
+ if ghost_np.max() > FX_EPS:
 
 
 
969
  base_np = np.array(out_img).astype(np.float32) / 255.0
970
+ ghost_alpha = ghost_np[..., None] * GHOST_TRAIL_ALPHA
971
  base_np = (1.0 - ghost_alpha) * base_np + ghost_alpha * GHOST_TRAIL_COLOR
972
  if focus_mask is not None:
973
  focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None]
 
975
  base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np
976
  out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8))
977
 
978
+ if state.fx_ball_ring_enabled and circle_trail:
979
+ draw = ImageDraw.Draw(out_img)
980
+ ring_width = max(1, int(round(BALL_RING_THICKNESS_PX)))
981
+ for cx, cy, radius in circle_trail:
982
+ if radius <= 1.0:
983
+ continue
984
+ left = cx - radius
985
+ top = cy - radius
986
+ right = cx + radius
987
+ bottom = cy + radius
988
+ draw.ellipse((left, top, right, bottom), outline=BALL_RING_COLOR_RGB, width=ring_width)
989
+
990
  # Draw crosses for conditioning frames only (frames with recorded clicks)
991
  clicks_map = state.clicks_by_frame_obj.get(frame_idx)
992
  if state.show_click_marks and clicks_map:
 
1175
  return trail_mask
1176
 
1177
 
1178
+ def _collect_ball_trail_circles(state: AppState, frame_idx: int) -> list[tuple[float, float, float]]:
1179
+ if (
1180
+ state is None
1181
+ or not state.fx_ghost_trail_enabled
1182
+ or state.masks_by_frame is None
1183
+ ):
1184
+ return []
1185
+ kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame
1186
+ if kick_candidate is None:
1187
+ return []
1188
+
1189
+ start_idx = max(int(kick_candidate) + 1, int(frame_idx) + 1)
1190
+ end_idx = state.num_frames
1191
+ if start_idx >= end_idx:
1192
+ return []
1193
+
1194
+ circles: list[tuple[float, float, float]] = []
1195
+ for idx in range(start_idx, end_idx):
1196
+ frame_masks = state.masks_by_frame.get(idx)
1197
+ if not frame_masks:
1198
+ continue
1199
+ mask = frame_masks.get(BALL_OBJECT_ID)
1200
+ if mask is None:
1201
+ continue
1202
+ mask_np = np.array(mask, dtype=np.float32)
1203
+ if mask_np.ndim == 3:
1204
+ mask_np = mask_np.squeeze()
1205
+ if mask_np.size == 0:
1206
+ continue
1207
+ mask_np = np.clip(mask_np, 0.0, 1.0)
1208
+ if mask_np.max() <= FX_EPS:
1209
+ continue
1210
+ centroid = _compute_mask_centroid(mask_np)
1211
+ if centroid is None:
1212
+ continue
1213
+ cx, cy = centroid
1214
+ ys, xs = np.nonzero(mask_np > 0.05)
1215
+ if xs.size == 0 or ys.size == 0:
1216
+ continue
1217
+ min_x, max_x = xs.min(), xs.max()
1218
+ min_y, max_y = ys.min(), ys.max()
1219
+ radius_x = (max_x - min_x + 1) / 2.0
1220
+ radius_y = (max_y - min_y + 1) / 2.0
1221
+ radius = float(max(radius_x, radius_y))
1222
+ if radius <= 1.0:
1223
+ continue
1224
+ circles.append((float(cx), float(cy), radius))
1225
+
1226
+ return circles
1227
+
1228
+
1229
  def _ensure_color_for_obj(state: AppState, obj_id: int):
1230
  if obj_id not in state.color_by_obj:
1231
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
 
1275
  return np.clip(mask_np * falloff, 0.0, 1.0)
1276
 
1277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1278
  def _update_centroids_for_frame(state: AppState, frame_idx: int):
1279
  if state is None:
1280
  return