moqingyan123 commited on
Commit
17f21ca
·
1 Parent(s): f71f431
app.py CHANGED
@@ -60,6 +60,208 @@ print(
60
  )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @lru_cache(maxsize=1)
64
  def _load_vine_pipeline():
65
  """
@@ -96,16 +298,16 @@ def _load_vine_pipeline():
96
  )
97
 
98
 
99
- @spaces.GPU(duration=300) # Up to ~5 minutes of H200 ZeroGPU time per call
100
  def process_video(
101
  video_file,
102
  categorical_keywords,
103
  unary_keywords,
104
  binary_keywords,
105
- object_pairs,
106
  output_fps,
107
  box_threshold,
108
  text_threshold,
 
109
  ):
110
  vine_pipe = _load_vine_pipeline()
111
 
@@ -130,11 +332,17 @@ def process_video(
130
  binary_keywords = (
131
  [kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
132
  )
133
- object_pairs = (
134
- [tuple(map(int, pair.split("-"))) for pair in object_pairs.split(",")]
135
- if object_pairs
136
- else []
137
- )
 
 
 
 
 
 
138
 
139
  results = vine_pipe(
140
  inputs=video_file,
@@ -150,8 +358,17 @@ def process_video(
150
  box_threshold=box_threshold,
151
  text_threshold=text_threshold,
152
  target_fps=output_fps,
 
153
  )
154
 
 
 
 
 
 
 
 
 
155
  vine_pipe.box_threshold = box_threshold
156
  vine_pipe.text_threshold = text_threshold
157
  vine_pipe.target_fps = output_fps
@@ -194,7 +411,47 @@ def process_video(
194
  "Warning: annotated video not found or empty; check visualization settings."
195
  )
196
 
197
- return video_path_for_ui, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  def _video_component(label: str, *, is_output: bool = False):
@@ -214,6 +471,9 @@ def _video_component(label: str, *, is_output: bool = False):
214
  kwargs["type"] = "filepath"
215
  if "sources" in sig.parameters:
216
  kwargs["sources"] = ["upload"]
 
 
 
217
 
218
  if is_output and "autoplay" in sig.parameters:
219
  kwargs["autoplay"] = True
@@ -240,40 +500,103 @@ def _create_blocks():
240
  return gr.Blocks(**blocks_kwargs)
241
 
242
 
243
- # Create Gradio interface
244
  with _create_blocks() as demo:
245
- video_input = _video_component("Upload Video", is_output=False)
246
- categorical_input = gr.Textbox(
247
- label="Categorical Keywords (comma-separated)",
248
- value="person, car, tree, background",
249
- )
250
- unary_input = gr.Textbox(
251
- label="Unary Keywords (comma-separated)", value="walking, running, standing"
252
- )
253
- binary_input = gr.Textbox(
254
- label="Binary Keywords (comma-separated)",
255
- placeholder="e.g., chasing, carrying",
256
- )
257
- pairs_input = gr.Textbox(
258
- label="Object Pairs (comma-separated indices)",
259
- placeholder="e.g., 0-1,0-2 for pairs of objects",
260
- )
261
- fps_input = gr.Number(
262
- label="Output FPS (affects processing speed)", value=1 # default 1 FPS
263
- )
264
 
265
- with gr.Accordion("Advanced Settings", open=False):
266
- box_threshold_input = gr.Slider(
267
- label="Box Threshold", minimum=0.1, maximum=0.9, value=0.35, step=0.05
268
- )
269
- text_threshold_input = gr.Slider(
270
- label="Text Threshold", minimum=0.1, maximum=0.9, value=0.25, step=0.05
271
- )
272
-
273
- submit_btn = gr.Button("Process Video", variant="primary")
274
 
275
- video_output = _video_component("Output Video with Annotations", is_output=True)
276
- json_output = gr.JSON(label="Summary of Detected Events")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  submit_btn.click(
279
  fn=process_video,
@@ -282,12 +605,12 @@ with _create_blocks() as demo:
282
  categorical_input,
283
  unary_input,
284
  binary_input,
285
- pairs_input,
286
  fps_input,
287
  box_threshold_input,
288
  text_threshold_input,
 
289
  ],
290
- outputs=[video_output, json_output],
291
  )
292
 
293
  if __name__ == "__main__":
 
60
  )
61
 
62
 
63
+ def format_summary(summary, binary_confidence_threshold=0.8):
64
+ """
65
+ Format the summary dictionary into a readable markdown string.
66
+ Filters binary relations by confidence threshold.
67
+ """
68
+ if not summary or not isinstance(summary, dict):
69
+ return "# Detection Summary\n\nNo events detected or processing in progress..."
70
+
71
+ output_lines = ["# Detection Summary\n"]
72
+ has_content = False
73
+
74
+ # Categorical keywords
75
+ if "categorical_keywords" in summary and summary["categorical_keywords"]:
76
+ output_lines.append("## Categorical Keywords\n")
77
+ cate = summary["categorical_keywords"]
78
+ if isinstance(cate, dict) and cate:
79
+ has_content = True
80
+ for kw, info in cate.items():
81
+ output_lines.append(f"**{kw}**")
82
+ if isinstance(info, dict):
83
+ for key, val in info.items():
84
+ output_lines.append(f" - {key}: {val}")
85
+ else:
86
+ output_lines.append(f" - {info}")
87
+ output_lines.append("")
88
+ elif isinstance(cate, list) and cate:
89
+ has_content = True
90
+ for item in cate:
91
+ output_lines.append(f"- {item}")
92
+ output_lines.append("")
93
+
94
+ # Unary keywords
95
+ if "unary_keywords" in summary and summary["unary_keywords"]:
96
+ output_lines.append("## Unary Keywords\n")
97
+ unary = summary["unary_keywords"]
98
+ if isinstance(unary, dict) and unary:
99
+ has_content = True
100
+ for kw, info in unary.items():
101
+ output_lines.append(f"**{kw}**")
102
+ if isinstance(info, dict):
103
+ for key, val in info.items():
104
+ output_lines.append(f" - {key}: {val}")
105
+ else:
106
+ output_lines.append(f" - {info}")
107
+ output_lines.append("")
108
+ elif isinstance(unary, list) and unary:
109
+ has_content = True
110
+ for item in unary:
111
+ output_lines.append(f"- {item}")
112
+ output_lines.append("")
113
+
114
+ # Binary keywords - show ALL binary relations for debugging
115
+ print(f"DEBUG: Checking binary_keywords...")
116
+ print(f" 'binary_keywords' in summary: {'binary_keywords' in summary}")
117
+ if 'binary_keywords' in summary:
118
+ print(f" summary['binary_keywords'] truthy: {bool(summary['binary_keywords'])}")
119
+ print(f" summary['binary_keywords'] type: {type(summary['binary_keywords'])}")
120
+ print(f" summary['binary_keywords'] value: {summary['binary_keywords']}")
121
+
122
+ if "binary_keywords" in summary and summary["binary_keywords"]:
123
+ output_lines.append(f"## Binary Keywords\n")
124
+ binary = summary["binary_keywords"]
125
+ print(f"DEBUG: Processing binary keywords, type: {type(binary)}, length: {len(binary) if isinstance(binary, (dict, list)) else 'N/A'}")
126
+ if isinstance(binary, dict) and binary:
127
+ has_content = True
128
+ # Show all binary relations, sorted by confidence
129
+ binary_items = []
130
+ for kw, info in binary.items():
131
+ if isinstance(info, dict):
132
+ confidence = info.get("confidence", info.get("score", 0))
133
+ binary_items.append((kw, info, confidence))
134
+ else:
135
+ binary_items.append((kw, info, 0))
136
+
137
+ # Sort by confidence descending
138
+ binary_items.sort(key=lambda x: x[2], reverse=True)
139
+
140
+ high_conf_count = 0
141
+ low_conf_count = 0
142
+
143
+ # Show high confidence items first
144
+ output_lines.append(f"### High Confidence (≥ {binary_confidence_threshold})\n")
145
+ for kw, info, confidence in binary_items:
146
+ if confidence >= binary_confidence_threshold:
147
+ high_conf_count += 1
148
+ if isinstance(info, dict):
149
+ output_lines.append(f"**{kw}** (confidence: {confidence:.2f})")
150
+ for key, val in info.items():
151
+ if key not in ["confidence", "score"]:
152
+ output_lines.append(f" - {key}: {val}")
153
+ else:
154
+ output_lines.append(f"**{kw}**: {info}")
155
+ output_lines.append("")
156
+
157
+ if high_conf_count == 0:
158
+ output_lines.append(f"*No binary relations found with confidence ≥ {binary_confidence_threshold}*\n")
159
+
160
+ # Show lower confidence items for debugging
161
+ output_lines.append(f"### Lower Confidence (< {binary_confidence_threshold})\n")
162
+ for kw, info, confidence in binary_items:
163
+ if confidence < binary_confidence_threshold:
164
+ low_conf_count += 1
165
+ if isinstance(info, dict):
166
+ output_lines.append(f"**{kw}** (confidence: {confidence:.2f})")
167
+ for key, val in info.items():
168
+ if key not in ["confidence", "score"]:
169
+ output_lines.append(f" - {key}: {val}")
170
+ else:
171
+ output_lines.append(f"**{kw}**: {info}")
172
+ output_lines.append("")
173
+
174
+ if low_conf_count == 0:
175
+ output_lines.append(f"*No binary relations found with confidence < {binary_confidence_threshold}*\n")
176
+
177
+ output_lines.append(f"**Total binary relations detected: {len(binary_items)}**\n")
178
+ elif isinstance(binary, list) and binary:
179
+ has_content = True
180
+ for item in binary:
181
+ output_lines.append(f"- {item}")
182
+ output_lines.append("")
183
+
184
+ # Object pairs - show ALL object pair interactions for debugging
185
+ print(f"DEBUG: Checking object_pairs...")
186
+ print(f" 'object_pairs' in summary: {'object_pairs' in summary}")
187
+ if 'object_pairs' in summary:
188
+ print(f" summary['object_pairs'] truthy: {bool(summary['object_pairs'])}")
189
+ print(f" summary['object_pairs'] type: {type(summary['object_pairs'])}")
190
+ print(f" summary['object_pairs'] value: {summary['object_pairs']}")
191
+
192
+ if "object_pairs" in summary and summary["object_pairs"]:
193
+ output_lines.append(f"## Object Pair Interactions\n")
194
+ pairs = summary["object_pairs"]
195
+ print(f"DEBUG: Processing object pairs, type: {type(pairs)}, length: {len(pairs) if isinstance(pairs, (dict, list)) else 'N/A'}")
196
+ if isinstance(pairs, dict) and pairs:
197
+ has_content = True
198
+ # Show all object pairs, sorted by confidence
199
+ pair_items = []
200
+ for pair, info in pairs.items():
201
+ if isinstance(info, dict):
202
+ confidence = info.get("confidence", info.get("score", 0))
203
+ pair_items.append((pair, info, confidence))
204
+ else:
205
+ pair_items.append((pair, info, 0))
206
+
207
+ # Sort by confidence descending
208
+ pair_items.sort(key=lambda x: x[2], reverse=True)
209
+
210
+ high_conf_count = 0
211
+ low_conf_count = 0
212
+
213
+ # Show high confidence items first
214
+ output_lines.append(f"### High Confidence (≥ {binary_confidence_threshold})\n")
215
+ for pair, info, confidence in pair_items:
216
+ if confidence >= binary_confidence_threshold:
217
+ high_conf_count += 1
218
+ if isinstance(info, dict):
219
+ output_lines.append(f"**{pair}** (confidence: {confidence:.2f})")
220
+ for key, val in info.items():
221
+ if key not in ["confidence", "score"]:
222
+ output_lines.append(f" - {key}: {val}")
223
+ else:
224
+ output_lines.append(f"**{pair}**: {info}")
225
+ output_lines.append("")
226
+
227
+ if high_conf_count == 0:
228
+ output_lines.append(f"*No object pairs found with confidence ≥ {binary_confidence_threshold}*\n")
229
+
230
+ # Show lower confidence items for debugging
231
+ output_lines.append(f"### Lower Confidence (< {binary_confidence_threshold})\n")
232
+ for pair, info, confidence in pair_items:
233
+ if confidence < binary_confidence_threshold:
234
+ low_conf_count += 1
235
+ if isinstance(info, dict):
236
+ output_lines.append(f"**{pair}** (confidence: {confidence:.2f})")
237
+ for key, val in info.items():
238
+ if key not in ["confidence", "score"]:
239
+ output_lines.append(f" - {key}: {val}")
240
+ else:
241
+ output_lines.append(f"**{pair}**: {info}")
242
+ output_lines.append("")
243
+
244
+ if low_conf_count == 0:
245
+ output_lines.append(f"*No object pairs found with confidence < {binary_confidence_threshold}*\n")
246
+
247
+ output_lines.append(f"**Total object pairs detected: {len(pair_items)}**\n")
248
+ elif isinstance(pairs, list) and pairs:
249
+ has_content = True
250
+ for item in pairs:
251
+ output_lines.append(f"- {item}")
252
+ output_lines.append("")
253
+
254
+ # If no content was added, show the raw summary for debugging
255
+ if not has_content:
256
+ output_lines.append("## Raw Summary Data\n")
257
+ output_lines.append("```json")
258
+ import json
259
+ output_lines.append(json.dumps(summary, indent=2, default=str))
260
+ output_lines.append("```")
261
+
262
+ return "\n".join(output_lines)
263
+
264
+
265
  @lru_cache(maxsize=1)
266
  def _load_vine_pipeline():
267
  """
 
298
  )
299
 
300
 
301
+ @spaces.GPU(duration=120) # Up to ~5 minutes of H200 ZeroGPU time per call
302
  def process_video(
303
  video_file,
304
  categorical_keywords,
305
  unary_keywords,
306
  binary_keywords,
 
307
  output_fps,
308
  box_threshold,
309
  text_threshold,
310
+ binary_confidence_threshold,
311
  ):
312
  vine_pipe = _load_vine_pipeline()
313
 
 
332
  binary_keywords = (
333
  [kw.strip() for kw in binary_keywords.split(",")] if binary_keywords else []
334
  )
335
+
336
+ # Debug: Print what we're sending to the pipeline
337
+ print("\n" + "=" * 80)
338
+ print("INPUT TO VINE PIPELINE:")
339
+ print(f" categorical_keywords: {categorical_keywords}")
340
+ print(f" unary_keywords: {unary_keywords}")
341
+ print(f" binary_keywords: {binary_keywords}")
342
+ print("=" * 80 + "\n")
343
+
344
+ # Object pairs is now optional - empty list will auto-generate all pairs in vine_model.py
345
+ object_pairs = []
346
 
347
  results = vine_pipe(
348
  inputs=video_file,
 
358
  box_threshold=box_threshold,
359
  text_threshold=text_threshold,
360
  target_fps=output_fps,
361
+ binary_confidence_threshold=binary_confidence_threshold,
362
  )
363
 
364
+ # Debug: Print what the pipeline returned
365
+ print("\n" + "=" * 80)
366
+ print("PIPELINE RESULTS DEBUG:")
367
+ print(f" results type: {type(results)}")
368
+ if isinstance(results, dict):
369
+ print(f" results keys: {list(results.keys())}")
370
+ print("=" * 80 + "\n")
371
+
372
  vine_pipe.box_threshold = box_threshold
373
  vine_pipe.text_threshold = text_threshold
374
  vine_pipe.target_fps = output_fps
 
411
  "Warning: annotated video not found or empty; check visualization settings."
412
  )
413
 
414
+ # Debug: Print summary structure
415
+ import json
416
+ print("=" * 80)
417
+ print("SUMMARY DEBUG OUTPUT:")
418
+ print(f"Summary type: {type(summary)}")
419
+ print(f"Summary keys: {summary.keys() if isinstance(summary, dict) else 'N/A'}")
420
+ if isinstance(summary, dict):
421
+ print("\nFULL SUMMARY JSON:")
422
+ print(json.dumps(summary, indent=2, default=str))
423
+ print("\n" + "=" * 80)
424
+
425
+ # Check for any keys that might contain binary relation data
426
+ print("\nLOOKING FOR BINARY RELATION DATA:")
427
+ possible_keys = ['binary', 'binary_keywords', 'binary_relations', 'object_pairs',
428
+ 'pairs', 'relations', 'interactions', 'pairwise']
429
+ for pkey in possible_keys:
430
+ if pkey in summary:
431
+ print(f" FOUND: '{pkey}' -> {summary[pkey]}")
432
+
433
+ print("\nALL KEYS IN SUMMARY:")
434
+ for key in summary.keys():
435
+ print(f"\n{key}:")
436
+ print(f" Type: {type(summary[key])}")
437
+ if isinstance(summary[key], dict):
438
+ print(f" Length: {len(summary[key])}")
439
+ print(f" Keys (first 10): {list(summary[key].keys())[:10]}")
440
+ # Print all items for anything that might be binary relations
441
+ if any(term in key.lower() for term in ['binary', 'pair', 'relation', 'interaction']):
442
+ print(f" ALL ITEMS:")
443
+ for k, v in list(summary[key].items())[:20]: # First 20 items
444
+ print(f" {k}: {v}")
445
+ else:
446
+ print(f" Sample: {dict(list(summary[key].items())[:2])}")
447
+ elif isinstance(summary[key], list):
448
+ print(f" Length: {len(summary[key])}")
449
+ print(f" Sample: {summary[key][:2]}")
450
+ print("=" * 80)
451
+
452
+ # Format summary as readable markdown text, filtering by confidence threshold
453
+ formatted_summary = format_summary(summary, binary_confidence_threshold)
454
+ return video_path_for_ui, formatted_summary
455
 
456
 
457
  def _video_component(label: str, *, is_output: bool = False):
 
471
  kwargs["type"] = "filepath"
472
  if "sources" in sig.parameters:
473
  kwargs["sources"] = ["upload"]
474
+ # Restrict to MP4 files only
475
+ if "file_types" in sig.parameters:
476
+ kwargs["file_types"] = [".mp4"]
477
 
478
  if is_output and "autoplay" in sig.parameters:
479
  kwargs["autoplay"] = True
 
500
  return gr.Blocks(**blocks_kwargs)
501
 
502
 
503
+ # Create Gradio interface with two-column layout
504
  with _create_blocks() as demo:
505
+ gr.Markdown(
506
+ """
507
+ # 🎬 VINE: Video-based Interaction and Event Detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
+ Upload an MP4 video and specify keywords to detect objects, actions, and interactions in your video.
510
+ """
511
+ )
 
 
 
 
 
 
512
 
513
+ with gr.Row():
514
+ # Left column: Inputs
515
+ with gr.Column(scale=1):
516
+ gr.Markdown("### Input Configuration")
517
+
518
+ video_input = _video_component("Upload Video (MP4 only)", is_output=False)
519
+ gr.Markdown("*Note: Only MP4 format is currently supported*")
520
+
521
+ gr.Markdown("#### Detection Keywords")
522
+ categorical_input = gr.Textbox(
523
+ label="Categorical Keywords",
524
+ placeholder="e.g., person, car, dog",
525
+ value="person, car, dog",
526
+ info="Objects to detect in the video (comma-separated)"
527
+ )
528
+ unary_input = gr.Textbox(
529
+ label="Unary Keywords",
530
+ placeholder="e.g., walking, running, standing",
531
+ value="walking, running, standing",
532
+ info="Single-object actions to detect (comma-separated)"
533
+ )
534
+ binary_input = gr.Textbox(
535
+ label="Binary Keywords",
536
+ placeholder="e.g., chasing, carrying",
537
+ info="Object-to-object interactions to detect (comma-separated)"
538
+ )
539
+
540
+ gr.Markdown("#### Processing Settings")
541
+ fps_input = gr.Number(
542
+ label="Output FPS",
543
+ value=1,
544
+ info="Frames per second for processing (lower = faster)"
545
+ )
546
+
547
+ with gr.Accordion("Advanced Settings", open=False):
548
+ box_threshold_input = gr.Slider(
549
+ label="Box Threshold",
550
+ minimum=0.1,
551
+ maximum=0.9,
552
+ value=0.35,
553
+ step=0.05,
554
+ info="Confidence threshold for object detection"
555
+ )
556
+ text_threshold_input = gr.Slider(
557
+ label="Text Threshold",
558
+ minimum=0.1,
559
+ maximum=0.9,
560
+ value=0.25,
561
+ step=0.05,
562
+ info="Confidence threshold for text-based detection"
563
+ )
564
+ binary_confidence_input = gr.Slider(
565
+ label="Binary Relation Confidence Threshold",
566
+ minimum=0.0,
567
+ maximum=1.0,
568
+ value=0.8,
569
+ step=0.05,
570
+ info="Minimum confidence to show binary relations and object pairs"
571
+ )
572
+
573
+ submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
574
+
575
+ # Right column: Outputs
576
+ with gr.Column(scale=1):
577
+ gr.Markdown("### Results")
578
+
579
+ video_output = _video_component("Annotated Video Output", is_output=True)
580
+
581
+ gr.Markdown("### Detection Summary")
582
+ summary_output = gr.Markdown(
583
+ value="Results will appear here after processing...",
584
+ elem_classes=["summary-output"]
585
+ )
586
+
587
+ gr.Markdown(
588
+ """
589
+ ---
590
+ ### How to Use
591
+ 1. Upload an MP4 video file
592
+ 2. Specify the objects, actions, and interactions you want to detect
593
+ 3. Adjust processing settings if needed (including binary relation confidence threshold)
594
+ 4. Click "Process Video" to analyze
595
+
596
+ The system will automatically detect all binary relations between detected objects
597
+ and show only those with confidence above the threshold (default: 0.8).
598
+ """
599
+ )
600
 
601
  submit_btn.click(
602
  fn=process_video,
 
605
  categorical_input,
606
  unary_input,
607
  binary_input,
 
608
  fps_input,
609
  box_threshold_input,
610
  text_threshold_input,
611
+ binary_confidence_input,
612
  ],
613
+ outputs=[video_output, summary_output],
614
  )
615
 
616
  if __name__ == "__main__":
outputs/debug_crops/frame_0_obj_0.jpg CHANGED
outputs/debug_crops/frame_0_obj_1.jpg CHANGED
outputs/debug_crops/frame_0_obj_2.jpg CHANGED
outputs/debug_crops/frame_0_obj_3.jpg CHANGED
outputs/debug_crops/frame_0_obj_4.jpg CHANGED
outputs/debug_crops/frame_0_obj_5.jpg CHANGED
outputs/debug_crops/frame_1_obj_0.jpg CHANGED
outputs/debug_crops/frame_1_obj_1.jpg CHANGED
outputs/debug_crops/frame_1_obj_2.jpg CHANGED
outputs/debug_crops/frame_1_obj_3.jpg CHANGED
outputs/debug_crops/frame_1_obj_5.jpg CHANGED
src/LASER/laser/models/model_utils.py CHANGED
@@ -117,7 +117,12 @@ def crop_image_contain_bboxes(img, bbox_ls, data_id):
117
  return img[y1:y2, x1:x2]
118
 
119
  def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
120
- # Ensure the masks are binary (0 or 1)
 
 
 
 
 
121
  red_mask = red_mask.astype(bool)
122
  blue_mask = blue_mask.astype(bool)
123
  non_masked_area = ~(red_mask | blue_mask)
@@ -126,16 +131,18 @@ def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8)
126
  b, g, r = cv2.split(img)
127
 
128
  # Adjust the red channel based on the red mask
129
- r = np.where(red_mask[:, :, 0], np.clip(r + (255 - r) * alpha, 0, 255), r).astype(np.uint8)
130
 
131
  # Adjust the blue channel based on the blue mask
132
- b = np.where(blue_mask[:, :, 0], np.clip(b + (255 - b) * alpha, 0, 255), b).astype(np.uint8)
133
 
134
  # Merge the channels back together
135
  output_img = cv2.merge((b, g, r))
136
 
137
  white_img = np.full_like(output_img, 255, dtype=np.uint8)
138
- output_img = np.where(non_masked_area, cv2.addWeighted(output_img, 1 - white_alpha, white_img, white_alpha, 0), output_img)
 
 
139
 
140
  return output_img
141
 
 
117
  return img[y1:y2, x1:x2]
118
 
119
  def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
120
+ # Ensure the masks are 2D and binary (0 or 1)
121
+ if red_mask.ndim == 3:
122
+ red_mask = red_mask[:, :, 0]
123
+ if blue_mask.ndim == 3:
124
+ blue_mask = blue_mask[:, :, 0]
125
+
126
  red_mask = red_mask.astype(bool)
127
  blue_mask = blue_mask.astype(bool)
128
  non_masked_area = ~(red_mask | blue_mask)
 
131
  b, g, r = cv2.split(img)
132
 
133
  # Adjust the red channel based on the red mask
134
+ r = np.where(red_mask, np.clip(r + (255 - r) * alpha, 0, 255), r).astype(np.uint8)
135
 
136
  # Adjust the blue channel based on the blue mask
137
+ b = np.where(blue_mask, np.clip(b + (255 - b) * alpha, 0, 255), b).astype(np.uint8)
138
 
139
  # Merge the channels back together
140
  output_img = cv2.merge((b, g, r))
141
 
142
  white_img = np.full_like(output_img, 255, dtype=np.uint8)
143
+ # Expand non_masked_area to 3D for proper broadcasting with 3-channel images
144
+ non_masked_area_3d = np.expand_dims(non_masked_area, axis=-1)
145
+ output_img = np.where(non_masked_area_3d, cv2.addWeighted(output_img, 1 - white_alpha, white_img, white_alpha, 0), output_img)
146
 
147
  return output_img
148
 
vine_hf/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/__init__.cpython-310.pyc and b/vine_hf/__pycache__/__init__.cpython-310.pyc differ
 
vine_hf/__pycache__/flattening.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/flattening.cpython-310.pyc and b/vine_hf/__pycache__/flattening.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_config.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_config.cpython-310.pyc and b/vine_hf/__pycache__/vine_config.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_model.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_model.cpython-310.pyc and b/vine_hf/__pycache__/vine_model.cpython-310.pyc differ
 
vine_hf/__pycache__/vine_pipeline.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vine_pipeline.cpython-310.pyc and b/vine_hf/__pycache__/vine_pipeline.cpython-310.pyc differ
 
vine_hf/__pycache__/vis_utils.cpython-310.pyc CHANGED
Binary files a/vine_hf/__pycache__/vis_utils.cpython-310.pyc and b/vine_hf/__pycache__/vis_utils.cpython-310.pyc differ
 
vine_hf/vine_model.py CHANGED
@@ -388,6 +388,24 @@ class VineModel(PreTrainedModel):
388
  batched_binary_kws = [list(binary_keywords)]
389
 
390
  batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  if object_pairs:
392
  for frame_id, frame_masks in masks.items():
393
  if frame_id >= num_frames:
 
388
  batched_binary_kws = [list(binary_keywords)]
389
 
390
  batched_obj_pairs: List[Tuple[int, int, Tuple[int, int]]] = []
391
+
392
+ # Auto-generate all object pairs if binary_keywords provided but object_pairs is empty
393
+ if not object_pairs and binary_keywords:
394
+ # Get all unique object IDs across all frames
395
+ all_object_ids = set()
396
+ for frame_masks in masks.values():
397
+ all_object_ids.update(frame_masks.keys())
398
+
399
+ # Generate all bidirectional pairs (i, j) where i != j
400
+ object_pairs = []
401
+ sorted_ids = sorted(all_object_ids)
402
+ for from_oid in sorted_ids:
403
+ for to_oid in sorted_ids:
404
+ if from_oid != to_oid:
405
+ object_pairs.append((from_oid, to_oid))
406
+
407
+ print(f"Auto-generated {len(object_pairs)} bidirectional object pairs for binary relation detection: {object_pairs}")
408
+
409
  if object_pairs:
410
  for frame_id, frame_masks in masks.items():
411
  if frame_id >= num_frames:
vine_hf/vine_pipeline.py CHANGED
@@ -125,6 +125,8 @@ class VinePipeline(Pipeline):
125
  postprocess_kwargs["return_top_k"] = kwargs["return_top_k"]
126
  if "self.visualize" in kwargs:
127
  postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
 
 
128
 
129
  return preprocess_kwargs, forward_kwargs, postprocess_kwargs
130
 
@@ -781,6 +783,9 @@ class VinePipeline(Pipeline):
781
  if debug_visualizations is None:
782
  debug_visualizations = self.debug_visualizations
783
 
 
 
 
784
  vine_frame_sets = render_vine_frame_sets(
785
  frames_np,
786
  bboxes,
@@ -788,6 +793,7 @@ class VinePipeline(Pipeline):
788
  unary_lookup,
789
  binary_lookup,
790
  visualization_data.get("sam_masks"),
 
791
  )
792
 
793
  vine_visuals: Dict[str, Dict[str, Any]] = {}
@@ -872,11 +878,27 @@ class VinePipeline(Pipeline):
872
  "top_categories": [{"label": str, "probability": float}, ...],
873
  "top_unary": [{"frame_id": int, "predicate": str, "probability": float}, ...],
874
  }
 
 
 
875
  }
876
  }
877
  """
878
  categorical_preds = model_outputs.get("categorical_predictions", {})
879
  unary_preds = model_outputs.get("unary_predictions", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
880
 
881
  unary_by_obj: Dict[int, List[Tuple[float, str, int]]] = {}
882
  for (frame_id, obj_id), preds in unary_preds.items():
@@ -886,6 +908,24 @@ class VinePipeline(Pipeline):
886
  )
887
  unary_by_obj.setdefault(obj_id, []).append((prob_val, predicate, frame_id))
888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
  objects_summary: Dict[str, Dict[str, Any]] = {}
890
  all_obj_ids = set(categorical_preds.keys()) | set(unary_by_obj.keys())
891
 
@@ -927,4 +967,10 @@ class VinePipeline(Pipeline):
927
  "num_objects_detected": len(objects_summary),
928
  "objects": objects_summary,
929
  }
 
 
 
 
 
 
930
  return summary
 
125
  postprocess_kwargs["return_top_k"] = kwargs["return_top_k"]
126
  if "self.visualize" in kwargs:
127
  postprocess_kwargs["self.visualize"] = kwargs["self.visualize"]
128
+ if "binary_confidence_threshold" in kwargs:
129
+ postprocess_kwargs["binary_confidence_threshold"] = kwargs["binary_confidence_threshold"]
130
 
131
  return preprocess_kwargs, forward_kwargs, postprocess_kwargs
132
 
 
783
  if debug_visualizations is None:
784
  debug_visualizations = self.debug_visualizations
785
 
786
+ # Get binary confidence threshold from kwargs (default 0.0 means show all)
787
+ binary_confidence_threshold = kwargs.get("binary_confidence_threshold", 0.0)
788
+
789
  vine_frame_sets = render_vine_frame_sets(
790
  frames_np,
791
  bboxes,
 
793
  unary_lookup,
794
  binary_lookup,
795
  visualization_data.get("sam_masks"),
796
+ binary_confidence_threshold,
797
  )
798
 
799
  vine_visuals: Dict[str, Dict[str, Any]] = {}
 
878
  "top_categories": [{"label": str, "probability": float}, ...],
879
  "top_unary": [{"frame_id": int, "predicate": str, "probability": float}, ...],
880
  }
881
+ },
882
+ "binary_keywords": {
883
+ "<from_id>-<to_id>": {"predicate": str, "confidence": float, "frame_id": int}
884
  }
885
  }
886
  """
887
  categorical_preds = model_outputs.get("categorical_predictions", {})
888
  unary_preds = model_outputs.get("unary_predictions", {})
889
+ binary_preds = model_outputs.get("binary_predictions", {})
890
+
891
+ # Debug: Print binary predictions
892
+ print("\n" + "=" * 80)
893
+ print("DEBUG _generate_summary: Binary predictions from model")
894
+ print(f" Type: {type(binary_preds)}")
895
+ print(f" Length: {len(binary_preds) if isinstance(binary_preds, dict) else 'N/A'}")
896
+ print(f" Keys (first 20): {list(binary_preds.keys())[:20] if isinstance(binary_preds, dict) else 'N/A'}")
897
+ if isinstance(binary_preds, dict) and len(binary_preds) > 0:
898
+ print(f" Sample entries:")
899
+ for i, (key, val) in enumerate(list(binary_preds.items())[:5]):
900
+ print(f" {key}: {val}")
901
+ print("=" * 80 + "\n")
902
 
903
  unary_by_obj: Dict[int, List[Tuple[float, str, int]]] = {}
904
  for (frame_id, obj_id), preds in unary_preds.items():
 
908
  )
909
  unary_by_obj.setdefault(obj_id, []).append((prob_val, predicate, frame_id))
910
 
911
+ # Process binary predictions
912
+ binary_keywords: Dict[str, Dict[str, Any]] = {}
913
+ for (frame_id, (from_id, to_id)), preds in binary_preds.items():
914
+ for prob, predicate in preds:
915
+ prob_val = (
916
+ float(prob.detach().cpu()) if torch.is_tensor(prob) else float(prob)
917
+ )
918
+ pair_key = f"{from_id}-{to_id}"
919
+ # Keep only the highest confidence prediction for each pair
920
+ if pair_key not in binary_keywords or prob_val > binary_keywords[pair_key]["confidence"]:
921
+ binary_keywords[pair_key] = {
922
+ "predicate": predicate,
923
+ "confidence": prob_val,
924
+ "frame_id": int(frame_id),
925
+ "from_id": int(from_id),
926
+ "to_id": int(to_id),
927
+ }
928
+
929
  objects_summary: Dict[str, Dict[str, Any]] = {}
930
  all_obj_ids = set(categorical_preds.keys()) | set(unary_by_obj.keys())
931
 
 
967
  "num_objects_detected": len(objects_summary),
968
  "objects": objects_summary,
969
  }
970
+
971
+ # Add binary keywords to summary if any exist
972
+ if binary_keywords:
973
+ summary["binary_keywords"] = binary_keywords
974
+ print(f"\nDEBUG: Added {len(binary_keywords)} binary keywords to summary")
975
+
976
  return summary
vine_hf/vis_utils.py CHANGED
@@ -330,6 +330,7 @@ def render_vine_frame_sets(
330
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
331
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
332
  masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
 
333
  ) -> Dict[str, List[np.ndarray]]:
334
  frame_groups: Dict[str, List[np.ndarray]] = {
335
  "object": [],
@@ -403,6 +404,9 @@ def render_vine_frame_sets(
403
  anchor, direction = _label_anchor_and_direction(bbox, "bottom")
404
  _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
405
 
 
 
 
406
  for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
407
  if len(obj_pair) != 2 or not relation_preds:
408
  continue
@@ -411,17 +415,33 @@ def render_vine_frame_sets(
411
  obj_bbox = bbox_lookup.get(obj_id)
412
  if not subj_bbox or not obj_bbox:
413
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  start, end = relation_line(subj_bbox, obj_bbox)
415
  color = tuple(int(c) for c in np.clip(
416
  (np.array(_object_color_bgr(subj_id), dtype=np.float32) +
417
  np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
418
  0, 255
419
  ))
420
- prob, relation = relation_preds[0]
421
  label_text = f"{relation} {prob:.2f}"
422
  mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
423
- cv2.line(binary_bgr, start, end, color, 6, cv2.LINE_AA)
424
- cv2.line(all_bgr, start, end, color, 6, cv2.LINE_AA)
 
425
  _draw_centered_label(binary_bgr, label_text, mid_point, color)
426
  _draw_centered_label(all_bgr, label_text, mid_point, color)
427
 
@@ -440,6 +460,7 @@ def render_vine_frames(
440
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
441
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
442
  masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
 
443
  ) -> List[np.ndarray]:
444
  return render_vine_frame_sets(
445
  frames,
@@ -448,6 +469,7 @@ def render_vine_frames(
448
  unary_lookup,
449
  binary_lookup,
450
  masks,
 
451
  ).get("all", [])
452
 
453
  def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
 
330
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
331
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
332
  masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
333
+ binary_confidence_threshold: float = 0.0,
334
  ) -> Dict[str, List[np.ndarray]]:
335
  frame_groups: Dict[str, List[np.ndarray]] = {
336
  "object": [],
 
404
  anchor, direction = _label_anchor_and_direction(bbox, "bottom")
405
  _draw_label_block(all_bgr, unary_lines, anchor, _object_color_bgr(obj_id), direction=direction)
406
 
407
+ # First pass: collect all pairs above threshold and deduplicate bidirectional pairs
408
+ pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation)
409
+
410
  for obj_pair, relation_preds in binary_lookup.get(frame_idx, []):
411
  if len(obj_pair) != 2 or not relation_preds:
412
  continue
 
415
  obj_bbox = bbox_lookup.get(obj_id)
416
  if not subj_bbox or not obj_bbox:
417
  continue
418
+ prob, relation = relation_preds[0]
419
+ # Filter by confidence threshold
420
+ if prob < binary_confidence_threshold:
421
+ continue
422
+
423
+ # Create canonical key (smaller_id, larger_id) for deduplication
424
+ pair_key = (min(subj_id, obj_id), max(subj_id, obj_id))
425
+
426
+ # Keep the higher confidence direction
427
+ if pair_key not in pairs_to_draw or prob > pairs_to_draw[pair_key][2]:
428
+ pairs_to_draw[pair_key] = (subj_id, obj_id, prob, relation)
429
+
430
+ # Second pass: draw the selected pairs
431
+ for subj_id, obj_id, prob, relation in pairs_to_draw.values():
432
+ subj_bbox = bbox_lookup.get(subj_id)
433
+ obj_bbox = bbox_lookup.get(obj_id)
434
  start, end = relation_line(subj_bbox, obj_bbox)
435
  color = tuple(int(c) for c in np.clip(
436
  (np.array(_object_color_bgr(subj_id), dtype=np.float32) +
437
  np.array(_object_color_bgr(obj_id), dtype=np.float32)) / 2.0,
438
  0, 255
439
  ))
 
440
  label_text = f"{relation} {prob:.2f}"
441
  mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
442
+ # Draw arrowed lines showing direction from subject to object (smaller arrow tip)
443
+ cv2.arrowedLine(binary_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
444
+ cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
445
  _draw_centered_label(binary_bgr, label_text, mid_point, color)
446
  _draw_centered_label(all_bgr, label_text, mid_point, color)
447
 
 
460
  unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
461
  binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
462
  masks: Union[Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None] = None,
463
+ binary_confidence_threshold: float = 0.0,
464
  ) -> List[np.ndarray]:
465
  return render_vine_frame_sets(
466
  frames,
 
469
  unary_lookup,
470
  binary_lookup,
471
  masks,
472
+ binary_confidence_threshold,
473
  ).get("all", [])
474
 
475
  def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):