ASethi04 commited on
Commit
21f4849
·
1 Parent(s): 17f21ca
app.py CHANGED
@@ -60,206 +60,6 @@ print(
60
  )
61
 
62
 
63
- def format_summary(summary, binary_confidence_threshold=0.8):
64
- """
65
- Format the summary dictionary into a readable markdown string.
66
- Filters binary relations by confidence threshold.
67
- """
68
- if not summary or not isinstance(summary, dict):
69
- return "# Detection Summary\n\nNo events detected or processing in progress..."
70
-
71
- output_lines = ["# Detection Summary\n"]
72
- has_content = False
73
-
74
- # Categorical keywords
75
- if "categorical_keywords" in summary and summary["categorical_keywords"]:
76
- output_lines.append("## Categorical Keywords\n")
77
- cate = summary["categorical_keywords"]
78
- if isinstance(cate, dict) and cate:
79
- has_content = True
80
- for kw, info in cate.items():
81
- output_lines.append(f"**{kw}**")
82
- if isinstance(info, dict):
83
- for key, val in info.items():
84
- output_lines.append(f" - {key}: {val}")
85
- else:
86
- output_lines.append(f" - {info}")
87
- output_lines.append("")
88
- elif isinstance(cate, list) and cate:
89
- has_content = True
90
- for item in cate:
91
- output_lines.append(f"- {item}")
92
- output_lines.append("")
93
-
94
- # Unary keywords
95
- if "unary_keywords" in summary and summary["unary_keywords"]:
96
- output_lines.append("## Unary Keywords\n")
97
- unary = summary["unary_keywords"]
98
- if isinstance(unary, dict) and unary:
99
- has_content = True
100
- for kw, info in unary.items():
101
- output_lines.append(f"**{kw}**")
102
- if isinstance(info, dict):
103
- for key, val in info.items():
104
- output_lines.append(f" - {key}: {val}")
105
- else:
106
- output_lines.append(f" - {info}")
107
- output_lines.append("")
108
- elif isinstance(unary, list) and unary:
109
- has_content = True
110
- for item in unary:
111
- output_lines.append(f"- {item}")
112
- output_lines.append("")
113
-
114
- # Binary keywords - show ALL binary relations for debugging
115
- print(f"DEBUG: Checking binary_keywords...")
116
- print(f" 'binary_keywords' in summary: {'binary_keywords' in summary}")
117
- if 'binary_keywords' in summary:
118
- print(f" summary['binary_keywords'] truthy: {bool(summary['binary_keywords'])}")
119
- print(f" summary['binary_keywords'] type: {type(summary['binary_keywords'])}")
120
- print(f" summary['binary_keywords'] value: {summary['binary_keywords']}")
121
-
122
- if "binary_keywords" in summary and summary["binary_keywords"]:
123
- output_lines.append(f"## Binary Keywords\n")
124
- binary = summary["binary_keywords"]
125
- print(f"DEBUG: Processing binary keywords, type: {type(binary)}, length: {len(binary) if isinstance(binary, (dict, list)) else 'N/A'}")
126
- if isinstance(binary, dict) and binary:
127
- has_content = True
128
- # Show all binary relations, sorted by confidence
129
- binary_items = []
130
- for kw, info in binary.items():
131
- if isinstance(info, dict):
132
- confidence = info.get("confidence", info.get("score", 0))
133
- binary_items.append((kw, info, confidence))
134
- else:
135
- binary_items.append((kw, info, 0))
136
-
137
- # Sort by confidence descending
138
- binary_items.sort(key=lambda x: x[2], reverse=True)
139
-
140
- high_conf_count = 0
141
- low_conf_count = 0
142
-
143
- # Show high confidence items first
144
- output_lines.append(f"### High Confidence (≥ {binary_confidence_threshold})\n")
145
- for kw, info, confidence in binary_items:
146
- if confidence >= binary_confidence_threshold:
147
- high_conf_count += 1
148
- if isinstance(info, dict):
149
- output_lines.append(f"**{kw}** (confidence: {confidence:.2f})")
150
- for key, val in info.items():
151
- if key not in ["confidence", "score"]:
152
- output_lines.append(f" - {key}: {val}")
153
- else:
154
- output_lines.append(f"**{kw}**: {info}")
155
- output_lines.append("")
156
-
157
- if high_conf_count == 0:
158
- output_lines.append(f"*No binary relations found with confidence ≥ {binary_confidence_threshold}*\n")
159
-
160
- # Show lower confidence items for debugging
161
- output_lines.append(f"### Lower Confidence (< {binary_confidence_threshold})\n")
162
- for kw, info, confidence in binary_items:
163
- if confidence < binary_confidence_threshold:
164
- low_conf_count += 1
165
- if isinstance(info, dict):
166
- output_lines.append(f"**{kw}** (confidence: {confidence:.2f})")
167
- for key, val in info.items():
168
- if key not in ["confidence", "score"]:
169
- output_lines.append(f" - {key}: {val}")
170
- else:
171
- output_lines.append(f"**{kw}**: {info}")
172
- output_lines.append("")
173
-
174
- if low_conf_count == 0:
175
- output_lines.append(f"*No binary relations found with confidence < {binary_confidence_threshold}*\n")
176
-
177
- output_lines.append(f"**Total binary relations detected: {len(binary_items)}**\n")
178
- elif isinstance(binary, list) and binary:
179
- has_content = True
180
- for item in binary:
181
- output_lines.append(f"- {item}")
182
- output_lines.append("")
183
-
184
- # Object pairs - show ALL object pair interactions for debugging
185
- print(f"DEBUG: Checking object_pairs...")
186
- print(f" 'object_pairs' in summary: {'object_pairs' in summary}")
187
- if 'object_pairs' in summary:
188
- print(f" summary['object_pairs'] truthy: {bool(summary['object_pairs'])}")
189
- print(f" summary['object_pairs'] type: {type(summary['object_pairs'])}")
190
- print(f" summary['object_pairs'] value: {summary['object_pairs']}")
191
-
192
- if "object_pairs" in summary and summary["object_pairs"]:
193
- output_lines.append(f"## Object Pair Interactions\n")
194
- pairs = summary["object_pairs"]
195
- print(f"DEBUG: Processing object pairs, type: {type(pairs)}, length: {len(pairs) if isinstance(pairs, (dict, list)) else 'N/A'}")
196
- if isinstance(pairs, dict) and pairs:
197
- has_content = True
198
- # Show all object pairs, sorted by confidence
199
- pair_items = []
200
- for pair, info in pairs.items():
201
- if isinstance(info, dict):
202
- confidence = info.get("confidence", info.get("score", 0))
203
- pair_items.append((pair, info, confidence))
204
- else:
205
- pair_items.append((pair, info, 0))
206
-
207
- # Sort by confidence descending
208
- pair_items.sort(key=lambda x: x[2], reverse=True)
209
-
210
- high_conf_count = 0
211
- low_conf_count = 0
212
-
213
- # Show high confidence items first
214
- output_lines.append(f"### High Confidence (≥ {binary_confidence_threshold})\n")
215
- for pair, info, confidence in pair_items:
216
- if confidence >= binary_confidence_threshold:
217
- high_conf_count += 1
218
- if isinstance(info, dict):
219
- output_lines.append(f"**{pair}** (confidence: {confidence:.2f})")
220
- for key, val in info.items():
221
- if key not in ["confidence", "score"]:
222
- output_lines.append(f" - {key}: {val}")
223
- else:
224
- output_lines.append(f"**{pair}**: {info}")
225
- output_lines.append("")
226
-
227
- if high_conf_count == 0:
228
- output_lines.append(f"*No object pairs found with confidence ≥ {binary_confidence_threshold}*\n")
229
-
230
- # Show lower confidence items for debugging
231
- output_lines.append(f"### Lower Confidence (< {binary_confidence_threshold})\n")
232
- for pair, info, confidence in pair_items:
233
- if confidence < binary_confidence_threshold:
234
- low_conf_count += 1
235
- if isinstance(info, dict):
236
- output_lines.append(f"**{pair}** (confidence: {confidence:.2f})")
237
- for key, val in info.items():
238
- if key not in ["confidence", "score"]:
239
- output_lines.append(f" - {key}: {val}")
240
- else:
241
- output_lines.append(f"**{pair}**: {info}")
242
- output_lines.append("")
243
-
244
- if low_conf_count == 0:
245
- output_lines.append(f"*No object pairs found with confidence < {binary_confidence_threshold}*\n")
246
-
247
- output_lines.append(f"**Total object pairs detected: {len(pair_items)}**\n")
248
- elif isinstance(pairs, list) and pairs:
249
- has_content = True
250
- for item in pairs:
251
- output_lines.append(f"- {item}")
252
- output_lines.append("")
253
-
254
- # If no content was added, show the raw summary for debugging
255
- if not has_content:
256
- output_lines.append("## Raw Summary Data\n")
257
- output_lines.append("```json")
258
- import json
259
- output_lines.append(json.dumps(summary, indent=2, default=str))
260
- output_lines.append("```")
261
-
262
- return "\n".join(output_lines)
263
 
264
 
265
  @lru_cache(maxsize=1)
@@ -394,9 +194,10 @@ def process_video(
394
  summary = results_dict.get("summary") or {}
395
 
396
  if result_video_path and os.path.exists(result_video_path):
397
- gradio_tmp = Path(
398
- os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
399
- ) / "vine_outputs"
 
400
  gradio_tmp.mkdir(parents=True, exist_ok=True)
401
  dest_path = gradio_tmp / Path(result_video_path).name
402
  try:
@@ -411,47 +212,7 @@ def process_video(
411
  "Warning: annotated video not found or empty; check visualization settings."
412
  )
413
 
414
- # Debug: Print summary structure
415
- import json
416
- print("=" * 80)
417
- print("SUMMARY DEBUG OUTPUT:")
418
- print(f"Summary type: {type(summary)}")
419
- print(f"Summary keys: {summary.keys() if isinstance(summary, dict) else 'N/A'}")
420
- if isinstance(summary, dict):
421
- print("\nFULL SUMMARY JSON:")
422
- print(json.dumps(summary, indent=2, default=str))
423
- print("\n" + "=" * 80)
424
-
425
- # Check for any keys that might contain binary relation data
426
- print("\nLOOKING FOR BINARY RELATION DATA:")
427
- possible_keys = ['binary', 'binary_keywords', 'binary_relations', 'object_pairs',
428
- 'pairs', 'relations', 'interactions', 'pairwise']
429
- for pkey in possible_keys:
430
- if pkey in summary:
431
- print(f" FOUND: '{pkey}' -> {summary[pkey]}")
432
-
433
- print("\nALL KEYS IN SUMMARY:")
434
- for key in summary.keys():
435
- print(f"\n{key}:")
436
- print(f" Type: {type(summary[key])}")
437
- if isinstance(summary[key], dict):
438
- print(f" Length: {len(summary[key])}")
439
- print(f" Keys (first 10): {list(summary[key].keys())[:10]}")
440
- # Print all items for anything that might be binary relations
441
- if any(term in key.lower() for term in ['binary', 'pair', 'relation', 'interaction']):
442
- print(f" ALL ITEMS:")
443
- for k, v in list(summary[key].items())[:20]: # First 20 items
444
- print(f" {k}: {v}")
445
- else:
446
- print(f" Sample: {dict(list(summary[key].items())[:2])}")
447
- elif isinstance(summary[key], list):
448
- print(f" Length: {len(summary[key])}")
449
- print(f" Sample: {summary[key][:2]}")
450
- print("=" * 80)
451
-
452
- # Format summary as readable markdown text, filtering by confidence threshold
453
- formatted_summary = format_summary(summary, binary_confidence_threshold)
454
- return video_path_for_ui, formatted_summary
455
 
456
 
457
  def _video_component(label: str, *, is_output: bool = False):
@@ -523,25 +284,25 @@ with _create_blocks() as demo:
523
  label="Categorical Keywords",
524
  placeholder="e.g., person, car, dog",
525
  value="person, car, dog",
526
- info="Objects to detect in the video (comma-separated)"
527
  )
528
  unary_input = gr.Textbox(
529
  label="Unary Keywords",
530
  placeholder="e.g., walking, running, standing",
531
  value="walking, running, standing",
532
- info="Single-object actions to detect (comma-separated)"
533
  )
534
  binary_input = gr.Textbox(
535
  label="Binary Keywords",
536
  placeholder="e.g., chasing, carrying",
537
- info="Object-to-object interactions to detect (comma-separated)"
538
  )
539
 
540
  gr.Markdown("#### Processing Settings")
541
  fps_input = gr.Number(
542
  label="Output FPS",
543
  value=1,
544
- info="Frames per second for processing (lower = faster)"
545
  )
546
 
547
  with gr.Accordion("Advanced Settings", open=False):
@@ -551,7 +312,7 @@ with _create_blocks() as demo:
551
  maximum=0.9,
552
  value=0.35,
553
  step=0.05,
554
- info="Confidence threshold for object detection"
555
  )
556
  text_threshold_input = gr.Slider(
557
  label="Text Threshold",
@@ -559,7 +320,7 @@ with _create_blocks() as demo:
559
  maximum=0.9,
560
  value=0.25,
561
  step=0.05,
562
- info="Confidence threshold for text-based detection"
563
  )
564
  binary_confidence_input = gr.Slider(
565
  label="Binary Relation Confidence Threshold",
@@ -567,7 +328,7 @@ with _create_blocks() as demo:
567
  maximum=1.0,
568
  value=0.8,
569
  step=0.05,
570
- info="Minimum confidence to show binary relations and object pairs"
571
  )
572
 
573
  submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
@@ -579,10 +340,7 @@ with _create_blocks() as demo:
579
  video_output = _video_component("Annotated Video Output", is_output=True)
580
 
581
  gr.Markdown("### Detection Summary")
582
- summary_output = gr.Markdown(
583
- value="Results will appear here after processing...",
584
- elem_classes=["summary-output"]
585
- )
586
 
587
  gr.Markdown(
588
  """
 
60
  )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  @lru_cache(maxsize=1)
 
194
  summary = results_dict.get("summary") or {}
195
 
196
  if result_video_path and os.path.exists(result_video_path):
197
+ gradio_tmp = (
198
+ Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
199
+ / "vine_outputs"
200
+ )
201
  gradio_tmp.mkdir(parents=True, exist_ok=True)
202
  dest_path = gradio_tmp / Path(result_video_path).name
203
  try:
 
212
  "Warning: annotated video not found or empty; check visualization settings."
213
  )
214
 
215
+ return video_path_for_ui, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
 
218
  def _video_component(label: str, *, is_output: bool = False):
 
284
  label="Categorical Keywords",
285
  placeholder="e.g., person, car, dog",
286
  value="person, car, dog",
287
+ info="Objects to detect in the video (comma-separated)",
288
  )
289
  unary_input = gr.Textbox(
290
  label="Unary Keywords",
291
  placeholder="e.g., walking, running, standing",
292
  value="walking, running, standing",
293
+ info="Single-object actions to detect (comma-separated)",
294
  )
295
  binary_input = gr.Textbox(
296
  label="Binary Keywords",
297
  placeholder="e.g., chasing, carrying",
298
+ info="Object-to-object interactions to detect (comma-separated)",
299
  )
300
 
301
  gr.Markdown("#### Processing Settings")
302
  fps_input = gr.Number(
303
  label="Output FPS",
304
  value=1,
305
+ info="Frames per second for processing (lower = faster)",
306
  )
307
 
308
  with gr.Accordion("Advanced Settings", open=False):
 
312
  maximum=0.9,
313
  value=0.35,
314
  step=0.05,
315
+ info="Confidence threshold for object detection",
316
  )
317
  text_threshold_input = gr.Slider(
318
  label="Text Threshold",
 
320
  maximum=0.9,
321
  value=0.25,
322
  step=0.05,
323
+ info="Confidence threshold for text-based detection",
324
  )
325
  binary_confidence_input = gr.Slider(
326
  label="Binary Relation Confidence Threshold",
 
328
  maximum=1.0,
329
  value=0.8,
330
  step=0.05,
331
+ info="Minimum confidence to show binary relations and object pairs",
332
  )
333
 
334
  submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
 
340
  video_output = _video_component("Annotated Video Output", is_output=True)
341
 
342
  gr.Markdown("### Detection Summary")
343
+ summary_output = gr.JSON(label="Summary of Detected Events")
 
 
 
344
 
345
  gr.Markdown(
346
  """
outputs/debug_crops/frame_0_obj_0.jpg CHANGED
outputs/debug_crops/frame_0_obj_1.jpg CHANGED
outputs/debug_crops/frame_0_obj_2.jpg CHANGED
outputs/debug_crops/frame_0_obj_3.jpg CHANGED
outputs/debug_crops/frame_0_obj_4.jpg CHANGED
outputs/debug_crops/frame_0_obj_5.jpg CHANGED
outputs/debug_crops/frame_1_obj_0.jpg CHANGED
outputs/debug_crops/frame_1_obj_1.jpg CHANGED
outputs/debug_crops/frame_1_obj_2.jpg CHANGED
outputs/debug_crops/frame_1_obj_3.jpg CHANGED
outputs/debug_crops/frame_1_obj_4.jpg CHANGED
outputs/debug_crops/frame_1_obj_5.jpg CHANGED
outputs/debug_crops/frame_1_obj_6.jpg CHANGED
src/LASER/laser/models/model_utils.py CHANGED
@@ -6,20 +6,22 @@ import torch
6
  import jax.numpy as jnp
7
  import jax
8
 
 
9
  def increase_brightness(img, alpha=0.2):
10
  height, width, _ = img.shape
11
- white_img = np.zeros([height,width,3],dtype=np.uint8)
12
- white_img.fill(255) # or img[:] = 255
13
 
14
- dst = cv2.addWeighted(img, alpha , white_img, 1-alpha, 0)
15
  return dst
16
 
 
17
  def increase_brightness_except(img, bbox_ls, alpha=0.2):
18
  height, width, _ = img.shape
19
- white_img = np.zeros([height,width,3],dtype=np.uint8)
20
- white_img.fill(255) # or img[:] = 255
21
 
22
- output_img = cv2.addWeighted(img, alpha , white_img, 1-alpha, 0)
23
 
24
  for x1, y1, x2, y2 in bbox_ls:
25
  output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
@@ -28,12 +30,12 @@ def increase_brightness_except(img, bbox_ls, alpha=0.2):
28
 
29
  def extract_single_object(img, mask, alpha=0.8):
30
  """OpenCV version of extract_single_object that works with numpy arrays.
31
-
32
  Args:
33
  img: numpy array of shape (height, width, 3)
34
  mask: numpy array of shape (height, width, 1) or (height, width)
35
  alpha: float between 0 and 1 for blending
36
-
37
  Returns:
38
  numpy array of shape (height, width, 3)
39
  """
@@ -51,18 +53,21 @@ def extract_single_object(img, mask, alpha=0.8):
51
  masked_white_img = np.where(mask, white_img, img)
52
 
53
  # Blend the original image with the masked white image
54
- output_img = cv2.addWeighted(img.astype(np.uint8), 1-alpha, masked_white_img.astype(np.uint8), alpha, 0)
 
 
55
 
56
  return output_img
57
 
 
58
  def extract_single_object_jax(img, mask, alpha=0.8):
59
  """JAX version of extract_single_object that works with JAX arrays.
60
-
61
  Args:
62
  img: JAX array of shape (height, width, 3)
63
  mask: JAX array of shape (height, width, 1) or (height, width)
64
  alpha: float between 0 and 1 for blending
65
-
66
  Returns:
67
  JAX array of shape (height, width, 3)
68
  """
@@ -80,10 +85,11 @@ def extract_single_object_jax(img, mask, alpha=0.8):
80
  masked_white_img = jnp.where(mask, white_img, img)
81
 
82
  # Blend the original image with the masked white image
83
- output_img = img * (1-alpha) + masked_white_img * alpha
84
 
85
  return output_img
86
 
 
87
  def crop_image_contain_bboxes(img, bbox_ls, data_id):
88
  all_bx1 = []
89
  all_by1 = []
@@ -92,9 +98,11 @@ def crop_image_contain_bboxes(img, bbox_ls, data_id):
92
 
93
  for bbox in bbox_ls:
94
  if isinstance(bbox, dict):
95
- bx1, by1, bx2, by2 = bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2']
96
  elif isinstance(bbox, (list, tuple, np.ndarray)):
97
- bx1, by1, bx2, by2 = map(int, bbox[:4]) # Convert first 4 elements to integers
 
 
98
  else:
99
  raise ValueError(f"Unsupported bbox format: {type(bbox)}")
100
 
@@ -111,13 +119,36 @@ def crop_image_contain_bboxes(img, bbox_ls, data_id):
111
  y1 = min(all_by1)
112
  y2 = max(all_by2)
113
 
114
- assert(x1 < x2), f"image bbox issue: {data_id}"
115
- assert(y1 < y2), f"image bbox issue: {data_id}"
116
 
117
  return img[y1:y2, x1:x2]
118
 
 
 
 
 
 
119
  def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
120
- # Ensure the masks are 2D and binary (0 or 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if red_mask.ndim == 3:
122
  red_mask = red_mask[:, :, 0]
123
  if blue_mask.ndim == 3:
@@ -125,44 +156,62 @@ def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8)
125
 
126
  red_mask = red_mask.astype(bool)
127
  blue_mask = blue_mask.astype(bool)
 
 
128
  non_masked_area = ~(red_mask | blue_mask)
129
 
130
- # Split the image into its color channels (B, G, R)
131
  b, g, r = cv2.split(img)
132
 
133
- # Adjust the red channel based on the red mask
134
- r = np.where(red_mask, np.clip(r + (255 - r) * alpha, 0, 255), r).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Adjust the blue channel based on the blue mask
137
- b = np.where(blue_mask, np.clip(b + (255 - b) * alpha, 0, 255), b).astype(np.uint8)
138
-
139
- # Merge the channels back together
140
  output_img = cv2.merge((b, g, r))
141
 
 
142
  white_img = np.full_like(output_img, 255, dtype=np.uint8)
143
- # Expand non_masked_area to 3D for proper broadcasting with 3-channel images
144
- non_masked_area_3d = np.expand_dims(non_masked_area, axis=-1)
145
- output_img = np.where(non_masked_area_3d, cv2.addWeighted(output_img, 1 - white_alpha, white_img, white_alpha, 0), output_img)
 
 
 
146
 
147
  return output_img
148
 
149
 
150
  def extract_object_subject_jax(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
151
  """JAX version of extract_object_subject that works with JAX arrays.
152
-
153
  Args:
154
  img: JAX array of shape (height, width, 3) in BGR format
155
  red_mask: JAX array of shape (height, width, 1) or (height, width)
156
  blue_mask: JAX array of shape (height, width, 1) or (height, width)
157
  alpha: float between 0 and 1 for color highlighting
158
  white_alpha: float between 0 and 1 for background blending
159
-
160
  Returns:
161
  JAX array of shape (height, width, 3) in BGR format with uint8 dtype
162
  """
163
  # Convert input image to float32 for calculations
164
  img = img.astype(jnp.float32)
165
-
166
  # Ensure the masks are binary (0 or 1)
167
  red_mask = red_mask.astype(bool)
168
  blue_mask = blue_mask.astype(bool)
@@ -179,54 +228,58 @@ def extract_object_subject_jax(img, red_mask, blue_mask, alpha=0.5, white_alpha=
179
  r = img[..., 2] # Red channel
180
 
181
  # Adjust the red channel based on the red mask
182
- r = jnp.where(red_mask[..., 0],
183
- jnp.clip(r + (255 - r) * alpha, 0, 255),
184
- r)
185
 
186
  # Adjust the blue channel based on the blue mask
187
- b = jnp.where(blue_mask[..., 0],
188
- jnp.clip(b + (255 - b) * alpha, 0, 255),
189
- b)
190
 
191
  # Stack the channels back together
192
  output_img = jnp.stack([b, g, r], axis=-1)
193
 
194
  # Create white background and blend
195
  white_img = jnp.full_like(output_img, 255.0, dtype=jnp.float32)
196
- output_img = jnp.where(non_masked_area,
197
- output_img * (1 - white_alpha) + white_img * white_alpha,
198
- output_img)
 
 
199
 
200
  # Round to nearest integer and cast to uint8
201
  output_img = jnp.round(output_img)
202
  return output_img.astype(jnp.uint8)
203
 
204
- def increase_brightness_draw_outer_edge(img, bbox_ls, alpha=0.2, colormap_name='Set1', thickness=2):
 
 
 
205
  if isinstance(img, torch.Tensor):
206
  img = img.cpu().numpy().astype(np.uint8)
207
  else:
208
  img = img.astype(np.uint8)
209
  height, width, _ = img.shape
210
- white_img = np.zeros([height,width,3],dtype=np.uint8)
211
- white_img.fill(255) # or img[:] = 255
212
 
213
- output_img = cv2.addWeighted(img, alpha , white_img, 1-alpha, 0)
214
  colormap = plt.colormaps[colormap_name]
215
 
216
  for bbox_id, (x1, y1, x2, y2) in enumerate(bbox_ls):
217
  output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
218
- color = [c * 255 for c in mpl.colors.to_rgb(colormap(bbox_id))]
219
  # print(f"color: {color}")
220
  output_img = cv2.rectangle(output_img, (x1, y1), (x2, y2), color, thickness)
221
 
222
  return torch.tensor(output_img, dtype=torch.float32)
223
 
 
224
  def get_print_hook(name):
225
  def print_hook(grad):
226
  print(f"{name}: \n {grad} \n")
227
  return grad
 
228
  return print_hook
229
 
 
230
  def segment_list(l, n=5):
231
  current_seg = []
232
  all_segs = []
@@ -242,18 +295,22 @@ def segment_list(l, n=5):
242
 
243
  return all_segs
244
 
 
245
  def get_tensor_size(a):
246
  return a.element_size() * a.nelement()
247
 
 
248
  def comp_diff(v1, v2):
249
  return 2 * torch.abs(v1 - v2) / (v1 + v2)
250
 
 
251
  def gather_names(pred_res):
252
  all_names = set()
253
  for name, _ in pred_res:
254
  all_names.add(name)
255
  return list(all_names)
256
 
 
257
  def extract_nl_feats(tokenizer, model, names, device):
258
  if len(names) == 0:
259
  features = []
@@ -262,14 +319,23 @@ def extract_nl_feats(tokenizer, model, names, device):
262
  features = model.get_text_features(**name_tokens)
263
  return features
264
 
265
- def extract_all_nl_feats(tokenizer, model, batch_size, batched_names, batched_unary_kws, batched_binary_kws, device):
 
 
 
 
 
 
 
 
 
266
  batched_obj_name_features = [[] for _ in range(batch_size)]
267
  batched_unary_nl_features = [[] for _ in range(batch_size)]
268
  batched_binary_nl_features = [[] for _ in range(batch_size)]
269
-
270
- for vid, (object_names, unary_kws, binary_kws) in \
271
- enumerate(zip(batched_names, batched_unary_kws, batched_binary_kws)):
272
 
 
 
 
273
  obj_name_features = extract_nl_feats(tokenizer, model, object_names, device)
274
  batched_obj_name_features[vid] = obj_name_features
275
 
@@ -279,22 +345,31 @@ def extract_all_nl_feats(tokenizer, model, batch_size, batched_names, batched_un
279
  binary_features = extract_nl_feats(tokenizer, model, binary_kws, device)
280
  batched_binary_nl_features[vid] = binary_features
281
 
282
- return batched_obj_name_features, batched_unary_nl_features, batched_binary_nl_features
 
 
 
 
 
283
 
284
- def single_object_crop(batch_size, batched_videos, batched_object_ids, batched_bboxes, batched_video_splits):
 
 
285
  batched_frame_bboxes = {}
286
  batched_cropped_objs = [[] for _ in range(batch_size)]
287
 
288
  for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
289
  overall_frame_id = batched_video_splits[video_id] + frame_id
290
  if type(bbox) == dict:
291
- bx1, by1, bx2, by2 = bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2']
292
  else:
293
  bx1, by1, bx2, by2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
294
 
295
  assert by2 > by1
296
  assert bx2 > bx1
297
- batched_cropped_objs[video_id].append((batched_videos[overall_frame_id][by1:by2, bx1:bx2]))
 
 
298
  batched_frame_bboxes[video_id, frame_id, obj_id] = (bx1, by1, bx2, by2)
299
 
300
  return batched_cropped_objs, batched_frame_bboxes
 
6
  import jax.numpy as jnp
7
  import jax
8
 
9
+
10
  def increase_brightness(img, alpha=0.2):
11
  height, width, _ = img.shape
12
+ white_img = np.zeros([height, width, 3], dtype=np.uint8)
13
+ white_img.fill(255) # or img[:] = 255
14
 
15
+ dst = cv2.addWeighted(img, alpha, white_img, 1 - alpha, 0)
16
  return dst
17
 
18
+
19
  def increase_brightness_except(img, bbox_ls, alpha=0.2):
20
  height, width, _ = img.shape
21
+ white_img = np.zeros([height, width, 3], dtype=np.uint8)
22
+ white_img.fill(255) # or img[:] = 255
23
 
24
+ output_img = cv2.addWeighted(img, alpha, white_img, 1 - alpha, 0)
25
 
26
  for x1, y1, x2, y2 in bbox_ls:
27
  output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
 
30
 
31
  def extract_single_object(img, mask, alpha=0.8):
32
  """OpenCV version of extract_single_object that works with numpy arrays.
33
+
34
  Args:
35
  img: numpy array of shape (height, width, 3)
36
  mask: numpy array of shape (height, width, 1) or (height, width)
37
  alpha: float between 0 and 1 for blending
38
+
39
  Returns:
40
  numpy array of shape (height, width, 3)
41
  """
 
53
  masked_white_img = np.where(mask, white_img, img)
54
 
55
  # Blend the original image with the masked white image
56
+ output_img = cv2.addWeighted(
57
+ img.astype(np.uint8), 1 - alpha, masked_white_img.astype(np.uint8), alpha, 0
58
+ )
59
 
60
  return output_img
61
 
62
+
63
  def extract_single_object_jax(img, mask, alpha=0.8):
64
  """JAX version of extract_single_object that works with JAX arrays.
65
+
66
  Args:
67
  img: JAX array of shape (height, width, 3)
68
  mask: JAX array of shape (height, width, 1) or (height, width)
69
  alpha: float between 0 and 1 for blending
70
+
71
  Returns:
72
  JAX array of shape (height, width, 3)
73
  """
 
85
  masked_white_img = jnp.where(mask, white_img, img)
86
 
87
  # Blend the original image with the masked white image
88
+ output_img = img * (1 - alpha) + masked_white_img * alpha
89
 
90
  return output_img
91
 
92
+
93
  def crop_image_contain_bboxes(img, bbox_ls, data_id):
94
  all_bx1 = []
95
  all_by1 = []
 
98
 
99
  for bbox in bbox_ls:
100
  if isinstance(bbox, dict):
101
+ bx1, by1, bx2, by2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
102
  elif isinstance(bbox, (list, tuple, np.ndarray)):
103
+ bx1, by1, bx2, by2 = map(
104
+ int, bbox[:4]
105
+ ) # Convert first 4 elements to integers
106
  else:
107
  raise ValueError(f"Unsupported bbox format: {type(bbox)}")
108
 
 
119
  y1 = min(all_by1)
120
  y2 = max(all_by2)
121
 
122
+ assert x1 < x2, f"image bbox issue: {data_id}"
123
+ assert y1 < y2, f"image bbox issue: {data_id}"
124
 
125
  return img[y1:y2, x1:x2]
126
 
127
+
128
+ import numpy as np
129
+ import cv2
130
+
131
+
132
  def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
133
+ """
134
+ Blend subject/object regions into the image:
135
+ - red_mask: subject
136
+ - blue_mask: object
137
+ - alpha: how strong color highlight is
138
+ - white_alpha: how strongly to fade background toward white
139
+ """
140
+
141
+ # Ensure img is uint8 HxWx3
142
+ img = np.asarray(img)
143
+ if img.ndim == 2:
144
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
145
+ if img.dtype != np.uint8:
146
+ img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
147
+
148
+ # Normalize masks to 2D
149
+ red_mask = np.asarray(red_mask)
150
+ blue_mask = np.asarray(blue_mask)
151
+
152
  if red_mask.ndim == 3:
153
  red_mask = red_mask[:, :, 0]
154
  if blue_mask.ndim == 3:
 
156
 
157
  red_mask = red_mask.astype(bool)
158
  blue_mask = blue_mask.astype(bool)
159
+
160
+ # Background = areas not in either mask
161
  non_masked_area = ~(red_mask | blue_mask)
162
 
163
+ # Split channels
164
  b, g, r = cv2.split(img)
165
 
166
+ # Highlight red region
167
+ r = np.where(
168
+ red_mask,
169
+ np.clip(r + (255 - r) * alpha, 0, 255),
170
+ r,
171
+ )
172
+
173
+ # Highlight blue region
174
+ b = np.where(
175
+ blue_mask,
176
+ np.clip(b + (255 - b) * alpha, 0, 255),
177
+ b,
178
+ )
179
+
180
+ # Ensure proper dtype
181
+ b = b.astype(np.uint8)
182
+ g = g.astype(np.uint8)
183
+ r = r.astype(np.uint8)
184
 
 
 
 
 
185
  output_img = cv2.merge((b, g, r))
186
 
187
+ # Fade non-masked area toward white
188
  white_img = np.full_like(output_img, 255, dtype=np.uint8)
189
+ non_masked_area_3d = non_masked_area[
190
+ ..., None
191
+ ] # (H, W, 1) -> broadcast to (H, W, 3)
192
+
193
+ faded = cv2.addWeighted(output_img, 1 - white_alpha, white_img, white_alpha, 0)
194
+ output_img = np.where(non_masked_area_3d, faded, output_img)
195
 
196
  return output_img
197
 
198
 
199
  def extract_object_subject_jax(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
200
  """JAX version of extract_object_subject that works with JAX arrays.
201
+
202
  Args:
203
  img: JAX array of shape (height, width, 3) in BGR format
204
  red_mask: JAX array of shape (height, width, 1) or (height, width)
205
  blue_mask: JAX array of shape (height, width, 1) or (height, width)
206
  alpha: float between 0 and 1 for color highlighting
207
  white_alpha: float between 0 and 1 for background blending
208
+
209
  Returns:
210
  JAX array of shape (height, width, 3) in BGR format with uint8 dtype
211
  """
212
  # Convert input image to float32 for calculations
213
  img = img.astype(jnp.float32)
214
+
215
  # Ensure the masks are binary (0 or 1)
216
  red_mask = red_mask.astype(bool)
217
  blue_mask = blue_mask.astype(bool)
 
228
  r = img[..., 2] # Red channel
229
 
230
  # Adjust the red channel based on the red mask
231
+ r = jnp.where(red_mask[..., 0], jnp.clip(r + (255 - r) * alpha, 0, 255), r)
 
 
232
 
233
  # Adjust the blue channel based on the blue mask
234
+ b = jnp.where(blue_mask[..., 0], jnp.clip(b + (255 - b) * alpha, 0, 255), b)
 
 
235
 
236
  # Stack the channels back together
237
  output_img = jnp.stack([b, g, r], axis=-1)
238
 
239
  # Create white background and blend
240
  white_img = jnp.full_like(output_img, 255.0, dtype=jnp.float32)
241
+ output_img = jnp.where(
242
+ non_masked_area,
243
+ output_img * (1 - white_alpha) + white_img * white_alpha,
244
+ output_img,
245
+ )
246
 
247
  # Round to nearest integer and cast to uint8
248
  output_img = jnp.round(output_img)
249
  return output_img.astype(jnp.uint8)
250
 
251
+
252
+ def increase_brightness_draw_outer_edge(
253
+ img, bbox_ls, alpha=0.2, colormap_name="Set1", thickness=2
254
+ ):
255
  if isinstance(img, torch.Tensor):
256
  img = img.cpu().numpy().astype(np.uint8)
257
  else:
258
  img = img.astype(np.uint8)
259
  height, width, _ = img.shape
260
+ white_img = np.zeros([height, width, 3], dtype=np.uint8)
261
+ white_img.fill(255) # or img[:] = 255
262
 
263
+ output_img = cv2.addWeighted(img, alpha, white_img, 1 - alpha, 0)
264
  colormap = plt.colormaps[colormap_name]
265
 
266
  for bbox_id, (x1, y1, x2, y2) in enumerate(bbox_ls):
267
  output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
268
+ color = [c * 255 for c in mpl.colors.to_rgb(colormap(bbox_id))]
269
  # print(f"color: {color}")
270
  output_img = cv2.rectangle(output_img, (x1, y1), (x2, y2), color, thickness)
271
 
272
  return torch.tensor(output_img, dtype=torch.float32)
273
 
274
+
275
  def get_print_hook(name):
276
  def print_hook(grad):
277
  print(f"{name}: \n {grad} \n")
278
  return grad
279
+
280
  return print_hook
281
 
282
+
283
  def segment_list(l, n=5):
284
  current_seg = []
285
  all_segs = []
 
295
 
296
  return all_segs
297
 
298
+
299
  def get_tensor_size(a):
300
  return a.element_size() * a.nelement()
301
 
302
+
303
  def comp_diff(v1, v2):
304
  return 2 * torch.abs(v1 - v2) / (v1 + v2)
305
 
306
+
307
  def gather_names(pred_res):
308
  all_names = set()
309
  for name, _ in pred_res:
310
  all_names.add(name)
311
  return list(all_names)
312
 
313
+
314
  def extract_nl_feats(tokenizer, model, names, device):
315
  if len(names) == 0:
316
  features = []
 
319
  features = model.get_text_features(**name_tokens)
320
  return features
321
 
322
+
323
+ def extract_all_nl_feats(
324
+ tokenizer,
325
+ model,
326
+ batch_size,
327
+ batched_names,
328
+ batched_unary_kws,
329
+ batched_binary_kws,
330
+ device,
331
+ ):
332
  batched_obj_name_features = [[] for _ in range(batch_size)]
333
  batched_unary_nl_features = [[] for _ in range(batch_size)]
334
  batched_binary_nl_features = [[] for _ in range(batch_size)]
 
 
 
335
 
336
+ for vid, (object_names, unary_kws, binary_kws) in enumerate(
337
+ zip(batched_names, batched_unary_kws, batched_binary_kws)
338
+ ):
339
  obj_name_features = extract_nl_feats(tokenizer, model, object_names, device)
340
  batched_obj_name_features[vid] = obj_name_features
341
 
 
345
  binary_features = extract_nl_feats(tokenizer, model, binary_kws, device)
346
  batched_binary_nl_features[vid] = binary_features
347
 
348
+ return (
349
+ batched_obj_name_features,
350
+ batched_unary_nl_features,
351
+ batched_binary_nl_features,
352
+ )
353
+
354
 
355
+ def single_object_crop(
356
+ batch_size, batched_videos, batched_object_ids, batched_bboxes, batched_video_splits
357
+ ):
358
  batched_frame_bboxes = {}
359
  batched_cropped_objs = [[] for _ in range(batch_size)]
360
 
361
  for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
362
  overall_frame_id = batched_video_splits[video_id] + frame_id
363
  if type(bbox) == dict:
364
+ bx1, by1, bx2, by2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
365
  else:
366
  bx1, by1, bx2, by2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
367
 
368
  assert by2 > by1
369
  assert bx2 > bx1
370
+ batched_cropped_objs[video_id].append(
371
+ (batched_videos[overall_frame_id][by1:by2, bx1:bx2])
372
+ )
373
  batched_frame_bboxes[video_id, frame_id, obj_id] = (bx1, by1, bx2, by2)
374
 
375
  return batched_cropped_objs, batched_frame_bboxes
vine_hf/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/__init__.cpython-310.pyc and b/vine_hf/__pycache__/__init__.cpython-310.pyc differ
 
vine_hf/__pycache__/flattening.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/flattening.cpython-310.pyc and b/vine_hf/__pycache__/flattening.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_config.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_config.cpython-310.pyc and b/vine_hf/__pycache__/vine_config.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_model.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_model.cpython-310.pyc and b/vine_hf/__pycache__/vine_model.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_pipeline.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_pipeline.cpython-310.pyc and b/vine_hf/__pycache__/vine_pipeline.cpython-310.pyc differ
 
vine_hf/__pycache__/vis_utils.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vis_utils.cpython-310.pyc and b/vine_hf/__pycache__/vis_utils.cpython-310.pyc differ
 
vine_hf/vine_pipeline.py CHANGED
@@ -586,8 +586,17 @@ class VinePipeline(Pipeline):
586
  import subprocess
587
 
588
  try:
 
 
 
 
 
 
 
 
 
589
  ffmpeg_cmd = [
590
- "ffmpeg",
591
  "-y",
592
  "-f",
593
  "rawvideo",
@@ -657,6 +666,10 @@ class VinePipeline(Pipeline):
657
  out = None
658
  used_codec = None
659
 
 
 
 
 
660
  for codec in codecs_to_try:
661
  try:
662
  fourcc = cv2.VideoWriter_fourcc(*codec)
@@ -679,19 +692,37 @@ class VinePipeline(Pipeline):
679
 
680
  print(f"Using OpenCV with codec: {used_codec}")
681
 
 
682
  for frame in video_tensor:
 
 
 
 
 
683
  if len(frame.shape) == 3 and frame.shape[2] == 3:
684
  frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
685
  else:
686
  frame_bgr = frame
 
687
  if frame_bgr.dtype != np.uint8:
688
  frame_bgr = (
689
  (frame_bgr * 255).astype(np.uint8)
690
  if frame_bgr.max() <= 1
691
  else frame_bgr.astype(np.uint8)
692
  )
 
 
 
 
 
 
 
 
 
693
  out.write(frame_bgr)
 
694
 
 
695
  out.release()
696
  return temp_path
697
 
 
586
  import subprocess
587
 
588
  try:
589
+ # Try to get FFmpeg from imageio-ffmpeg first, then fall back to system FFmpeg
590
+ try:
591
+ import imageio_ffmpeg
592
+ ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
593
+ print(f"Using FFmpeg from imageio-ffmpeg: {ffmpeg_exe}")
594
+ except ImportError:
595
+ ffmpeg_exe = "ffmpeg"
596
+ print("Using system FFmpeg")
597
+
598
  ffmpeg_cmd = [
599
+ ffmpeg_exe,
600
  "-y",
601
  "-f",
602
  "rawvideo",
 
666
  out = None
667
  used_codec = None
668
 
669
+ # Debug: Print video tensor info
670
+ print(f"DEBUG: video_tensor shape: {video_tensor.shape}, dtype: {video_tensor.dtype}")
671
+ print(f"DEBUG: Expected dimensions - width: {width}, height: {height}, fps: {fps}")
672
+
673
  for codec in codecs_to_try:
674
  try:
675
  fourcc = cv2.VideoWriter_fourcc(*codec)
 
692
 
693
  print(f"Using OpenCV with codec: {used_codec}")
694
 
695
+ frame_count = 0
696
  for frame in video_tensor:
697
+ # Debug: Print first frame info
698
+ if frame_count == 0:
699
+ print(f"DEBUG: First frame shape: {frame.shape}, dtype: {frame.dtype}")
700
+ print(f"DEBUG: First frame min: {frame.min()}, max: {frame.max()}, mean: {frame.mean()}")
701
+
702
  if len(frame.shape) == 3 and frame.shape[2] == 3:
703
  frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
704
  else:
705
  frame_bgr = frame
706
+
707
  if frame_bgr.dtype != np.uint8:
708
  frame_bgr = (
709
  (frame_bgr * 255).astype(np.uint8)
710
  if frame_bgr.max() <= 1
711
  else frame_bgr.astype(np.uint8)
712
  )
713
+
714
+ # Debug: Check if frame dimensions match VideoWriter expectations
715
+ if frame_count == 0:
716
+ print(f"DEBUG: After conversion - frame_bgr shape: {frame_bgr.shape}, dtype: {frame_bgr.dtype}")
717
+ print(f"DEBUG: After conversion - min: {frame_bgr.min()}, max: {frame_bgr.max()}")
718
+ actual_height, actual_width = frame_bgr.shape[:2]
719
+ if actual_height != height or actual_width != width:
720
+ print(f"WARNING: Frame size mismatch! Expected ({height}, {width}), got ({actual_height}, {actual_width})")
721
+
722
  out.write(frame_bgr)
723
+ frame_count += 1
724
 
725
+ print(f"DEBUG: Wrote {frame_count} frames to video")
726
  out.release()
727
  return temp_path
728
 
vine_hf/vis_utils.py CHANGED
@@ -54,10 +54,12 @@ from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
54
  # All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
55
  ########################################################################################
56
 
 
57
  def clean_label(label):
58
  """Replace underscores and slashes with spaces for uniformity."""
59
  return label.replace("_", " ").replace("/", " ")
60
 
 
61
  # Should be performed somewhere else I believe
62
  def format_cate_preds(cate_preds):
63
  # Group object predictions from the model output.
@@ -72,6 +74,7 @@ def format_cate_preds(cate_preds):
72
  obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
73
  return obj_pred_dict
74
 
 
75
  def format_binary_cate_preds(binary_preds):
76
  frame_binary_preds = []
77
  for key, score in binary_preds.items():
@@ -85,6 +88,7 @@ def format_binary_cate_preds(binary_preds):
85
  frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
86
  return frame_binary_preds
87
 
 
88
  _FONT = cv2.FONT_HERSHEY_SIMPLEX
89
 
90
 
@@ -106,7 +110,9 @@ def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.n
106
  return mask_np > 0
107
 
108
 
109
- def _sanitize_bbox(bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int) -> Optional[Tuple[int, int, int, int]]:
 
 
110
  if bbox is None:
111
  return None
112
  if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
@@ -164,7 +170,16 @@ def _draw_label_block(
164
  cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
165
  text_x = left_x + 4
166
  text_y = min(bottom_y - baseline - 2, img_h - 1)
167
- cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
168
  y_cursor = bottom_y
169
  else:
170
  for text in lines:
@@ -177,7 +192,16 @@ def _draw_label_block(
177
  cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
178
  text_x = left_x + 4
179
  text_y = min(bottom_y - baseline - 2, img_h - 1)
180
- cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
181
  y_cursor = top_y
182
 
183
 
@@ -198,13 +222,26 @@ def _draw_centered_label(
198
  top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
199
  right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
200
  bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
201
- cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1)
 
 
202
  text_x = left_x + 4
203
  text_y = min(bottom_y - baseline - 2, img_h - 1)
204
- cv2.putText(image, text, (text_x, text_y), _FONT, font_scale, (0, 0, 0), thickness, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
205
 
206
 
207
- def _extract_frame_entities(store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int) -> Dict[int, Any]:
 
 
208
  if isinstance(store, dict):
209
  frame_entry = store.get(frame_idx, {})
210
  elif isinstance(store, list) and 0 <= frame_idx < len(store):
@@ -271,7 +308,9 @@ def render_sam_frames(
271
  continue
272
  color = _object_color_bgr(obj_id)
273
  alpha = 0.45
274
- overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(color, dtype=np.float32)
 
 
275
 
276
  annotated = np.clip(overlay, 0, 255).astype(np.uint8)
277
  frame_h, frame_w = annotated.shape[:2]
@@ -329,7 +368,9 @@ def render_vine_frame_sets(
329
  cat_label_lookup: Dict[int, Tuple[str, float]],
330
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
331
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
332
- masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
 
 
333
  binary_confidence_threshold: float = 0.0,
334
  ) -> Dict[str, List[np.ndarray]]:
335
  frame_groups: Dict[str, List[np.ndarray]] = {
@@ -347,7 +388,9 @@ def render_vine_frame_sets(
347
  base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
348
  frame_h, frame_w = base_bgr.shape[:2]
349
  frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
350
- frame_masks = _extract_frame_entities(masks, frame_idx) if masks is not None else {}
 
 
351
 
352
  objects_bgr = base_bgr.copy()
353
  unary_bgr = base_bgr.copy()
@@ -393,16 +436,36 @@ def render_vine_frame_sets(
393
  for obj_id, bbox in bbox_lookup.items():
394
  title = titles_lookup.get(obj_id)
395
  unary_lines = unary_lines_lookup.get(obj_id, [])
396
- _draw_bbox_with_label(objects_bgr, bbox, obj_id, title=title, label_position="top")
397
- _draw_bbox_with_label(unary_bgr, bbox, obj_id, title=title, label_position="top")
 
 
 
 
398
  if unary_lines:
399
  anchor, direction = _label_anchor_and_direction(bbox, "bottom")
400
- _draw_label_block(unary_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
401
- _draw_bbox_with_label(binary_bgr, bbox, obj_id, title=title, label_position="top")
402
- _draw_bbox_with_label(all_bgr, bbox, obj_id, title=title, label_position="top")
 
 
 
 
 
 
 
 
 
 
403
  if unary_lines:
404
  anchor, direction = _label_anchor_and_direction(bbox, "bottom")
405
- _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
 
 
 
 
 
 
406
 
407
  # First pass: collect all pairs above threshold and deduplicate bidirectional pairs
408
  pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation)
@@ -432,15 +495,24 @@ def render_vine_frame_sets(
432
  subj_bbox = bbox_lookup.get(subj_id)
433
  obj_bbox = bbox_lookup.get(obj_id)
434
  start, end = relation_line(subj_bbox, obj_bbox)
435
- color = tuple(int(c) for c in np.clip(
436
- (np.array(_object_color_bgr(subj_id), dtype=np.float32) +
437
- np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
438
- 0, 255
439
- ))
 
 
 
 
 
 
 
440
  label_text = f"{relation} {prob:.2f}"
441
  mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
442
  # Draw arrowed lines showing direction from subject to object (smaller arrow tip)
443
- cv2.arrowedLine(binary_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
 
 
444
  cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
445
  _draw_centered_label(binary_bgr, label_text, mid_point, color)
446
  _draw_centered_label(all_bgr, label_text, mid_point, color)
@@ -459,7 +531,9 @@ def render_vine_frames(
459
  cat_label_lookup: Dict[int, Tuple[str, float]],
460
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
461
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
462
- masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
 
 
463
  binary_confidence_threshold: float = 0.0,
464
  ) -> List[np.ndarray]:
465
  return render_vine_frame_sets(
@@ -471,11 +545,12 @@ def render_vine_frames(
471
  masks,
472
  binary_confidence_threshold,
473
  ).get("all", [])
474
-
 
475
  def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
476
  all_colors = []
477
  all_texts = []
478
- for (obj_id, bbox, gt_label) in gt_labels:
479
  preds = obj_pred_dict.get(obj_id, [])
480
  if len(preds) == 0:
481
  top1 = "N/A"
@@ -485,143 +560,214 @@ def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
485
  topk_labels = [p[0] for p in preds[:topk_object]]
486
  # Compare cleaned labels.
487
  if top1.lower() == gt_label.lower():
488
- box_color = (0, 255, 0) # bright green for correct
489
  elif gt_label.lower() in [p.lower() for p in topk_labels]:
490
- box_color = (0, 165, 255) # bright orange for partial match
491
  else:
492
- box_color = (0, 0, 255) # bright red for incorrect
493
-
494
  label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
495
  all_colors.append(box_color)
496
  all_texts.append(label_text)
497
  return all_colors, all_texts
498
 
 
499
  def plot_unary(frame_img, gt_labels, all_colors, all_texts):
500
-
501
- for (obj_id, bbox, gt_label), box_color, label_text in zip(gt_labels, all_colors, all_texts):
 
502
  x1, y1, x2, y2 = map(int, bbox)
503
  cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
504
- (tw, th), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
505
- cv2.rectangle(frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1)
506
- cv2.putText(frame_img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX,
507
- 0.5, (0, 0, 0), 1, cv2.LINE_AA)
508
-
 
 
 
 
 
 
 
 
 
 
 
 
509
  return frame_img
510
 
511
- def get_white_pane(pane_height,
512
- pane_width=600,
513
- header_height = 50,
514
- header_font = cv2.FONT_HERSHEY_SIMPLEX,
515
- header_font_scale = 0.7,
516
- header_thickness = 2,
517
- header_color = (0, 0, 0)):
518
- # Create an expanded white pane to display text info.
 
 
 
519
  white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
520
-
521
  # --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
522
  left_width = int(pane_width * 0.6)
523
  right_width = pane_width - left_width
524
  left_pane = white_pane[:, :left_width, :].copy()
525
  right_pane = white_pane[:, left_width:, :].copy()
526
-
527
- cv2.putText(left_pane, "Binary Predictions", (10, header_height - 30),
528
- header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
529
- cv2.putText(right_pane, "Ground Truth", (10, header_height - 30),
530
- header_font, header_font_scale, header_color, header_thickness, cv2.LINE_AA)
531
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  return white_pane
533
 
 
534
  # This is for ploting binary prediction results with frame-based scene graphs
535
- def plot_binary_sg(frame_img,
536
- white_pane,
537
- bin_preds,
538
- gt_relations,
539
- topk_binary,
540
- header_height=50,
541
- indicator_size=20,
542
- pane_width=600):
543
- # Leave vertical space for the headers.
 
 
544
  line_height = 30 # vertical spacing per line
545
- x_text = 10 # left margin for text
546
  y_text_left = header_height + 10 # starting y for left pane text
547
- y_text_right = header_height + 10 # starting y for right pane text
548
-
549
  # Left section: top-k binary predictions.
550
  left_width = int(pane_width * 0.6)
551
  right_width = pane_width - left_width
552
  left_pane = white_pane[:, :left_width, :].copy()
553
  right_pane = white_pane[:, left_width:, :].copy()
554
-
555
- for (subj, pred_rel, obj, score) in bin_preds[:topk_binary]:
556
- correct = any((subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
557
- for gt in gt_relations)
 
 
558
  indicator_color = (0, 255, 0) if correct else (0, 0, 255)
559
- cv2.rectangle(left_pane, (x_text, y_text_left - indicator_size + 5),
560
- (x_text + indicator_size, y_text_left + 5), indicator_color, -1)
 
 
 
 
 
561
  text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
562
- cv2.putText(left_pane, text, (x_text + indicator_size + 5, y_text_left + 5),
563
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
 
 
 
 
 
 
 
 
564
  y_text_left += line_height
565
-
566
  # Right section: ground truth binary relations.
567
  for gt in gt_relations:
568
  if len(gt) != 3:
569
  continue
570
  text = f"{gt[0]} - {gt[2]} - {gt[1]}"
571
- cv2.putText(right_pane, text, (x_text, y_text_right + 5),
572
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
 
 
 
 
 
 
 
 
573
  y_text_right += line_height
574
-
575
  # Combine the two text panes and then with the frame image.
576
  combined_pane = np.hstack((left_pane, right_pane))
577
  combined_image = np.hstack((frame_img, combined_pane))
578
  return combined_image
579
 
580
- def visualized_frame(frame_img,
581
- bboxes,
582
- object_ids,
583
- gt_labels,
584
- cate_preds,
585
- binary_preds,
586
- gt_relations,
587
- topk_object,
588
- topk_binary,
589
- phase="unary"):
590
-
 
 
591
  """Return the combined annotated frame for frame index i as an image (in BGR)."""
592
  # Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
593
 
594
  # --- Process Object Predictions (for overlaying bboxes) ---
595
  if phase == "unary":
596
  objs = []
597
- for ((_, f_id, obj_id), bbox, gt_label) in zip(object_ids, bboxes, gt_labels):
598
  gt_label = clean_label(gt_label)
599
  objs.append((obj_id, bbox, gt_label))
600
-
601
  formatted_cate_preds = format_cate_preds(cate_preds)
602
- all_colors, all_texts = color_for_cate_correctness(formatted_cate_preds, gt_labels, topk_object)
 
 
603
  updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
604
  return updated_frame_img
605
-
606
  else:
607
  # --- Process Binary Predictions & Ground Truth for the Text Pane ---
608
  formatted_binary_preds = format_binary_cate_preds(binary_preds)
609
-
610
  # Ground truth binary relations for the frame.
611
  # Clean ground truth relations.
612
- gt_relations = [(clean_label(str(s)), clean_label(str(o)), clean_label(rel)) for s, o, rel in gt_relations]
613
-
 
 
 
614
  pane_width = 600 # increased pane width for more horizontal space
615
  pane_height = frame_img.shape[0]
616
-
617
  # --- Add header labels to each text pane with extra space ---
618
  header_height = 50 # increased header space
619
- white_pane = get_white_pane(pane_height, pane_width, header_height=header_height)
620
-
621
- combined_image = plot_binary_sg(frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary)
622
-
 
 
 
 
623
  return combined_image
624
 
 
625
  def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
626
  # Ensure mask is a numpy array
627
  mask = np.array(mask)
@@ -644,7 +790,7 @@ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
644
  color = list(cmap((cmap_idx * 47) % 256))
645
  color[3] = 0.5
646
  color = np.array(color)
647
-
648
  # Expand mask to (H, W, 1) for broadcasting
649
  mask_expanded = mask[..., None]
650
  mask_image = mask_expanded * color.reshape(1, 1, -1)
@@ -663,7 +809,7 @@ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
663
  linewidth=1.5,
664
  edgecolor=color[:3],
665
  facecolor="none",
666
- alpha=color[3]
667
  )
668
  ax.add_patch(rect)
669
  ax.text(
@@ -673,10 +819,11 @@ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
673
  color="white",
674
  fontsize=6,
675
  backgroundcolor=np.array(color),
676
- alpha=1
677
  )
678
  ax.imshow(mask_image)
679
 
 
680
  def save_mask_one_image(frame_image, masks, save_path):
681
  """Render masks on top of a frame and store the visualization on disk."""
682
  fig, ax = plt.subplots(1, figsize=(6, 6))
@@ -695,9 +842,7 @@ def save_mask_one_image(frame_image, masks, save_path):
695
 
696
  prepared_masks = {
697
  obj_id: (
698
- mask.detach().cpu().numpy()
699
- if torch.is_tensor(mask)
700
- else np.asarray(mask)
701
  )
702
  for obj_id, mask in mask_iter
703
  }
@@ -711,54 +856,61 @@ def save_mask_one_image(frame_image, masks, save_path):
711
  fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
712
  plt.close(fig)
713
  return save_path
714
-
715
- def get_video_masks_visualization(video_tensor,
716
- video_masks,
717
- video_id,
718
- video_save_base_dir,
719
- oid_class_pred=None,
720
- sample_rate = 1):
721
-
 
 
722
  video_save_dir = os.path.join(video_save_base_dir, video_id)
723
  if not os.path.exists(video_save_dir):
724
  os.makedirs(video_save_dir, exist_ok=True)
725
-
726
  for frame_id, image in enumerate(video_tensor):
727
  if frame_id not in video_masks:
728
  print("No mask for Frame", frame_id)
729
  continue
730
-
731
  masks = video_masks[frame_id]
732
  save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
733
  get_mask_one_image(image, masks, oid_class_pred)
734
 
 
735
  def get_mask_one_image(frame_image, masks, oid_class_pred=None):
736
  # Create a figure and axis
737
  fig, ax = plt.subplots(1, figsize=(6, 6))
738
 
739
  # Display the frame image
740
  ax.imshow(frame_image)
741
- ax.axis('off')
742
 
743
  if type(masks) == list:
744
  masks = {i: m for i, m in enumerate(masks)}
745
-
746
  # Add the masks
747
  for obj_id, mask in masks.items():
748
- det_class = f"{obj_id}. {oid_class_pred[obj_id]}" if not oid_class_pred is None else None
 
 
 
 
749
  show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
750
 
751
  # Show the plot
752
  return fig, ax
753
 
 
754
  def save_video(frames, output_filename, output_fps):
755
-
756
  # --- Create a video from all frames ---
757
  num_frames = len(frames)
758
  frame_h, frame_w = frames.shape[:2]
759
 
760
  # Use a codec supported by VS Code (H.264 via 'avc1').
761
- fourcc = cv2.VideoWriter_fourcc(*'avc1')
762
  out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
763
 
764
  print(f"Processing {num_frames} frames...")
@@ -766,23 +918,26 @@ def save_video(frames, output_filename, output_fps):
766
  vis_frame = get_visualized_frame(i)
767
  out.write(vis_frame)
768
  if i % 10 == 0:
769
- print(f"Processed frame {i+1}/{num_frames}")
770
 
771
  out.release()
772
  print(f"Video saved as {output_filename}")
773
-
774
 
775
  def list_depth(lst):
776
  """Calculates the depth of a nested list."""
777
  if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
778
  return 0
779
- elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (isinstance(lst, list) and len(lst) == 0):
 
 
780
  return 1
781
  else:
782
  return 1 + max(list_depth(item) for item in lst)
783
-
 
784
  def normalize_prompt(points, labels):
785
- if list_depth(points) == 3:
786
  points = torch.stack([p.unsqueeze(0) for p in points])
787
  labels = torch.stack([l.unsqueeze(0) for l in labels])
788
  return points, labels
@@ -791,36 +946,56 @@ def normalize_prompt(points, labels):
791
  def show_box(box, ax, object_id):
792
  if len(box) == 0:
793
  return
794
-
795
  cmap = plt.get_cmap("gist_rainbow")
796
  cmap_idx = 0 if object_id is None else object_id
797
  color = list(cmap((cmap_idx * 47) % 256))
798
-
799
  x0, y0 = box[0], box[1]
800
  w, h = box[2] - box[0], box[3] - box[1]
801
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2))
802
-
 
 
 
803
  def show_points(coords, labels, ax, object_id=None, marker_size=375):
804
  if len(labels) == 0:
805
  return
806
-
807
- pos_points = coords[labels==1]
808
- neg_points = coords[labels==0]
809
-
810
  cmap = plt.get_cmap("gist_rainbow")
811
  cmap_idx = 0 if object_id is None else object_id
812
  color = list(cmap((cmap_idx * 47) % 256))
813
-
814
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='P', s=marker_size, edgecolor=color, linewidth=1.25)
815
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='s', s=marker_size, edgecolor=color, linewidth=1.25)
816
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
818
  # Create a figure and axis
819
  fig, ax = plt.subplots(1, figsize=(6, 6))
820
 
821
  # Display the frame image
822
  ax.imshow(frame_image)
823
- ax.axis('off')
824
 
825
  points, labels = normalize_prompt(points, labels)
826
  if type(boxes) == torch.Tensor:
@@ -837,40 +1012,50 @@ def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
837
  pass
838
  else:
839
  raise Exception()
840
-
841
  for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
842
  if not len(point_ls) == 0:
843
  show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
844
-
845
  # Show the plot
846
  plt.savefig(save_path)
847
  plt.close()
848
-
849
- def save_video_prompts_visualization(video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir):
 
 
 
850
  video_save_dir = os.path.join(video_save_base_dir, video_id)
851
  if not os.path.exists(video_save_dir):
852
  os.makedirs(video_save_dir, exist_ok=True)
853
-
854
  for frame_id, image in enumerate(video_tensor):
855
  boxes, points, labels = [], [], []
856
-
857
  if frame_id in video_boxes:
858
  boxes = video_boxes[frame_id]
859
-
860
  if frame_id in video_points:
861
  points = video_points[frame_id]
862
  if frame_id in video_labels:
863
  labels = video_labels[frame_id]
864
-
865
  save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
866
  save_prompts_one_image(image, boxes, points, labels, save_path)
867
-
868
 
869
- def save_video_masks_visualization(video_tensor, video_masks, video_id, video_save_base_dir, oid_class_pred=None, sample_rate = 1):
 
 
 
 
 
 
 
 
870
  video_save_dir = os.path.join(video_save_base_dir, video_id)
871
  if not os.path.exists(video_save_dir):
872
  os.makedirs(video_save_dir, exist_ok=True)
873
-
874
  for frame_id, image in enumerate(video_tensor):
875
  if random.random() > sample_rate:
876
  continue
@@ -880,18 +1065,17 @@ def save_video_masks_visualization(video_tensor, video_masks, video_id, video_sa
880
  masks = video_masks[frame_id]
881
  save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
882
  save_mask_one_image(image, masks, save_path)
883
-
884
 
885
 
886
- def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
887
  cmap = plt.get_cmap(cmap_name)
888
  cmap_idx = 0 if obj_id is None else obj_id
889
  color = list(cmap((cmap_idx * 47) % 256))
890
  color[3] = 0.5
891
  color = np.array(color)
892
  return color
893
-
894
-
895
  def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
896
  return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
897
 
@@ -906,7 +1090,9 @@ def relation_line(
906
  """
907
  center1 = _bbox_center(bbox1)
908
  center2 = _bbox_center(bbox2)
909
- if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(center1[1], center2[1], abs_tol=1e-3):
 
 
910
  offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
911
  center2 = (center2[0] + offset, center2[1])
912
  start = (int(round(center1[0])), int(round(center1[1])))
@@ -915,57 +1101,68 @@ def relation_line(
915
  end = (end[0] + 1, end[1])
916
  return start, end
917
 
 
918
  def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
919
  # Create a figure and axis
920
  fig, ax = plt.subplots(1, figsize=(6, 6))
921
 
922
  # Display the frame image
923
  ax.imshow(frame_image)
924
- ax.axis('off')
925
-
926
  all_objs_to_show = set()
927
  all_lines_to_show = []
928
-
929
  # print(rel_pred_ls[0])
930
  for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
931
- all_objs_to_show.add(from_obj_id)
932
- all_objs_to_show.add(to_obj_id)
933
-
934
  from_mask = masks[from_obj_id]
935
  bbox1 = mask_to_bbox(from_mask)
936
  to_mask = masks[to_obj_id]
937
  bbox2 = mask_to_bbox(to_mask)
938
-
939
  c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
940
-
941
  line_color = get_color(from_obj_id)
942
  face_color = get_color(to_obj_id)
943
  line = c1, c2, face_color, line_color, rel_text
944
  all_lines_to_show.append(line)
945
-
946
  masks_to_show = {}
947
  for oid in all_objs_to_show:
948
  masks_to_show[oid] = masks[oid]
949
-
950
  # Add the masks
951
  for obj_id, mask in masks_to_show.items():
952
  show_mask(mask, ax, obj_id=obj_id, random_color=False)
953
 
954
- for (from_pt_x, from_pt_y), (to_pt_x, to_pt_y), face_color, line_color, rel_text in all_lines_to_show:
955
-
956
- plt.plot([from_pt_x, to_pt_x], [from_pt_y, to_pt_y], color=line_color, linestyle='-', linewidth=3)
 
 
 
 
 
 
 
 
957
  mid_pt_x = (from_pt_x + to_pt_x) / 2
958
  mid_pt_y = (from_pt_y + to_pt_y) / 2
959
  ax.text(
960
- mid_pt_x - 5,
961
- mid_pt_y,
962
- rel_text,
963
- color="white",
964
- fontsize=6,
965
- backgroundcolor=np.array(line_color),
966
- bbox=dict(facecolor=face_color, edgecolor=line_color, boxstyle='round,pad=1'),
967
- alpha=1
968
- )
969
-
 
 
970
  # Show the plot
971
  return fig, ax
 
54
  # All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
55
  ########################################################################################
56
 
57
+
58
  def clean_label(label):
59
  """Replace underscores and slashes with spaces for uniformity."""
60
  return label.replace("_", " ").replace("/", " ")
61
 
62
+
63
  # Should be performed somewhere else I believe
64
  def format_cate_preds(cate_preds):
65
  # Group object predictions from the model output.
 
74
  obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
75
  return obj_pred_dict
76
 
77
+
78
  def format_binary_cate_preds(binary_preds):
79
  frame_binary_preds = []
80
  for key, score in binary_preds.items():
 
88
  frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
89
  return frame_binary_preds
90
 
91
+
92
  _FONT = cv2.FONT_HERSHEY_SIMPLEX
93
 
94
 
 
110
  return mask_np > 0
111
 
112
 
113
+ def _sanitize_bbox(
114
+ bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int
115
+ ) -> Optional[Tuple[int, int, int, int]]:
116
  if bbox is None:
117
  return None
118
  if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
 
170
  cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
171
  text_x = left_x + 4
172
  text_y = min(bottom_y - baseline - 2, img_h - 1)
173
+ cv2.putText(
174
+ image,
175
+ text,
176
+ (text_x, text_y),
177
+ _FONT,
178
+ font_scale,
179
+ (0, 0, 0),
180
+ thickness,
181
+ cv2.LINE_AA,
182
+ )
183
  y_cursor = bottom_y
184
  else:
185
  for text in lines:
 
192
  cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
193
  text_x = left_x + 4
194
  text_y = min(bottom_y - baseline - 2, img_h - 1)
195
+ cv2.putText(
196
+ image,
197
+ text,
198
+ (text_x, text_y),
199
+ _FONT,
200
+ font_scale,
201
+ (0, 0, 0),
202
+ thickness,
203
+ cv2.LINE_AA,
204
+ )
205
  y_cursor = top_y
206
 
207
 
 
222
  top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
223
  right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
224
  bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
225
+ cv2.rectangle(
226
+ image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1
227
+ )
228
  text_x = left_x + 4
229
  text_y = min(bottom_y - baseline - 2, img_h - 1)
230
+ cv2.putText(
231
+ image,
232
+ text,
233
+ (text_x, text_y),
234
+ _FONT,
235
+ font_scale,
236
+ (0, 0, 0),
237
+ thickness,
238
+ cv2.LINE_AA,
239
+ )
240
 
241
 
242
+ def _extract_frame_entities(
243
+ store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int
244
+ ) -> Dict[int, Any]:
245
  if isinstance(store, dict):
246
  frame_entry = store.get(frame_idx, {})
247
  elif isinstance(store, list) and 0 <= frame_idx < len(store):
 
308
  continue
309
  color = _object_color_bgr(obj_id)
310
  alpha = 0.45
311
+ overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(
312
+ color, dtype=np.float32
313
+ )
314
 
315
  annotated = np.clip(overlay, 0, 255).astype(np.uint8)
316
  frame_h, frame_w = annotated.shape[:2]
 
368
  cat_label_lookup: Dict[int, Tuple[str, float]],
369
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
370
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
371
+ masks: Union[
372
+ Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None
373
+ ] = None,
374
  binary_confidence_threshold: float = 0.0,
375
  ) -> Dict[str, List[np.ndarray]]:
376
  frame_groups: Dict[str, List[np.ndarray]] = {
 
388
  base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
389
  frame_h, frame_w = base_bgr.shape[:2]
390
  frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
391
+ frame_masks = (
392
+ _extract_frame_entities(masks, frame_idx) if masks is not None else {}
393
+ )
394
 
395
  objects_bgr = base_bgr.copy()
396
  unary_bgr = base_bgr.copy()
 
436
  for obj_id, bbox in bbox_lookup.items():
437
  title = titles_lookup.get(obj_id)
438
  unary_lines = unary_lines_lookup.get(obj_id, [])
439
+ _draw_bbox_with_label(
440
+ objects_bgr, bbox, obj_id, title=title, label_position="top"
441
+ )
442
+ _draw_bbox_with_label(
443
+ unary_bgr, bbox, obj_id, title=title, label_position="top"
444
+ )
445
  if unary_lines:
446
  anchor, direction = _label_anchor_and_direction(bbox, "bottom")
447
+ _draw_label_block(
448
+ unary_bgr,
449
+ unary_lines,
450
+ anchor,
451
+ _object_color_bgr(obj_id),
452
+ direction=direction,
453
+ )
454
+ _draw_bbox_with_label(
455
+ binary_bgr, bbox, obj_id, title=title, label_position="top"
456
+ )
457
+ _draw_bbox_with_label(
458
+ all_bgr, bbox, obj_id, title=title, label_position="top"
459
+ )
460
  if unary_lines:
461
  anchor, direction = _label_anchor_and_direction(bbox, "bottom")
462
+ _draw_label_block(
463
+ all_bgr,
464
+ unary_lines,
465
+ anchor,
466
+ _object_color_bgr(obj_id),
467
+ direction=direction,
468
+ )
469
 
470
  # First pass: collect all pairs above threshold and deduplicate bidirectional pairs
471
  pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation)
 
495
  subj_bbox = bbox_lookup.get(subj_id)
496
  obj_bbox = bbox_lookup.get(obj_id)
497
  start, end = relation_line(subj_bbox, obj_bbox)
498
+ color = tuple(
499
+ int(c)
500
+ for c in np.clip(
501
+ (
502
+ np.array(_object_color_bgr(subj_id), dtype=np.float32)
503
+ + np.array(_object_color_bgr(obj_id), dtype=np.float32)
504
+ )
505
+ / 2.0,
506
+ 0,
507
+ 255,
508
+ )
509
+ )
510
  label_text = f"{relation} {prob:.2f}"
511
  mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
512
  # Draw arrowed lines showing direction from subject to object (smaller arrow tip)
513
+ cv2.arrowedLine(
514
+ binary_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05
515
+ )
516
  cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
517
  _draw_centered_label(binary_bgr, label_text, mid_point, color)
518
  _draw_centered_label(all_bgr, label_text, mid_point, color)
 
531
  cat_label_lookup: Dict[int, Tuple[str, float]],
532
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
533
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
534
+ masks: Union[
535
+ Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None
536
+ ] = None,
537
  binary_confidence_threshold: float = 0.0,
538
  ) -> List[np.ndarray]:
539
  return render_vine_frame_sets(
 
545
  masks,
546
  binary_confidence_threshold,
547
  ).get("all", [])
548
+
549
+
550
  def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
551
  all_colors = []
552
  all_texts = []
553
+ for obj_id, bbox, gt_label in gt_labels:
554
  preds = obj_pred_dict.get(obj_id, [])
555
  if len(preds) == 0:
556
  top1 = "N/A"
 
560
  topk_labels = [p[0] for p in preds[:topk_object]]
561
  # Compare cleaned labels.
562
  if top1.lower() == gt_label.lower():
563
+ box_color = (0, 255, 0) # bright green for correct
564
  elif gt_label.lower() in [p.lower() for p in topk_labels]:
565
+ box_color = (0, 165, 255) # bright orange for partial match
566
  else:
567
+ box_color = (0, 0, 255) # bright red for incorrect
568
+
569
  label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
570
  all_colors.append(box_color)
571
  all_texts.append(label_text)
572
  return all_colors, all_texts
573
 
574
+
575
  def plot_unary(frame_img, gt_labels, all_colors, all_texts):
576
+ for (obj_id, bbox, gt_label), box_color, label_text in zip(
577
+ gt_labels, all_colors, all_texts
578
+ ):
579
  x1, y1, x2, y2 = map(int, bbox)
580
  cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
581
+ (tw, th), baseline = cv2.getTextSize(
582
+ label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
583
+ )
584
+ cv2.rectangle(
585
+ frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1
586
+ )
587
+ cv2.putText(
588
+ frame_img,
589
+ label_text,
590
+ (x1, y1 - 2),
591
+ cv2.FONT_HERSHEY_SIMPLEX,
592
+ 0.5,
593
+ (0, 0, 0),
594
+ 1,
595
+ cv2.LINE_AA,
596
+ )
597
+
598
  return frame_img
599
 
600
+
601
+ def get_white_pane(
602
+ pane_height,
603
+ pane_width=600,
604
+ header_height=50,
605
+ header_font=cv2.FONT_HERSHEY_SIMPLEX,
606
+ header_font_scale=0.7,
607
+ header_thickness=2,
608
+ header_color=(0, 0, 0),
609
+ ):
610
+ # Create an expanded white pane to display text info.
611
  white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
612
+
613
  # --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
614
  left_width = int(pane_width * 0.6)
615
  right_width = pane_width - left_width
616
  left_pane = white_pane[:, :left_width, :].copy()
617
  right_pane = white_pane[:, left_width:, :].copy()
618
+
619
+ cv2.putText(
620
+ left_pane,
621
+ "Binary Predictions",
622
+ (10, header_height - 30),
623
+ header_font,
624
+ header_font_scale,
625
+ header_color,
626
+ header_thickness,
627
+ cv2.LINE_AA,
628
+ )
629
+ cv2.putText(
630
+ right_pane,
631
+ "Ground Truth",
632
+ (10, header_height - 30),
633
+ header_font,
634
+ header_font_scale,
635
+ header_color,
636
+ header_thickness,
637
+ cv2.LINE_AA,
638
+ )
639
+
640
  return white_pane
641
 
642
+
643
  # This is for ploting binary prediction results with frame-based scene graphs
644
+ def plot_binary_sg(
645
+ frame_img,
646
+ white_pane,
647
+ bin_preds,
648
+ gt_relations,
649
+ topk_binary,
650
+ header_height=50,
651
+ indicator_size=20,
652
+ pane_width=600,
653
+ ):
654
+ # Leave vertical space for the headers.
655
  line_height = 30 # vertical spacing per line
656
+ x_text = 10 # left margin for text
657
  y_text_left = header_height + 10 # starting y for left pane text
658
+ y_text_right = header_height + 10 # starting y for right pane text
659
+
660
  # Left section: top-k binary predictions.
661
  left_width = int(pane_width * 0.6)
662
  right_width = pane_width - left_width
663
  left_pane = white_pane[:, :left_width, :].copy()
664
  right_pane = white_pane[:, left_width:, :].copy()
665
+
666
+ for subj, pred_rel, obj, score in bin_preds[:topk_binary]:
667
+ correct = any(
668
+ (subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
669
+ for gt in gt_relations
670
+ )
671
  indicator_color = (0, 255, 0) if correct else (0, 0, 255)
672
+ cv2.rectangle(
673
+ left_pane,
674
+ (x_text, y_text_left - indicator_size + 5),
675
+ (x_text + indicator_size, y_text_left + 5),
676
+ indicator_color,
677
+ -1,
678
+ )
679
  text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
680
+ cv2.putText(
681
+ left_pane,
682
+ text,
683
+ (x_text + indicator_size + 5, y_text_left + 5),
684
+ cv2.FONT_HERSHEY_SIMPLEX,
685
+ 0.6,
686
+ (0, 0, 0),
687
+ 1,
688
+ cv2.LINE_AA,
689
+ )
690
  y_text_left += line_height
691
+
692
  # Right section: ground truth binary relations.
693
  for gt in gt_relations:
694
  if len(gt) != 3:
695
  continue
696
  text = f"{gt[0]} - {gt[2]} - {gt[1]}"
697
+ cv2.putText(
698
+ right_pane,
699
+ text,
700
+ (x_text, y_text_right + 5),
701
+ cv2.FONT_HERSHEY_SIMPLEX,
702
+ 0.6,
703
+ (0, 0, 0),
704
+ 1,
705
+ cv2.LINE_AA,
706
+ )
707
  y_text_right += line_height
708
+
709
  # Combine the two text panes and then with the frame image.
710
  combined_pane = np.hstack((left_pane, right_pane))
711
  combined_image = np.hstack((frame_img, combined_pane))
712
  return combined_image
713
 
714
+
715
+ def visualized_frame(
716
+ frame_img,
717
+ bboxes,
718
+ object_ids,
719
+ gt_labels,
720
+ cate_preds,
721
+ binary_preds,
722
+ gt_relations,
723
+ topk_object,
724
+ topk_binary,
725
+ phase="unary",
726
+ ):
727
  """Return the combined annotated frame for frame index i as an image (in BGR)."""
728
  # Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
729
 
730
  # --- Process Object Predictions (for overlaying bboxes) ---
731
  if phase == "unary":
732
  objs = []
733
+ for (_, f_id, obj_id), bbox, gt_label in zip(object_ids, bboxes, gt_labels):
734
  gt_label = clean_label(gt_label)
735
  objs.append((obj_id, bbox, gt_label))
736
+
737
  formatted_cate_preds = format_cate_preds(cate_preds)
738
+ all_colors, all_texts = color_for_cate_correctness(
739
+ formatted_cate_preds, gt_labels, topk_object
740
+ )
741
  updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
742
  return updated_frame_img
743
+
744
  else:
745
  # --- Process Binary Predictions & Ground Truth for the Text Pane ---
746
  formatted_binary_preds = format_binary_cate_preds(binary_preds)
747
+
748
  # Ground truth binary relations for the frame.
749
  # Clean ground truth relations.
750
+ gt_relations = [
751
+ (clean_label(str(s)), clean_label(str(o)), clean_label(rel))
752
+ for s, o, rel in gt_relations
753
+ ]
754
+
755
  pane_width = 600 # increased pane width for more horizontal space
756
  pane_height = frame_img.shape[0]
757
+
758
  # --- Add header labels to each text pane with extra space ---
759
  header_height = 50 # increased header space
760
+ white_pane = get_white_pane(
761
+ pane_height, pane_width, header_height=header_height
762
+ )
763
+
764
+ combined_image = plot_binary_sg(
765
+ frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary
766
+ )
767
+
768
  return combined_image
769
 
770
+
771
  def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
772
  # Ensure mask is a numpy array
773
  mask = np.array(mask)
 
790
  color = list(cmap((cmap_idx * 47) % 256))
791
  color[3] = 0.5
792
  color = np.array(color)
793
+
794
  # Expand mask to (H, W, 1) for broadcasting
795
  mask_expanded = mask[..., None]
796
  mask_image = mask_expanded * color.reshape(1, 1, -1)
 
809
  linewidth=1.5,
810
  edgecolor=color[:3],
811
  facecolor="none",
812
+ alpha=color[3],
813
  )
814
  ax.add_patch(rect)
815
  ax.text(
 
819
  color="white",
820
  fontsize=6,
821
  backgroundcolor=np.array(color),
822
+ alpha=1,
823
  )
824
  ax.imshow(mask_image)
825
 
826
+
827
  def save_mask_one_image(frame_image, masks, save_path):
828
  """Render masks on top of a frame and store the visualization on disk."""
829
  fig, ax = plt.subplots(1, figsize=(6, 6))
 
842
 
843
  prepared_masks = {
844
  obj_id: (
845
+ mask.detach().cpu().numpy() if torch.is_tensor(mask) else np.asarray(mask)
 
 
846
  )
847
  for obj_id, mask in mask_iter
848
  }
 
856
  fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
857
  plt.close(fig)
858
  return save_path
859
+
860
+
861
+ def get_video_masks_visualization(
862
+ video_tensor,
863
+ video_masks,
864
+ video_id,
865
+ video_save_base_dir,
866
+ oid_class_pred=None,
867
+ sample_rate=1,
868
+ ):
869
  video_save_dir = os.path.join(video_save_base_dir, video_id)
870
  if not os.path.exists(video_save_dir):
871
  os.makedirs(video_save_dir, exist_ok=True)
872
+
873
  for frame_id, image in enumerate(video_tensor):
874
  if frame_id not in video_masks:
875
  print("No mask for Frame", frame_id)
876
  continue
877
+
878
  masks = video_masks[frame_id]
879
  save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
880
  get_mask_one_image(image, masks, oid_class_pred)
881
 
882
+
883
  def get_mask_one_image(frame_image, masks, oid_class_pred=None):
884
  # Create a figure and axis
885
  fig, ax = plt.subplots(1, figsize=(6, 6))
886
 
887
  # Display the frame image
888
  ax.imshow(frame_image)
889
+ ax.axis("off")
890
 
891
  if type(masks) == list:
892
  masks = {i: m for i, m in enumerate(masks)}
893
+
894
  # Add the masks
895
  for obj_id, mask in masks.items():
896
+ det_class = (
897
+ f"{obj_id}. {oid_class_pred[obj_id]}"
898
+ if not oid_class_pred is None
899
+ else None
900
+ )
901
  show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
902
 
903
  # Show the plot
904
  return fig, ax
905
 
906
+
907
  def save_video(frames, output_filename, output_fps):
 
908
  # --- Create a video from all frames ---
909
  num_frames = len(frames)
910
  frame_h, frame_w = frames.shape[:2]
911
 
912
  # Use a codec supported by VS Code (H.264 via 'avc1').
913
+ fourcc = cv2.VideoWriter_fourcc(*"avc1")
914
  out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
915
 
916
  print(f"Processing {num_frames} frames...")
 
918
  vis_frame = get_visualized_frame(i)
919
  out.write(vis_frame)
920
  if i % 10 == 0:
921
+ print(f"Processed frame {i + 1}/{num_frames}")
922
 
923
  out.release()
924
  print(f"Video saved as {output_filename}")
925
+
926
 
927
  def list_depth(lst):
928
  """Calculates the depth of a nested list."""
929
  if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
930
  return 0
931
+ elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (
932
+ isinstance(lst, list) and len(lst) == 0
933
+ ):
934
  return 1
935
  else:
936
  return 1 + max(list_depth(item) for item in lst)
937
+
938
+
939
  def normalize_prompt(points, labels):
940
+ if list_depth(points) == 3:
941
  points = torch.stack([p.unsqueeze(0) for p in points])
942
  labels = torch.stack([l.unsqueeze(0) for l in labels])
943
  return points, labels
 
946
  def show_box(box, ax, object_id):
947
  if len(box) == 0:
948
  return
949
+
950
  cmap = plt.get_cmap("gist_rainbow")
951
  cmap_idx = 0 if object_id is None else object_id
952
  color = list(cmap((cmap_idx * 47) % 256))
953
+
954
  x0, y0 = box[0], box[1]
955
  w, h = box[2] - box[0], box[3] - box[1]
956
+ ax.add_patch(
957
+ plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=2)
958
+ )
959
+
960
+
961
  def show_points(coords, labels, ax, object_id=None, marker_size=375):
962
  if len(labels) == 0:
963
  return
964
+
965
+ pos_points = coords[labels == 1]
966
+ neg_points = coords[labels == 0]
967
+
968
  cmap = plt.get_cmap("gist_rainbow")
969
  cmap_idx = 0 if object_id is None else object_id
970
  color = list(cmap((cmap_idx * 47) % 256))
971
+
972
+ ax.scatter(
973
+ pos_points[:, 0],
974
+ pos_points[:, 1],
975
+ color="green",
976
+ marker="P",
977
+ s=marker_size,
978
+ edgecolor=color,
979
+ linewidth=1.25,
980
+ )
981
+ ax.scatter(
982
+ neg_points[:, 0],
983
+ neg_points[:, 1],
984
+ color="red",
985
+ marker="s",
986
+ s=marker_size,
987
+ edgecolor=color,
988
+ linewidth=1.25,
989
+ )
990
+
991
+
992
  def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
993
  # Create a figure and axis
994
  fig, ax = plt.subplots(1, figsize=(6, 6))
995
 
996
  # Display the frame image
997
  ax.imshow(frame_image)
998
+ ax.axis("off")
999
 
1000
  points, labels = normalize_prompt(points, labels)
1001
  if type(boxes) == torch.Tensor:
 
1012
  pass
1013
  else:
1014
  raise Exception()
1015
+
1016
  for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
1017
  if not len(point_ls) == 0:
1018
  show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
1019
+
1020
  # Show the plot
1021
  plt.savefig(save_path)
1022
  plt.close()
1023
+
1024
+
1025
+ def save_video_prompts_visualization(
1026
+ video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir
1027
+ ):
1028
  video_save_dir = os.path.join(video_save_base_dir, video_id)
1029
  if not os.path.exists(video_save_dir):
1030
  os.makedirs(video_save_dir, exist_ok=True)
1031
+
1032
  for frame_id, image in enumerate(video_tensor):
1033
  boxes, points, labels = [], [], []
1034
+
1035
  if frame_id in video_boxes:
1036
  boxes = video_boxes[frame_id]
1037
+
1038
  if frame_id in video_points:
1039
  points = video_points[frame_id]
1040
  if frame_id in video_labels:
1041
  labels = video_labels[frame_id]
1042
+
1043
  save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
1044
  save_prompts_one_image(image, boxes, points, labels, save_path)
 
1045
 
1046
+
1047
+ def save_video_masks_visualization(
1048
+ video_tensor,
1049
+ video_masks,
1050
+ video_id,
1051
+ video_save_base_dir,
1052
+ oid_class_pred=None,
1053
+ sample_rate=1,
1054
+ ):
1055
  video_save_dir = os.path.join(video_save_base_dir, video_id)
1056
  if not os.path.exists(video_save_dir):
1057
  os.makedirs(video_save_dir, exist_ok=True)
1058
+
1059
  for frame_id, image in enumerate(video_tensor):
1060
  if random.random() > sample_rate:
1061
  continue
 
1065
  masks = video_masks[frame_id]
1066
  save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
1067
  save_mask_one_image(image, masks, save_path)
 
1068
 
1069
 
1070
+ def get_color(obj_id, cmap_name="gist_rainbow", alpha=0.5):
1071
  cmap = plt.get_cmap(cmap_name)
1072
  cmap_idx = 0 if obj_id is None else obj_id
1073
  color = list(cmap((cmap_idx * 47) % 256))
1074
  color[3] = 0.5
1075
  color = np.array(color)
1076
  return color
1077
+
1078
+
1079
  def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
1080
  return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
1081
 
 
1090
  """
1091
  center1 = _bbox_center(bbox1)
1092
  center2 = _bbox_center(bbox2)
1093
+ if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(
1094
+ center1[1], center2[1], abs_tol=1e-3
1095
+ ):
1096
  offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
1097
  center2 = (center2[0] + offset, center2[1])
1098
  start = (int(round(center1[0])), int(round(center1[1])))
 
1101
  end = (end[0] + 1, end[1])
1102
  return start, end
1103
 
1104
+
1105
  def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
1106
  # Create a figure and axis
1107
  fig, ax = plt.subplots(1, figsize=(6, 6))
1108
 
1109
  # Display the frame image
1110
  ax.imshow(frame_image)
1111
+ ax.axis("off")
1112
+
1113
  all_objs_to_show = set()
1114
  all_lines_to_show = []
1115
+
1116
  # print(rel_pred_ls[0])
1117
  for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
1118
+ all_objs_to_show.add(from_obj_id)
1119
+ all_objs_to_show.add(to_obj_id)
1120
+
1121
  from_mask = masks[from_obj_id]
1122
  bbox1 = mask_to_bbox(from_mask)
1123
  to_mask = masks[to_obj_id]
1124
  bbox2 = mask_to_bbox(to_mask)
1125
+
1126
  c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
1127
+
1128
  line_color = get_color(from_obj_id)
1129
  face_color = get_color(to_obj_id)
1130
  line = c1, c2, face_color, line_color, rel_text
1131
  all_lines_to_show.append(line)
1132
+
1133
  masks_to_show = {}
1134
  for oid in all_objs_to_show:
1135
  masks_to_show[oid] = masks[oid]
1136
+
1137
  # Add the masks
1138
  for obj_id, mask in masks_to_show.items():
1139
  show_mask(mask, ax, obj_id=obj_id, random_color=False)
1140
 
1141
+ for (from_pt_x, from_pt_y), (
1142
+ to_pt_x,
1143
+ to_pt_y,
1144
+ ), face_color, line_color, rel_text in all_lines_to_show:
1145
+ plt.plot(
1146
+ [from_pt_x, to_pt_x],
1147
+ [from_pt_y, to_pt_y],
1148
+ color=line_color,
1149
+ linestyle="-",
1150
+ linewidth=3,
1151
+ )
1152
  mid_pt_x = (from_pt_x + to_pt_x) / 2
1153
  mid_pt_y = (from_pt_y + to_pt_y) / 2
1154
  ax.text(
1155
+ mid_pt_x - 5,
1156
+ mid_pt_y,
1157
+ rel_text,
1158
+ color="white",
1159
+ fontsize=6,
1160
+ backgroundcolor=np.array(line_color),
1161
+ bbox=dict(
1162
+ facecolor=face_color, edgecolor=line_color, boxstyle="round,pad=1"
1163
+ ),
1164
+ alpha=1,
1165
+ )
1166
+
1167
  # Show the plot
1168
  return fig, ax