Spaces:
Running
on
T4
Running
on
T4
updates
Browse files- app.py +13 -255
- outputs/debug_crops/frame_0_obj_0.jpg +0 -0
- outputs/debug_crops/frame_0_obj_1.jpg +0 -0
- outputs/debug_crops/frame_0_obj_2.jpg +0 -0
- outputs/debug_crops/frame_0_obj_3.jpg +0 -0
- outputs/debug_crops/frame_0_obj_4.jpg +0 -0
- outputs/debug_crops/frame_0_obj_5.jpg +0 -0
- outputs/debug_crops/frame_1_obj_0.jpg +0 -0
- outputs/debug_crops/frame_1_obj_1.jpg +0 -0
- outputs/debug_crops/frame_1_obj_2.jpg +0 -0
- outputs/debug_crops/frame_1_obj_3.jpg +0 -0
- outputs/debug_crops/frame_1_obj_4.jpg +0 -0
- outputs/debug_crops/frame_1_obj_5.jpg +0 -0
- outputs/debug_crops/frame_1_obj_6.jpg +0 -0
- src/LASER/laser/models/model_utils.py +127 -52
- vine_hf/__pycache__/__init__.cpython-310.pyc +0 -0
- vine_hf/__pycache__/flattening.cpython-310.pyc +0 -0
- vine_hf/__pycache__/vine_config.cpython-310.pyc +0 -0
- vine_hf/__pycache__/vine_model.cpython-310.pyc +0 -0
- vine_hf/__pycache__/vine_pipeline.cpython-310.pyc +0 -0
- vine_hf/__pycache__/vis_utils.cpython-310.pyc +0 -0
- vine_hf/vine_pipeline.py +32 -1
- vine_hf/vis_utils.py +372 -175
app.py
CHANGED
|
@@ -60,206 +60,6 @@ print(
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
| 63 |
-
def format_summary(summary, binary_confidence_threshold=0.8):
|
| 64 |
-
"""
|
| 65 |
-
Format the summary dictionary into a readable markdown string.
|
| 66 |
-
Filters binary relations by confidence threshold.
|
| 67 |
-
"""
|
| 68 |
-
if not summary or not isinstance(summary, dict):
|
| 69 |
-
return "# Detection Summary\n\nNo events detected or processing in progress..."
|
| 70 |
-
|
| 71 |
-
output_lines = ["# Detection Summary\n"]
|
| 72 |
-
has_content = False
|
| 73 |
-
|
| 74 |
-
# Categorical keywords
|
| 75 |
-
if "categorical_keywords" in summary and summary["categorical_keywords"]:
|
| 76 |
-
output_lines.append("## Categorical Keywords\n")
|
| 77 |
-
cate = summary["categorical_keywords"]
|
| 78 |
-
if isinstance(cate, dict) and cate:
|
| 79 |
-
has_content = True
|
| 80 |
-
for kw, info in cate.items():
|
| 81 |
-
output_lines.append(f"**{kw}**")
|
| 82 |
-
if isinstance(info, dict):
|
| 83 |
-
for key, val in info.items():
|
| 84 |
-
output_lines.append(f" - {key}: {val}")
|
| 85 |
-
else:
|
| 86 |
-
output_lines.append(f" - {info}")
|
| 87 |
-
output_lines.append("")
|
| 88 |
-
elif isinstance(cate, list) and cate:
|
| 89 |
-
has_content = True
|
| 90 |
-
for item in cate:
|
| 91 |
-
output_lines.append(f"- {item}")
|
| 92 |
-
output_lines.append("")
|
| 93 |
-
|
| 94 |
-
# Unary keywords
|
| 95 |
-
if "unary_keywords" in summary and summary["unary_keywords"]:
|
| 96 |
-
output_lines.append("## Unary Keywords\n")
|
| 97 |
-
unary = summary["unary_keywords"]
|
| 98 |
-
if isinstance(unary, dict) and unary:
|
| 99 |
-
has_content = True
|
| 100 |
-
for kw, info in unary.items():
|
| 101 |
-
output_lines.append(f"**{kw}**")
|
| 102 |
-
if isinstance(info, dict):
|
| 103 |
-
for key, val in info.items():
|
| 104 |
-
output_lines.append(f" - {key}: {val}")
|
| 105 |
-
else:
|
| 106 |
-
output_lines.append(f" - {info}")
|
| 107 |
-
output_lines.append("")
|
| 108 |
-
elif isinstance(unary, list) and unary:
|
| 109 |
-
has_content = True
|
| 110 |
-
for item in unary:
|
| 111 |
-
output_lines.append(f"- {item}")
|
| 112 |
-
output_lines.append("")
|
| 113 |
-
|
| 114 |
-
# Binary keywords - show ALL binary relations for debugging
|
| 115 |
-
print(f"DEBUG: Checking binary_keywords...")
|
| 116 |
-
print(f" 'binary_keywords' in summary: {'binary_keywords' in summary}")
|
| 117 |
-
if 'binary_keywords' in summary:
|
| 118 |
-
print(f" summary['binary_keywords'] truthy: {bool(summary['binary_keywords'])}")
|
| 119 |
-
print(f" summary['binary_keywords'] type: {type(summary['binary_keywords'])}")
|
| 120 |
-
print(f" summary['binary_keywords'] value: {summary['binary_keywords']}")
|
| 121 |
-
|
| 122 |
-
if "binary_keywords" in summary and summary["binary_keywords"]:
|
| 123 |
-
output_lines.append(f"## Binary Keywords\n")
|
| 124 |
-
binary = summary["binary_keywords"]
|
| 125 |
-
print(f"DEBUG: Processing binary keywords, type: {type(binary)}, length: {len(binary) if isinstance(binary, (dict, list)) else 'N/A'}")
|
| 126 |
-
if isinstance(binary, dict) and binary:
|
| 127 |
-
has_content = True
|
| 128 |
-
# Show all binary relations, sorted by confidence
|
| 129 |
-
binary_items = []
|
| 130 |
-
for kw, info in binary.items():
|
| 131 |
-
if isinstance(info, dict):
|
| 132 |
-
confidence = info.get("confidence", info.get("score", 0))
|
| 133 |
-
binary_items.append((kw, info, confidence))
|
| 134 |
-
else:
|
| 135 |
-
binary_items.append((kw, info, 0))
|
| 136 |
-
|
| 137 |
-
# Sort by confidence descending
|
| 138 |
-
binary_items.sort(key=lambda x: x[2], reverse=True)
|
| 139 |
-
|
| 140 |
-
high_conf_count = 0
|
| 141 |
-
low_conf_count = 0
|
| 142 |
-
|
| 143 |
-
# Show high confidence items first
|
| 144 |
-
output_lines.append(f"### High Confidence (≥ {binary_confidence_threshold})\n")
|
| 145 |
-
for kw, info, confidence in binary_items:
|
| 146 |
-
if confidence >= binary_confidence_threshold:
|
| 147 |
-
high_conf_count += 1
|
| 148 |
-
if isinstance(info, dict):
|
| 149 |
-
output_lines.append(f"**{kw}** (confidence: {confidence:.2f})")
|
| 150 |
-
for key, val in info.items():
|
| 151 |
-
if key not in ["confidence", "score"]:
|
| 152 |
-
output_lines.append(f" - {key}: {val}")
|
| 153 |
-
else:
|
| 154 |
-
output_lines.append(f"**{kw}**: {info}")
|
| 155 |
-
output_lines.append("")
|
| 156 |
-
|
| 157 |
-
if high_conf_count == 0:
|
| 158 |
-
output_lines.append(f"*No binary relations found with confidence ≥ {binary_confidence_threshold}*\n")
|
| 159 |
-
|
| 160 |
-
# Show lower confidence items for debugging
|
| 161 |
-
output_lines.append(f"### Lower Confidence (< {binary_confidence_threshold})\n")
|
| 162 |
-
for kw, info, confidence in binary_items:
|
| 163 |
-
if confidence < binary_confidence_threshold:
|
| 164 |
-
low_conf_count += 1
|
| 165 |
-
if isinstance(info, dict):
|
| 166 |
-
output_lines.append(f"**{kw}** (confidence: {confidence:.2f})")
|
| 167 |
-
for key, val in info.items():
|
| 168 |
-
if key not in ["confidence", "score"]:
|
| 169 |
-
output_lines.append(f" - {key}: {val}")
|
| 170 |
-
else:
|
| 171 |
-
output_lines.append(f"**{kw}**: {info}")
|
| 172 |
-
output_lines.append("")
|
| 173 |
-
|
| 174 |
-
if low_conf_count == 0:
|
| 175 |
-
output_lines.append(f"*No binary relations found with confidence < {binary_confidence_threshold}*\n")
|
| 176 |
-
|
| 177 |
-
output_lines.append(f"**Total binary relations detected: {len(binary_items)}**\n")
|
| 178 |
-
elif isinstance(binary, list) and binary:
|
| 179 |
-
has_content = True
|
| 180 |
-
for item in binary:
|
| 181 |
-
output_lines.append(f"- {item}")
|
| 182 |
-
output_lines.append("")
|
| 183 |
-
|
| 184 |
-
# Object pairs - show ALL object pair interactions for debugging
|
| 185 |
-
print(f"DEBUG: Checking object_pairs...")
|
| 186 |
-
print(f" 'object_pairs' in summary: {'object_pairs' in summary}")
|
| 187 |
-
if 'object_pairs' in summary:
|
| 188 |
-
print(f" summary['object_pairs'] truthy: {bool(summary['object_pairs'])}")
|
| 189 |
-
print(f" summary['object_pairs'] type: {type(summary['object_pairs'])}")
|
| 190 |
-
print(f" summary['object_pairs'] value: {summary['object_pairs']}")
|
| 191 |
-
|
| 192 |
-
if "object_pairs" in summary and summary["object_pairs"]:
|
| 193 |
-
output_lines.append(f"## Object Pair Interactions\n")
|
| 194 |
-
pairs = summary["object_pairs"]
|
| 195 |
-
print(f"DEBUG: Processing object pairs, type: {type(pairs)}, length: {len(pairs) if isinstance(pairs, (dict, list)) else 'N/A'}")
|
| 196 |
-
if isinstance(pairs, dict) and pairs:
|
| 197 |
-
has_content = True
|
| 198 |
-
# Show all object pairs, sorted by confidence
|
| 199 |
-
pair_items = []
|
| 200 |
-
for pair, info in pairs.items():
|
| 201 |
-
if isinstance(info, dict):
|
| 202 |
-
confidence = info.get("confidence", info.get("score", 0))
|
| 203 |
-
pair_items.append((pair, info, confidence))
|
| 204 |
-
else:
|
| 205 |
-
pair_items.append((pair, info, 0))
|
| 206 |
-
|
| 207 |
-
# Sort by confidence descending
|
| 208 |
-
pair_items.sort(key=lambda x: x[2], reverse=True)
|
| 209 |
-
|
| 210 |
-
high_conf_count = 0
|
| 211 |
-
low_conf_count = 0
|
| 212 |
-
|
| 213 |
-
# Show high confidence items first
|
| 214 |
-
output_lines.append(f"### High Confidence (≥ {binary_confidence_threshold})\n")
|
| 215 |
-
for pair, info, confidence in pair_items:
|
| 216 |
-
if confidence >= binary_confidence_threshold:
|
| 217 |
-
high_conf_count += 1
|
| 218 |
-
if isinstance(info, dict):
|
| 219 |
-
output_lines.append(f"**{pair}** (confidence: {confidence:.2f})")
|
| 220 |
-
for key, val in info.items():
|
| 221 |
-
if key not in ["confidence", "score"]:
|
| 222 |
-
output_lines.append(f" - {key}: {val}")
|
| 223 |
-
else:
|
| 224 |
-
output_lines.append(f"**{pair}**: {info}")
|
| 225 |
-
output_lines.append("")
|
| 226 |
-
|
| 227 |
-
if high_conf_count == 0:
|
| 228 |
-
output_lines.append(f"*No object pairs found with confidence ≥ {binary_confidence_threshold}*\n")
|
| 229 |
-
|
| 230 |
-
# Show lower confidence items for debugging
|
| 231 |
-
output_lines.append(f"### Lower Confidence (< {binary_confidence_threshold})\n")
|
| 232 |
-
for pair, info, confidence in pair_items:
|
| 233 |
-
if confidence < binary_confidence_threshold:
|
| 234 |
-
low_conf_count += 1
|
| 235 |
-
if isinstance(info, dict):
|
| 236 |
-
output_lines.append(f"**{pair}** (confidence: {confidence:.2f})")
|
| 237 |
-
for key, val in info.items():
|
| 238 |
-
if key not in ["confidence", "score"]:
|
| 239 |
-
output_lines.append(f" - {key}: {val}")
|
| 240 |
-
else:
|
| 241 |
-
output_lines.append(f"**{pair}**: {info}")
|
| 242 |
-
output_lines.append("")
|
| 243 |
-
|
| 244 |
-
if low_conf_count == 0:
|
| 245 |
-
output_lines.append(f"*No object pairs found with confidence < {binary_confidence_threshold}*\n")
|
| 246 |
-
|
| 247 |
-
output_lines.append(f"**Total object pairs detected: {len(pair_items)}**\n")
|
| 248 |
-
elif isinstance(pairs, list) and pairs:
|
| 249 |
-
has_content = True
|
| 250 |
-
for item in pairs:
|
| 251 |
-
output_lines.append(f"- {item}")
|
| 252 |
-
output_lines.append("")
|
| 253 |
-
|
| 254 |
-
# If no content was added, show the raw summary for debugging
|
| 255 |
-
if not has_content:
|
| 256 |
-
output_lines.append("## Raw Summary Data\n")
|
| 257 |
-
output_lines.append("```json")
|
| 258 |
-
import json
|
| 259 |
-
output_lines.append(json.dumps(summary, indent=2, default=str))
|
| 260 |
-
output_lines.append("```")
|
| 261 |
-
|
| 262 |
-
return "\n".join(output_lines)
|
| 263 |
|
| 264 |
|
| 265 |
@lru_cache(maxsize=1)
|
|
@@ -394,9 +194,10 @@ def process_video(
|
|
| 394 |
summary = results_dict.get("summary") or {}
|
| 395 |
|
| 396 |
if result_video_path and os.path.exists(result_video_path):
|
| 397 |
-
gradio_tmp =
|
| 398 |
-
os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
|
| 399 |
-
|
|
|
|
| 400 |
gradio_tmp.mkdir(parents=True, exist_ok=True)
|
| 401 |
dest_path = gradio_tmp / Path(result_video_path).name
|
| 402 |
try:
|
|
@@ -411,47 +212,7 @@ def process_video(
|
|
| 411 |
"Warning: annotated video not found or empty; check visualization settings."
|
| 412 |
)
|
| 413 |
|
| 414 |
-
|
| 415 |
-
import json
|
| 416 |
-
print("=" * 80)
|
| 417 |
-
print("SUMMARY DEBUG OUTPUT:")
|
| 418 |
-
print(f"Summary type: {type(summary)}")
|
| 419 |
-
print(f"Summary keys: {summary.keys() if isinstance(summary, dict) else 'N/A'}")
|
| 420 |
-
if isinstance(summary, dict):
|
| 421 |
-
print("\nFULL SUMMARY JSON:")
|
| 422 |
-
print(json.dumps(summary, indent=2, default=str))
|
| 423 |
-
print("\n" + "=" * 80)
|
| 424 |
-
|
| 425 |
-
# Check for any keys that might contain binary relation data
|
| 426 |
-
print("\nLOOKING FOR BINARY RELATION DATA:")
|
| 427 |
-
possible_keys = ['binary', 'binary_keywords', 'binary_relations', 'object_pairs',
|
| 428 |
-
'pairs', 'relations', 'interactions', 'pairwise']
|
| 429 |
-
for pkey in possible_keys:
|
| 430 |
-
if pkey in summary:
|
| 431 |
-
print(f" FOUND: '{pkey}' -> {summary[pkey]}")
|
| 432 |
-
|
| 433 |
-
print("\nALL KEYS IN SUMMARY:")
|
| 434 |
-
for key in summary.keys():
|
| 435 |
-
print(f"\n{key}:")
|
| 436 |
-
print(f" Type: {type(summary[key])}")
|
| 437 |
-
if isinstance(summary[key], dict):
|
| 438 |
-
print(f" Length: {len(summary[key])}")
|
| 439 |
-
print(f" Keys (first 10): {list(summary[key].keys())[:10]}")
|
| 440 |
-
# Print all items for anything that might be binary relations
|
| 441 |
-
if any(term in key.lower() for term in ['binary', 'pair', 'relation', 'interaction']):
|
| 442 |
-
print(f" ALL ITEMS:")
|
| 443 |
-
for k, v in list(summary[key].items())[:20]: # First 20 items
|
| 444 |
-
print(f" {k}: {v}")
|
| 445 |
-
else:
|
| 446 |
-
print(f" Sample: {dict(list(summary[key].items())[:2])}")
|
| 447 |
-
elif isinstance(summary[key], list):
|
| 448 |
-
print(f" Length: {len(summary[key])}")
|
| 449 |
-
print(f" Sample: {summary[key][:2]}")
|
| 450 |
-
print("=" * 80)
|
| 451 |
-
|
| 452 |
-
# Format summary as readable markdown text, filtering by confidence threshold
|
| 453 |
-
formatted_summary = format_summary(summary, binary_confidence_threshold)
|
| 454 |
-
return video_path_for_ui, formatted_summary
|
| 455 |
|
| 456 |
|
| 457 |
def _video_component(label: str, *, is_output: bool = False):
|
|
@@ -523,25 +284,25 @@ with _create_blocks() as demo:
|
|
| 523 |
label="Categorical Keywords",
|
| 524 |
placeholder="e.g., person, car, dog",
|
| 525 |
value="person, car, dog",
|
| 526 |
-
info="Objects to detect in the video (comma-separated)"
|
| 527 |
)
|
| 528 |
unary_input = gr.Textbox(
|
| 529 |
label="Unary Keywords",
|
| 530 |
placeholder="e.g., walking, running, standing",
|
| 531 |
value="walking, running, standing",
|
| 532 |
-
info="Single-object actions to detect (comma-separated)"
|
| 533 |
)
|
| 534 |
binary_input = gr.Textbox(
|
| 535 |
label="Binary Keywords",
|
| 536 |
placeholder="e.g., chasing, carrying",
|
| 537 |
-
info="Object-to-object interactions to detect (comma-separated)"
|
| 538 |
)
|
| 539 |
|
| 540 |
gr.Markdown("#### Processing Settings")
|
| 541 |
fps_input = gr.Number(
|
| 542 |
label="Output FPS",
|
| 543 |
value=1,
|
| 544 |
-
info="Frames per second for processing (lower = faster)"
|
| 545 |
)
|
| 546 |
|
| 547 |
with gr.Accordion("Advanced Settings", open=False):
|
|
@@ -551,7 +312,7 @@ with _create_blocks() as demo:
|
|
| 551 |
maximum=0.9,
|
| 552 |
value=0.35,
|
| 553 |
step=0.05,
|
| 554 |
-
info="Confidence threshold for object detection"
|
| 555 |
)
|
| 556 |
text_threshold_input = gr.Slider(
|
| 557 |
label="Text Threshold",
|
|
@@ -559,7 +320,7 @@ with _create_blocks() as demo:
|
|
| 559 |
maximum=0.9,
|
| 560 |
value=0.25,
|
| 561 |
step=0.05,
|
| 562 |
-
info="Confidence threshold for text-based detection"
|
| 563 |
)
|
| 564 |
binary_confidence_input = gr.Slider(
|
| 565 |
label="Binary Relation Confidence Threshold",
|
|
@@ -567,7 +328,7 @@ with _create_blocks() as demo:
|
|
| 567 |
maximum=1.0,
|
| 568 |
value=0.8,
|
| 569 |
step=0.05,
|
| 570 |
-
info="Minimum confidence to show binary relations and object pairs"
|
| 571 |
)
|
| 572 |
|
| 573 |
submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
|
|
@@ -579,10 +340,7 @@ with _create_blocks() as demo:
|
|
| 579 |
video_output = _video_component("Annotated Video Output", is_output=True)
|
| 580 |
|
| 581 |
gr.Markdown("### Detection Summary")
|
| 582 |
-
summary_output = gr.
|
| 583 |
-
value="Results will appear here after processing...",
|
| 584 |
-
elem_classes=["summary-output"]
|
| 585 |
-
)
|
| 586 |
|
| 587 |
gr.Markdown(
|
| 588 |
"""
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
@lru_cache(maxsize=1)
|
|
|
|
| 194 |
summary = results_dict.get("summary") or {}
|
| 195 |
|
| 196 |
if result_video_path and os.path.exists(result_video_path):
|
| 197 |
+
gradio_tmp = (
|
| 198 |
+
Path(os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()))
|
| 199 |
+
/ "vine_outputs"
|
| 200 |
+
)
|
| 201 |
gradio_tmp.mkdir(parents=True, exist_ok=True)
|
| 202 |
dest_path = gradio_tmp / Path(result_video_path).name
|
| 203 |
try:
|
|
|
|
| 212 |
"Warning: annotated video not found or empty; check visualization settings."
|
| 213 |
)
|
| 214 |
|
| 215 |
+
return video_path_for_ui, summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
|
| 218 |
def _video_component(label: str, *, is_output: bool = False):
|
|
|
|
| 284 |
label="Categorical Keywords",
|
| 285 |
placeholder="e.g., person, car, dog",
|
| 286 |
value="person, car, dog",
|
| 287 |
+
info="Objects to detect in the video (comma-separated)",
|
| 288 |
)
|
| 289 |
unary_input = gr.Textbox(
|
| 290 |
label="Unary Keywords",
|
| 291 |
placeholder="e.g., walking, running, standing",
|
| 292 |
value="walking, running, standing",
|
| 293 |
+
info="Single-object actions to detect (comma-separated)",
|
| 294 |
)
|
| 295 |
binary_input = gr.Textbox(
|
| 296 |
label="Binary Keywords",
|
| 297 |
placeholder="e.g., chasing, carrying",
|
| 298 |
+
info="Object-to-object interactions to detect (comma-separated)",
|
| 299 |
)
|
| 300 |
|
| 301 |
gr.Markdown("#### Processing Settings")
|
| 302 |
fps_input = gr.Number(
|
| 303 |
label="Output FPS",
|
| 304 |
value=1,
|
| 305 |
+
info="Frames per second for processing (lower = faster)",
|
| 306 |
)
|
| 307 |
|
| 308 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 312 |
maximum=0.9,
|
| 313 |
value=0.35,
|
| 314 |
step=0.05,
|
| 315 |
+
info="Confidence threshold for object detection",
|
| 316 |
)
|
| 317 |
text_threshold_input = gr.Slider(
|
| 318 |
label="Text Threshold",
|
|
|
|
| 320 |
maximum=0.9,
|
| 321 |
value=0.25,
|
| 322 |
step=0.05,
|
| 323 |
+
info="Confidence threshold for text-based detection",
|
| 324 |
)
|
| 325 |
binary_confidence_input = gr.Slider(
|
| 326 |
label="Binary Relation Confidence Threshold",
|
|
|
|
| 328 |
maximum=1.0,
|
| 329 |
value=0.8,
|
| 330 |
step=0.05,
|
| 331 |
+
info="Minimum confidence to show binary relations and object pairs",
|
| 332 |
)
|
| 333 |
|
| 334 |
submit_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
|
|
|
|
| 340 |
video_output = _video_component("Annotated Video Output", is_output=True)
|
| 341 |
|
| 342 |
gr.Markdown("### Detection Summary")
|
| 343 |
+
summary_output = gr.JSON(label="Summary of Detected Events")
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
gr.Markdown(
|
| 346 |
"""
|
outputs/debug_crops/frame_0_obj_0.jpg
CHANGED
|
|
outputs/debug_crops/frame_0_obj_1.jpg
CHANGED
|
|
outputs/debug_crops/frame_0_obj_2.jpg
CHANGED
|
|
outputs/debug_crops/frame_0_obj_3.jpg
CHANGED
|
|
outputs/debug_crops/frame_0_obj_4.jpg
CHANGED
|
|
outputs/debug_crops/frame_0_obj_5.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_0.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_1.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_2.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_3.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_4.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_5.jpg
CHANGED
|
|
outputs/debug_crops/frame_1_obj_6.jpg
CHANGED
|
|
src/LASER/laser/models/model_utils.py
CHANGED
|
@@ -6,20 +6,22 @@ import torch
|
|
| 6 |
import jax.numpy as jnp
|
| 7 |
import jax
|
| 8 |
|
|
|
|
| 9 |
def increase_brightness(img, alpha=0.2):
|
| 10 |
height, width, _ = img.shape
|
| 11 |
-
white_img = np.zeros([height,width,3],dtype=np.uint8)
|
| 12 |
-
white_img.fill(255)
|
| 13 |
|
| 14 |
-
dst = cv2.addWeighted(img, alpha
|
| 15 |
return dst
|
| 16 |
|
|
|
|
| 17 |
def increase_brightness_except(img, bbox_ls, alpha=0.2):
|
| 18 |
height, width, _ = img.shape
|
| 19 |
-
white_img = np.zeros([height,width,3],dtype=np.uint8)
|
| 20 |
-
white_img.fill(255)
|
| 21 |
|
| 22 |
-
output_img = cv2.addWeighted(img, alpha
|
| 23 |
|
| 24 |
for x1, y1, x2, y2 in bbox_ls:
|
| 25 |
output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
|
|
@@ -28,12 +30,12 @@ def increase_brightness_except(img, bbox_ls, alpha=0.2):
|
|
| 28 |
|
| 29 |
def extract_single_object(img, mask, alpha=0.8):
|
| 30 |
"""OpenCV version of extract_single_object that works with numpy arrays.
|
| 31 |
-
|
| 32 |
Args:
|
| 33 |
img: numpy array of shape (height, width, 3)
|
| 34 |
mask: numpy array of shape (height, width, 1) or (height, width)
|
| 35 |
alpha: float between 0 and 1 for blending
|
| 36 |
-
|
| 37 |
Returns:
|
| 38 |
numpy array of shape (height, width, 3)
|
| 39 |
"""
|
|
@@ -51,18 +53,21 @@ def extract_single_object(img, mask, alpha=0.8):
|
|
| 51 |
masked_white_img = np.where(mask, white_img, img)
|
| 52 |
|
| 53 |
# Blend the original image with the masked white image
|
| 54 |
-
output_img = cv2.addWeighted(
|
|
|
|
|
|
|
| 55 |
|
| 56 |
return output_img
|
| 57 |
|
|
|
|
| 58 |
def extract_single_object_jax(img, mask, alpha=0.8):
|
| 59 |
"""JAX version of extract_single_object that works with JAX arrays.
|
| 60 |
-
|
| 61 |
Args:
|
| 62 |
img: JAX array of shape (height, width, 3)
|
| 63 |
mask: JAX array of shape (height, width, 1) or (height, width)
|
| 64 |
alpha: float between 0 and 1 for blending
|
| 65 |
-
|
| 66 |
Returns:
|
| 67 |
JAX array of shape (height, width, 3)
|
| 68 |
"""
|
|
@@ -80,10 +85,11 @@ def extract_single_object_jax(img, mask, alpha=0.8):
|
|
| 80 |
masked_white_img = jnp.where(mask, white_img, img)
|
| 81 |
|
| 82 |
# Blend the original image with the masked white image
|
| 83 |
-
output_img = img * (1-alpha) + masked_white_img * alpha
|
| 84 |
|
| 85 |
return output_img
|
| 86 |
|
|
|
|
| 87 |
def crop_image_contain_bboxes(img, bbox_ls, data_id):
|
| 88 |
all_bx1 = []
|
| 89 |
all_by1 = []
|
|
@@ -92,9 +98,11 @@ def crop_image_contain_bboxes(img, bbox_ls, data_id):
|
|
| 92 |
|
| 93 |
for bbox in bbox_ls:
|
| 94 |
if isinstance(bbox, dict):
|
| 95 |
-
bx1, by1, bx2, by2 = bbox[
|
| 96 |
elif isinstance(bbox, (list, tuple, np.ndarray)):
|
| 97 |
-
bx1, by1, bx2, by2 = map(
|
|
|
|
|
|
|
| 98 |
else:
|
| 99 |
raise ValueError(f"Unsupported bbox format: {type(bbox)}")
|
| 100 |
|
|
@@ -111,13 +119,36 @@ def crop_image_contain_bboxes(img, bbox_ls, data_id):
|
|
| 111 |
y1 = min(all_by1)
|
| 112 |
y2 = max(all_by2)
|
| 113 |
|
| 114 |
-
assert
|
| 115 |
-
assert
|
| 116 |
|
| 117 |
return img[y1:y2, x1:x2]
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if red_mask.ndim == 3:
|
| 122 |
red_mask = red_mask[:, :, 0]
|
| 123 |
if blue_mask.ndim == 3:
|
|
@@ -125,44 +156,62 @@ def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8)
|
|
| 125 |
|
| 126 |
red_mask = red_mask.astype(bool)
|
| 127 |
blue_mask = blue_mask.astype(bool)
|
|
|
|
|
|
|
| 128 |
non_masked_area = ~(red_mask | blue_mask)
|
| 129 |
|
| 130 |
-
# Split
|
| 131 |
b, g, r = cv2.split(img)
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
r = np.where(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
# Adjust the blue channel based on the blue mask
|
| 137 |
-
b = np.where(blue_mask, np.clip(b + (255 - b) * alpha, 0, 255), b).astype(np.uint8)
|
| 138 |
-
|
| 139 |
-
# Merge the channels back together
|
| 140 |
output_img = cv2.merge((b, g, r))
|
| 141 |
|
|
|
|
| 142 |
white_img = np.full_like(output_img, 255, dtype=np.uint8)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
return output_img
|
| 148 |
|
| 149 |
|
| 150 |
def extract_object_subject_jax(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
|
| 151 |
"""JAX version of extract_object_subject that works with JAX arrays.
|
| 152 |
-
|
| 153 |
Args:
|
| 154 |
img: JAX array of shape (height, width, 3) in BGR format
|
| 155 |
red_mask: JAX array of shape (height, width, 1) or (height, width)
|
| 156 |
blue_mask: JAX array of shape (height, width, 1) or (height, width)
|
| 157 |
alpha: float between 0 and 1 for color highlighting
|
| 158 |
white_alpha: float between 0 and 1 for background blending
|
| 159 |
-
|
| 160 |
Returns:
|
| 161 |
JAX array of shape (height, width, 3) in BGR format with uint8 dtype
|
| 162 |
"""
|
| 163 |
# Convert input image to float32 for calculations
|
| 164 |
img = img.astype(jnp.float32)
|
| 165 |
-
|
| 166 |
# Ensure the masks are binary (0 or 1)
|
| 167 |
red_mask = red_mask.astype(bool)
|
| 168 |
blue_mask = blue_mask.astype(bool)
|
|
@@ -179,54 +228,58 @@ def extract_object_subject_jax(img, red_mask, blue_mask, alpha=0.5, white_alpha=
|
|
| 179 |
r = img[..., 2] # Red channel
|
| 180 |
|
| 181 |
# Adjust the red channel based on the red mask
|
| 182 |
-
r = jnp.where(red_mask[..., 0],
|
| 183 |
-
jnp.clip(r + (255 - r) * alpha, 0, 255),
|
| 184 |
-
r)
|
| 185 |
|
| 186 |
# Adjust the blue channel based on the blue mask
|
| 187 |
-
b = jnp.where(blue_mask[..., 0],
|
| 188 |
-
jnp.clip(b + (255 - b) * alpha, 0, 255),
|
| 189 |
-
b)
|
| 190 |
|
| 191 |
# Stack the channels back together
|
| 192 |
output_img = jnp.stack([b, g, r], axis=-1)
|
| 193 |
|
| 194 |
# Create white background and blend
|
| 195 |
white_img = jnp.full_like(output_img, 255.0, dtype=jnp.float32)
|
| 196 |
-
output_img = jnp.where(
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
|
| 200 |
# Round to nearest integer and cast to uint8
|
| 201 |
output_img = jnp.round(output_img)
|
| 202 |
return output_img.astype(jnp.uint8)
|
| 203 |
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
| 205 |
if isinstance(img, torch.Tensor):
|
| 206 |
img = img.cpu().numpy().astype(np.uint8)
|
| 207 |
else:
|
| 208 |
img = img.astype(np.uint8)
|
| 209 |
height, width, _ = img.shape
|
| 210 |
-
white_img = np.zeros([height,width,3],dtype=np.uint8)
|
| 211 |
-
white_img.fill(255)
|
| 212 |
|
| 213 |
-
output_img = cv2.addWeighted(img, alpha
|
| 214 |
colormap = plt.colormaps[colormap_name]
|
| 215 |
|
| 216 |
for bbox_id, (x1, y1, x2, y2) in enumerate(bbox_ls):
|
| 217 |
output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
|
| 218 |
-
color =
|
| 219 |
# print(f"color: {color}")
|
| 220 |
output_img = cv2.rectangle(output_img, (x1, y1), (x2, y2), color, thickness)
|
| 221 |
|
| 222 |
return torch.tensor(output_img, dtype=torch.float32)
|
| 223 |
|
|
|
|
| 224 |
def get_print_hook(name):
|
| 225 |
def print_hook(grad):
|
| 226 |
print(f"{name}: \n {grad} \n")
|
| 227 |
return grad
|
|
|
|
| 228 |
return print_hook
|
| 229 |
|
|
|
|
| 230 |
def segment_list(l, n=5):
|
| 231 |
current_seg = []
|
| 232 |
all_segs = []
|
|
@@ -242,18 +295,22 @@ def segment_list(l, n=5):
|
|
| 242 |
|
| 243 |
return all_segs
|
| 244 |
|
|
|
|
| 245 |
def get_tensor_size(a):
|
| 246 |
return a.element_size() * a.nelement()
|
| 247 |
|
|
|
|
| 248 |
def comp_diff(v1, v2):
|
| 249 |
return 2 * torch.abs(v1 - v2) / (v1 + v2)
|
| 250 |
|
|
|
|
| 251 |
def gather_names(pred_res):
|
| 252 |
all_names = set()
|
| 253 |
for name, _ in pred_res:
|
| 254 |
all_names.add(name)
|
| 255 |
return list(all_names)
|
| 256 |
|
|
|
|
| 257 |
def extract_nl_feats(tokenizer, model, names, device):
|
| 258 |
if len(names) == 0:
|
| 259 |
features = []
|
|
@@ -262,14 +319,23 @@ def extract_nl_feats(tokenizer, model, names, device):
|
|
| 262 |
features = model.get_text_features(**name_tokens)
|
| 263 |
return features
|
| 264 |
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
batched_obj_name_features = [[] for _ in range(batch_size)]
|
| 267 |
batched_unary_nl_features = [[] for _ in range(batch_size)]
|
| 268 |
batched_binary_nl_features = [[] for _ in range(batch_size)]
|
| 269 |
-
|
| 270 |
-
for vid, (object_names, unary_kws, binary_kws) in \
|
| 271 |
-
enumerate(zip(batched_names, batched_unary_kws, batched_binary_kws)):
|
| 272 |
|
|
|
|
|
|
|
|
|
|
| 273 |
obj_name_features = extract_nl_feats(tokenizer, model, object_names, device)
|
| 274 |
batched_obj_name_features[vid] = obj_name_features
|
| 275 |
|
|
@@ -279,22 +345,31 @@ def extract_all_nl_feats(tokenizer, model, batch_size, batched_names, batched_un
|
|
| 279 |
binary_features = extract_nl_feats(tokenizer, model, binary_kws, device)
|
| 280 |
batched_binary_nl_features[vid] = binary_features
|
| 281 |
|
| 282 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
-
def single_object_crop(
|
|
|
|
|
|
|
| 285 |
batched_frame_bboxes = {}
|
| 286 |
batched_cropped_objs = [[] for _ in range(batch_size)]
|
| 287 |
|
| 288 |
for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
|
| 289 |
overall_frame_id = batched_video_splits[video_id] + frame_id
|
| 290 |
if type(bbox) == dict:
|
| 291 |
-
bx1, by1, bx2, by2 = bbox[
|
| 292 |
else:
|
| 293 |
bx1, by1, bx2, by2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
| 294 |
|
| 295 |
assert by2 > by1
|
| 296 |
assert bx2 > bx1
|
| 297 |
-
batched_cropped_objs[video_id].append(
|
|
|
|
|
|
|
| 298 |
batched_frame_bboxes[video_id, frame_id, obj_id] = (bx1, by1, bx2, by2)
|
| 299 |
|
| 300 |
return batched_cropped_objs, batched_frame_bboxes
|
|
|
|
| 6 |
import jax.numpy as jnp
|
| 7 |
import jax
|
| 8 |
|
| 9 |
+
|
| 10 |
def increase_brightness(img, alpha=0.2):
|
| 11 |
height, width, _ = img.shape
|
| 12 |
+
white_img = np.zeros([height, width, 3], dtype=np.uint8)
|
| 13 |
+
white_img.fill(255) # or img[:] = 255
|
| 14 |
|
| 15 |
+
dst = cv2.addWeighted(img, alpha, white_img, 1 - alpha, 0)
|
| 16 |
return dst
|
| 17 |
|
| 18 |
+
|
| 19 |
def increase_brightness_except(img, bbox_ls, alpha=0.2):
|
| 20 |
height, width, _ = img.shape
|
| 21 |
+
white_img = np.zeros([height, width, 3], dtype=np.uint8)
|
| 22 |
+
white_img.fill(255) # or img[:] = 255
|
| 23 |
|
| 24 |
+
output_img = cv2.addWeighted(img, alpha, white_img, 1 - alpha, 0)
|
| 25 |
|
| 26 |
for x1, y1, x2, y2 in bbox_ls:
|
| 27 |
output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
|
|
|
|
| 30 |
|
| 31 |
def extract_single_object(img, mask, alpha=0.8):
|
| 32 |
"""OpenCV version of extract_single_object that works with numpy arrays.
|
| 33 |
+
|
| 34 |
Args:
|
| 35 |
img: numpy array of shape (height, width, 3)
|
| 36 |
mask: numpy array of shape (height, width, 1) or (height, width)
|
| 37 |
alpha: float between 0 and 1 for blending
|
| 38 |
+
|
| 39 |
Returns:
|
| 40 |
numpy array of shape (height, width, 3)
|
| 41 |
"""
|
|
|
|
| 53 |
masked_white_img = np.where(mask, white_img, img)
|
| 54 |
|
| 55 |
# Blend the original image with the masked white image
|
| 56 |
+
output_img = cv2.addWeighted(
|
| 57 |
+
img.astype(np.uint8), 1 - alpha, masked_white_img.astype(np.uint8), alpha, 0
|
| 58 |
+
)
|
| 59 |
|
| 60 |
return output_img
|
| 61 |
|
| 62 |
+
|
| 63 |
def extract_single_object_jax(img, mask, alpha=0.8):
|
| 64 |
"""JAX version of extract_single_object that works with JAX arrays.
|
| 65 |
+
|
| 66 |
Args:
|
| 67 |
img: JAX array of shape (height, width, 3)
|
| 68 |
mask: JAX array of shape (height, width, 1) or (height, width)
|
| 69 |
alpha: float between 0 and 1 for blending
|
| 70 |
+
|
| 71 |
Returns:
|
| 72 |
JAX array of shape (height, width, 3)
|
| 73 |
"""
|
|
|
|
| 85 |
masked_white_img = jnp.where(mask, white_img, img)
|
| 86 |
|
| 87 |
# Blend the original image with the masked white image
|
| 88 |
+
output_img = img * (1 - alpha) + masked_white_img * alpha
|
| 89 |
|
| 90 |
return output_img
|
| 91 |
|
| 92 |
+
|
| 93 |
def crop_image_contain_bboxes(img, bbox_ls, data_id):
|
| 94 |
all_bx1 = []
|
| 95 |
all_by1 = []
|
|
|
|
| 98 |
|
| 99 |
for bbox in bbox_ls:
|
| 100 |
if isinstance(bbox, dict):
|
| 101 |
+
bx1, by1, bx2, by2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
|
| 102 |
elif isinstance(bbox, (list, tuple, np.ndarray)):
|
| 103 |
+
bx1, by1, bx2, by2 = map(
|
| 104 |
+
int, bbox[:4]
|
| 105 |
+
) # Convert first 4 elements to integers
|
| 106 |
else:
|
| 107 |
raise ValueError(f"Unsupported bbox format: {type(bbox)}")
|
| 108 |
|
|
|
|
| 119 |
y1 = min(all_by1)
|
| 120 |
y2 = max(all_by2)
|
| 121 |
|
| 122 |
+
assert x1 < x2, f"image bbox issue: {data_id}"
|
| 123 |
+
assert y1 < y2, f"image bbox issue: {data_id}"
|
| 124 |
|
| 125 |
return img[y1:y2, x1:x2]
|
| 126 |
|
| 127 |
+
|
| 128 |
+
import numpy as np
|
| 129 |
+
import cv2
|
| 130 |
+
|
| 131 |
+
|
| 132 |
def extract_object_subject(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
|
| 133 |
+
"""
|
| 134 |
+
Blend subject/object regions into the image:
|
| 135 |
+
- red_mask: subject
|
| 136 |
+
- blue_mask: object
|
| 137 |
+
- alpha: how strong color highlight is
|
| 138 |
+
- white_alpha: how strongly to fade background toward white
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
# Ensure img is uint8 HxWx3
|
| 142 |
+
img = np.asarray(img)
|
| 143 |
+
if img.ndim == 2:
|
| 144 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 145 |
+
if img.dtype != np.uint8:
|
| 146 |
+
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
| 147 |
+
|
| 148 |
+
# Normalize masks to 2D
|
| 149 |
+
red_mask = np.asarray(red_mask)
|
| 150 |
+
blue_mask = np.asarray(blue_mask)
|
| 151 |
+
|
| 152 |
if red_mask.ndim == 3:
|
| 153 |
red_mask = red_mask[:, :, 0]
|
| 154 |
if blue_mask.ndim == 3:
|
|
|
|
| 156 |
|
| 157 |
red_mask = red_mask.astype(bool)
|
| 158 |
blue_mask = blue_mask.astype(bool)
|
| 159 |
+
|
| 160 |
+
# Background = areas not in either mask
|
| 161 |
non_masked_area = ~(red_mask | blue_mask)
|
| 162 |
|
| 163 |
+
# Split channels
|
| 164 |
b, g, r = cv2.split(img)
|
| 165 |
|
| 166 |
+
# Highlight red region
|
| 167 |
+
r = np.where(
|
| 168 |
+
red_mask,
|
| 169 |
+
np.clip(r + (255 - r) * alpha, 0, 255),
|
| 170 |
+
r,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Highlight blue region
|
| 174 |
+
b = np.where(
|
| 175 |
+
blue_mask,
|
| 176 |
+
np.clip(b + (255 - b) * alpha, 0, 255),
|
| 177 |
+
b,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Ensure proper dtype
|
| 181 |
+
b = b.astype(np.uint8)
|
| 182 |
+
g = g.astype(np.uint8)
|
| 183 |
+
r = r.astype(np.uint8)
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
output_img = cv2.merge((b, g, r))
|
| 186 |
|
| 187 |
+
# Fade non-masked area toward white
|
| 188 |
white_img = np.full_like(output_img, 255, dtype=np.uint8)
|
| 189 |
+
non_masked_area_3d = non_masked_area[
|
| 190 |
+
..., None
|
| 191 |
+
] # (H, W, 1) -> broadcast to (H, W, 3)
|
| 192 |
+
|
| 193 |
+
faded = cv2.addWeighted(output_img, 1 - white_alpha, white_img, white_alpha, 0)
|
| 194 |
+
output_img = np.where(non_masked_area_3d, faded, output_img)
|
| 195 |
|
| 196 |
return output_img
|
| 197 |
|
| 198 |
|
| 199 |
def extract_object_subject_jax(img, red_mask, blue_mask, alpha=0.5, white_alpha=0.8):
|
| 200 |
"""JAX version of extract_object_subject that works with JAX arrays.
|
| 201 |
+
|
| 202 |
Args:
|
| 203 |
img: JAX array of shape (height, width, 3) in BGR format
|
| 204 |
red_mask: JAX array of shape (height, width, 1) or (height, width)
|
| 205 |
blue_mask: JAX array of shape (height, width, 1) or (height, width)
|
| 206 |
alpha: float between 0 and 1 for color highlighting
|
| 207 |
white_alpha: float between 0 and 1 for background blending
|
| 208 |
+
|
| 209 |
Returns:
|
| 210 |
JAX array of shape (height, width, 3) in BGR format with uint8 dtype
|
| 211 |
"""
|
| 212 |
# Convert input image to float32 for calculations
|
| 213 |
img = img.astype(jnp.float32)
|
| 214 |
+
|
| 215 |
# Ensure the masks are binary (0 or 1)
|
| 216 |
red_mask = red_mask.astype(bool)
|
| 217 |
blue_mask = blue_mask.astype(bool)
|
|
|
|
| 228 |
r = img[..., 2] # Red channel
|
| 229 |
|
| 230 |
# Adjust the red channel based on the red mask
|
| 231 |
+
r = jnp.where(red_mask[..., 0], jnp.clip(r + (255 - r) * alpha, 0, 255), r)
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# Adjust the blue channel based on the blue mask
|
| 234 |
+
b = jnp.where(blue_mask[..., 0], jnp.clip(b + (255 - b) * alpha, 0, 255), b)
|
|
|
|
|
|
|
| 235 |
|
| 236 |
# Stack the channels back together
|
| 237 |
output_img = jnp.stack([b, g, r], axis=-1)
|
| 238 |
|
| 239 |
# Create white background and blend
|
| 240 |
white_img = jnp.full_like(output_img, 255.0, dtype=jnp.float32)
|
| 241 |
+
output_img = jnp.where(
|
| 242 |
+
non_masked_area,
|
| 243 |
+
output_img * (1 - white_alpha) + white_img * white_alpha,
|
| 244 |
+
output_img,
|
| 245 |
+
)
|
| 246 |
|
| 247 |
# Round to nearest integer and cast to uint8
|
| 248 |
output_img = jnp.round(output_img)
|
| 249 |
return output_img.astype(jnp.uint8)
|
| 250 |
|
| 251 |
+
|
| 252 |
+
def increase_brightness_draw_outer_edge(
|
| 253 |
+
img, bbox_ls, alpha=0.2, colormap_name="Set1", thickness=2
|
| 254 |
+
):
|
| 255 |
if isinstance(img, torch.Tensor):
|
| 256 |
img = img.cpu().numpy().astype(np.uint8)
|
| 257 |
else:
|
| 258 |
img = img.astype(np.uint8)
|
| 259 |
height, width, _ = img.shape
|
| 260 |
+
white_img = np.zeros([height, width, 3], dtype=np.uint8)
|
| 261 |
+
white_img.fill(255) # or img[:] = 255
|
| 262 |
|
| 263 |
+
output_img = cv2.addWeighted(img, alpha, white_img, 1 - alpha, 0)
|
| 264 |
colormap = plt.colormaps[colormap_name]
|
| 265 |
|
| 266 |
for bbox_id, (x1, y1, x2, y2) in enumerate(bbox_ls):
|
| 267 |
output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
|
| 268 |
+
color = [c * 255 for c in mpl.colors.to_rgb(colormap(bbox_id))]
|
| 269 |
# print(f"color: {color}")
|
| 270 |
output_img = cv2.rectangle(output_img, (x1, y1), (x2, y2), color, thickness)
|
| 271 |
|
| 272 |
return torch.tensor(output_img, dtype=torch.float32)
|
| 273 |
|
| 274 |
+
|
| 275 |
def get_print_hook(name):
|
| 276 |
def print_hook(grad):
|
| 277 |
print(f"{name}: \n {grad} \n")
|
| 278 |
return grad
|
| 279 |
+
|
| 280 |
return print_hook
|
| 281 |
|
| 282 |
+
|
| 283 |
def segment_list(l, n=5):
|
| 284 |
current_seg = []
|
| 285 |
all_segs = []
|
|
|
|
| 295 |
|
| 296 |
return all_segs
|
| 297 |
|
| 298 |
+
|
| 299 |
def get_tensor_size(a):
|
| 300 |
return a.element_size() * a.nelement()
|
| 301 |
|
| 302 |
+
|
| 303 |
def comp_diff(v1, v2):
|
| 304 |
return 2 * torch.abs(v1 - v2) / (v1 + v2)
|
| 305 |
|
| 306 |
+
|
| 307 |
def gather_names(pred_res):
|
| 308 |
all_names = set()
|
| 309 |
for name, _ in pred_res:
|
| 310 |
all_names.add(name)
|
| 311 |
return list(all_names)
|
| 312 |
|
| 313 |
+
|
| 314 |
def extract_nl_feats(tokenizer, model, names, device):
|
| 315 |
if len(names) == 0:
|
| 316 |
features = []
|
|
|
|
| 319 |
features = model.get_text_features(**name_tokens)
|
| 320 |
return features
|
| 321 |
|
| 322 |
+
|
| 323 |
+
def extract_all_nl_feats(
|
| 324 |
+
tokenizer,
|
| 325 |
+
model,
|
| 326 |
+
batch_size,
|
| 327 |
+
batched_names,
|
| 328 |
+
batched_unary_kws,
|
| 329 |
+
batched_binary_kws,
|
| 330 |
+
device,
|
| 331 |
+
):
|
| 332 |
batched_obj_name_features = [[] for _ in range(batch_size)]
|
| 333 |
batched_unary_nl_features = [[] for _ in range(batch_size)]
|
| 334 |
batched_binary_nl_features = [[] for _ in range(batch_size)]
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
+
for vid, (object_names, unary_kws, binary_kws) in enumerate(
|
| 337 |
+
zip(batched_names, batched_unary_kws, batched_binary_kws)
|
| 338 |
+
):
|
| 339 |
obj_name_features = extract_nl_feats(tokenizer, model, object_names, device)
|
| 340 |
batched_obj_name_features[vid] = obj_name_features
|
| 341 |
|
|
|
|
| 345 |
binary_features = extract_nl_feats(tokenizer, model, binary_kws, device)
|
| 346 |
batched_binary_nl_features[vid] = binary_features
|
| 347 |
|
| 348 |
+
return (
|
| 349 |
+
batched_obj_name_features,
|
| 350 |
+
batched_unary_nl_features,
|
| 351 |
+
batched_binary_nl_features,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
|
| 355 |
+
def single_object_crop(
|
| 356 |
+
batch_size, batched_videos, batched_object_ids, batched_bboxes, batched_video_splits
|
| 357 |
+
):
|
| 358 |
batched_frame_bboxes = {}
|
| 359 |
batched_cropped_objs = [[] for _ in range(batch_size)]
|
| 360 |
|
| 361 |
for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
|
| 362 |
overall_frame_id = batched_video_splits[video_id] + frame_id
|
| 363 |
if type(bbox) == dict:
|
| 364 |
+
bx1, by1, bx2, by2 = bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]
|
| 365 |
else:
|
| 366 |
bx1, by1, bx2, by2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
| 367 |
|
| 368 |
assert by2 > by1
|
| 369 |
assert bx2 > bx1
|
| 370 |
+
batched_cropped_objs[video_id].append(
|
| 371 |
+
(batched_videos[overall_frame_id][by1:by2, bx1:bx2])
|
| 372 |
+
)
|
| 373 |
batched_frame_bboxes[video_id, frame_id, obj_id] = (bx1, by1, bx2, by2)
|
| 374 |
|
| 375 |
return batched_cropped_objs, batched_frame_bboxes
|
vine_hf/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/__init__.cpython-310.pyc and b/vine_hf/__pycache__/__init__.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/flattening.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/flattening.cpython-310.pyc and b/vine_hf/__pycache__/flattening.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/vine_config.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/vine_config.cpython-310.pyc and b/vine_hf/__pycache__/vine_config.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/vine_model.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/vine_model.cpython-310.pyc and b/vine_hf/__pycache__/vine_model.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/vine_pipeline.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/vine_pipeline.cpython-310.pyc and b/vine_hf/__pycache__/vine_pipeline.cpython-310.pyc differ
|
|
|
vine_hf/__pycache__/vis_utils.cpython-310.pyc
CHANGED
|
Binary files a/vine_hf/__pycache__/vis_utils.cpython-310.pyc and b/vine_hf/__pycache__/vis_utils.cpython-310.pyc differ
|
|
|
vine_hf/vine_pipeline.py
CHANGED
|
@@ -586,8 +586,17 @@ class VinePipeline(Pipeline):
|
|
| 586 |
import subprocess
|
| 587 |
|
| 588 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
ffmpeg_cmd = [
|
| 590 |
-
|
| 591 |
"-y",
|
| 592 |
"-f",
|
| 593 |
"rawvideo",
|
|
@@ -657,6 +666,10 @@ class VinePipeline(Pipeline):
|
|
| 657 |
out = None
|
| 658 |
used_codec = None
|
| 659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
for codec in codecs_to_try:
|
| 661 |
try:
|
| 662 |
fourcc = cv2.VideoWriter_fourcc(*codec)
|
|
@@ -679,19 +692,37 @@ class VinePipeline(Pipeline):
|
|
| 679 |
|
| 680 |
print(f"Using OpenCV with codec: {used_codec}")
|
| 681 |
|
|
|
|
| 682 |
for frame in video_tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
if len(frame.shape) == 3 and frame.shape[2] == 3:
|
| 684 |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 685 |
else:
|
| 686 |
frame_bgr = frame
|
|
|
|
| 687 |
if frame_bgr.dtype != np.uint8:
|
| 688 |
frame_bgr = (
|
| 689 |
(frame_bgr * 255).astype(np.uint8)
|
| 690 |
if frame_bgr.max() <= 1
|
| 691 |
else frame_bgr.astype(np.uint8)
|
| 692 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
out.write(frame_bgr)
|
|
|
|
| 694 |
|
|
|
|
| 695 |
out.release()
|
| 696 |
return temp_path
|
| 697 |
|
|
|
|
| 586 |
import subprocess
|
| 587 |
|
| 588 |
try:
|
| 589 |
+
# Try to get FFmpeg from imageio-ffmpeg first, then fall back to system FFmpeg
|
| 590 |
+
try:
|
| 591 |
+
import imageio_ffmpeg
|
| 592 |
+
ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
|
| 593 |
+
print(f"Using FFmpeg from imageio-ffmpeg: {ffmpeg_exe}")
|
| 594 |
+
except ImportError:
|
| 595 |
+
ffmpeg_exe = "ffmpeg"
|
| 596 |
+
print("Using system FFmpeg")
|
| 597 |
+
|
| 598 |
ffmpeg_cmd = [
|
| 599 |
+
ffmpeg_exe,
|
| 600 |
"-y",
|
| 601 |
"-f",
|
| 602 |
"rawvideo",
|
|
|
|
| 666 |
out = None
|
| 667 |
used_codec = None
|
| 668 |
|
| 669 |
+
# Debug: Print video tensor info
|
| 670 |
+
print(f"DEBUG: video_tensor shape: {video_tensor.shape}, dtype: {video_tensor.dtype}")
|
| 671 |
+
print(f"DEBUG: Expected dimensions - width: {width}, height: {height}, fps: {fps}")
|
| 672 |
+
|
| 673 |
for codec in codecs_to_try:
|
| 674 |
try:
|
| 675 |
fourcc = cv2.VideoWriter_fourcc(*codec)
|
|
|
|
| 692 |
|
| 693 |
print(f"Using OpenCV with codec: {used_codec}")
|
| 694 |
|
| 695 |
+
frame_count = 0
|
| 696 |
for frame in video_tensor:
|
| 697 |
+
# Debug: Print first frame info
|
| 698 |
+
if frame_count == 0:
|
| 699 |
+
print(f"DEBUG: First frame shape: {frame.shape}, dtype: {frame.dtype}")
|
| 700 |
+
print(f"DEBUG: First frame min: {frame.min()}, max: {frame.max()}, mean: {frame.mean()}")
|
| 701 |
+
|
| 702 |
if len(frame.shape) == 3 and frame.shape[2] == 3:
|
| 703 |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 704 |
else:
|
| 705 |
frame_bgr = frame
|
| 706 |
+
|
| 707 |
if frame_bgr.dtype != np.uint8:
|
| 708 |
frame_bgr = (
|
| 709 |
(frame_bgr * 255).astype(np.uint8)
|
| 710 |
if frame_bgr.max() <= 1
|
| 711 |
else frame_bgr.astype(np.uint8)
|
| 712 |
)
|
| 713 |
+
|
| 714 |
+
# Debug: Check if frame dimensions match VideoWriter expectations
|
| 715 |
+
if frame_count == 0:
|
| 716 |
+
print(f"DEBUG: After conversion - frame_bgr shape: {frame_bgr.shape}, dtype: {frame_bgr.dtype}")
|
| 717 |
+
print(f"DEBUG: After conversion - min: {frame_bgr.min()}, max: {frame_bgr.max()}")
|
| 718 |
+
actual_height, actual_width = frame_bgr.shape[:2]
|
| 719 |
+
if actual_height != height or actual_width != width:
|
| 720 |
+
print(f"WARNING: Frame size mismatch! Expected ({height}, {width}), got ({actual_height}, {actual_width})")
|
| 721 |
+
|
| 722 |
out.write(frame_bgr)
|
| 723 |
+
frame_count += 1
|
| 724 |
|
| 725 |
+
print(f"DEBUG: Wrote {frame_count} frames to video")
|
| 726 |
out.release()
|
| 727 |
return temp_path
|
| 728 |
|
vine_hf/vis_utils.py
CHANGED
|
@@ -54,10 +54,12 @@ from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
|
|
| 54 |
# All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
|
| 55 |
########################################################################################
|
| 56 |
|
|
|
|
| 57 |
def clean_label(label):
|
| 58 |
"""Replace underscores and slashes with spaces for uniformity."""
|
| 59 |
return label.replace("_", " ").replace("/", " ")
|
| 60 |
|
|
|
|
| 61 |
# Should be performed somewhere else I believe
|
| 62 |
def format_cate_preds(cate_preds):
|
| 63 |
# Group object predictions from the model output.
|
|
@@ -72,6 +74,7 @@ def format_cate_preds(cate_preds):
|
|
| 72 |
obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
|
| 73 |
return obj_pred_dict
|
| 74 |
|
|
|
|
| 75 |
def format_binary_cate_preds(binary_preds):
|
| 76 |
frame_binary_preds = []
|
| 77 |
for key, score in binary_preds.items():
|
|
@@ -85,6 +88,7 @@ def format_binary_cate_preds(binary_preds):
|
|
| 85 |
frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
|
| 86 |
return frame_binary_preds
|
| 87 |
|
|
|
|
| 88 |
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 89 |
|
| 90 |
|
|
@@ -106,7 +110,9 @@ def _to_numpy_mask(mask: Union[np.ndarray, torch.Tensor, None]) -> Optional[np.n
|
|
| 106 |
return mask_np > 0
|
| 107 |
|
| 108 |
|
| 109 |
-
def _sanitize_bbox(
|
|
|
|
|
|
|
| 110 |
if bbox is None:
|
| 111 |
return None
|
| 112 |
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
|
|
@@ -164,7 +170,16 @@ def _draw_label_block(
|
|
| 164 |
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 165 |
text_x = left_x + 4
|
| 166 |
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 167 |
-
cv2.putText(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
y_cursor = bottom_y
|
| 169 |
else:
|
| 170 |
for text in lines:
|
|
@@ -177,7 +192,16 @@ def _draw_label_block(
|
|
| 177 |
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 178 |
text_x = left_x + 4
|
| 179 |
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 180 |
-
cv2.putText(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
y_cursor = top_y
|
| 182 |
|
| 183 |
|
|
@@ -198,13 +222,26 @@ def _draw_centered_label(
|
|
| 198 |
top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
|
| 199 |
right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
|
| 200 |
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
|
| 201 |
-
cv2.rectangle(
|
|
|
|
|
|
|
| 202 |
text_x = left_x + 4
|
| 203 |
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 204 |
-
cv2.putText(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
-
def _extract_frame_entities(
|
|
|
|
|
|
|
| 208 |
if isinstance(store, dict):
|
| 209 |
frame_entry = store.get(frame_idx, {})
|
| 210 |
elif isinstance(store, list) and 0 <= frame_idx < len(store):
|
|
@@ -271,7 +308,9 @@ def render_sam_frames(
|
|
| 271 |
continue
|
| 272 |
color = _object_color_bgr(obj_id)
|
| 273 |
alpha = 0.45
|
| 274 |
-
overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(
|
|
|
|
|
|
|
| 275 |
|
| 276 |
annotated = np.clip(overlay, 0, 255).astype(np.uint8)
|
| 277 |
frame_h, frame_w = annotated.shape[:2]
|
|
@@ -329,7 +368,9 @@ def render_vine_frame_sets(
|
|
| 329 |
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 330 |
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 331 |
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 332 |
-
masks: Union[
|
|
|
|
|
|
|
| 333 |
binary_confidence_threshold: float = 0.0,
|
| 334 |
) -> Dict[str, List[np.ndarray]]:
|
| 335 |
frame_groups: Dict[str, List[np.ndarray]] = {
|
|
@@ -347,7 +388,9 @@ def render_vine_frame_sets(
|
|
| 347 |
base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 348 |
frame_h, frame_w = base_bgr.shape[:2]
|
| 349 |
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
|
| 350 |
-
frame_masks =
|
|
|
|
|
|
|
| 351 |
|
| 352 |
objects_bgr = base_bgr.copy()
|
| 353 |
unary_bgr = base_bgr.copy()
|
|
@@ -393,16 +436,36 @@ def render_vine_frame_sets(
|
|
| 393 |
for obj_id, bbox in bbox_lookup.items():
|
| 394 |
title = titles_lookup.get(obj_id)
|
| 395 |
unary_lines = unary_lines_lookup.get(obj_id, [])
|
| 396 |
-
_draw_bbox_with_label(
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
if unary_lines:
|
| 399 |
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 400 |
-
_draw_label_block(
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
if unary_lines:
|
| 404 |
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 405 |
-
_draw_label_block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
# First pass: collect all pairs above threshold and deduplicate bidirectional pairs
|
| 408 |
pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation)
|
|
@@ -432,15 +495,24 @@ def render_vine_frame_sets(
|
|
| 432 |
subj_bbox = bbox_lookup.get(subj_id)
|
| 433 |
obj_bbox = bbox_lookup.get(obj_id)
|
| 434 |
start, end = relation_line(subj_bbox, obj_bbox)
|
| 435 |
-
color = tuple(
|
| 436 |
-
(
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
label_text = f"{relation} {prob:.2f}"
|
| 441 |
mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
|
| 442 |
# Draw arrowed lines showing direction from subject to object (smaller arrow tip)
|
| 443 |
-
cv2.arrowedLine(
|
|
|
|
|
|
|
| 444 |
cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
|
| 445 |
_draw_centered_label(binary_bgr, label_text, mid_point, color)
|
| 446 |
_draw_centered_label(all_bgr, label_text, mid_point, color)
|
|
@@ -459,7 +531,9 @@ def render_vine_frames(
|
|
| 459 |
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 460 |
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 461 |
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 462 |
-
masks: Union[
|
|
|
|
|
|
|
| 463 |
binary_confidence_threshold: float = 0.0,
|
| 464 |
) -> List[np.ndarray]:
|
| 465 |
return render_vine_frame_sets(
|
|
@@ -471,11 +545,12 @@ def render_vine_frames(
|
|
| 471 |
masks,
|
| 472 |
binary_confidence_threshold,
|
| 473 |
).get("all", [])
|
| 474 |
-
|
|
|
|
| 475 |
def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
|
| 476 |
all_colors = []
|
| 477 |
all_texts = []
|
| 478 |
-
for
|
| 479 |
preds = obj_pred_dict.get(obj_id, [])
|
| 480 |
if len(preds) == 0:
|
| 481 |
top1 = "N/A"
|
|
@@ -485,143 +560,214 @@ def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
|
|
| 485 |
topk_labels = [p[0] for p in preds[:topk_object]]
|
| 486 |
# Compare cleaned labels.
|
| 487 |
if top1.lower() == gt_label.lower():
|
| 488 |
-
box_color = (0, 255, 0)
|
| 489 |
elif gt_label.lower() in [p.lower() for p in topk_labels]:
|
| 490 |
-
box_color = (0, 165, 255)
|
| 491 |
else:
|
| 492 |
-
box_color = (0, 0, 255)
|
| 493 |
-
|
| 494 |
label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
|
| 495 |
all_colors.append(box_color)
|
| 496 |
all_texts.append(label_text)
|
| 497 |
return all_colors, all_texts
|
| 498 |
|
|
|
|
| 499 |
def plot_unary(frame_img, gt_labels, all_colors, all_texts):
|
| 500 |
-
|
| 501 |
-
|
|
|
|
| 502 |
x1, y1, x2, y2 = map(int, bbox)
|
| 503 |
cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
|
| 504 |
-
(tw, th), baseline = cv2.getTextSize(
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
return frame_img
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
| 519 |
white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
|
| 520 |
-
|
| 521 |
# --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
|
| 522 |
left_width = int(pane_width * 0.6)
|
| 523 |
right_width = pane_width - left_width
|
| 524 |
left_pane = white_pane[:, :left_width, :].copy()
|
| 525 |
right_pane = white_pane[:, left_width:, :].copy()
|
| 526 |
-
|
| 527 |
-
cv2.putText(
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
return white_pane
|
| 533 |
|
|
|
|
| 534 |
# This is for ploting binary prediction results with frame-based scene graphs
|
| 535 |
-
def plot_binary_sg(
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
|
|
|
|
|
|
| 544 |
line_height = 30 # vertical spacing per line
|
| 545 |
-
x_text = 10
|
| 546 |
y_text_left = header_height + 10 # starting y for left pane text
|
| 547 |
-
y_text_right = header_height + 10
|
| 548 |
-
|
| 549 |
# Left section: top-k binary predictions.
|
| 550 |
left_width = int(pane_width * 0.6)
|
| 551 |
right_width = pane_width - left_width
|
| 552 |
left_pane = white_pane[:, :left_width, :].copy()
|
| 553 |
right_pane = white_pane[:, left_width:, :].copy()
|
| 554 |
-
|
| 555 |
-
for
|
| 556 |
-
correct = any(
|
| 557 |
-
|
|
|
|
|
|
|
| 558 |
indicator_color = (0, 255, 0) if correct else (0, 0, 255)
|
| 559 |
-
cv2.rectangle(
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
|
| 562 |
-
cv2.putText(
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
y_text_left += line_height
|
| 565 |
-
|
| 566 |
# Right section: ground truth binary relations.
|
| 567 |
for gt in gt_relations:
|
| 568 |
if len(gt) != 3:
|
| 569 |
continue
|
| 570 |
text = f"{gt[0]} - {gt[2]} - {gt[1]}"
|
| 571 |
-
cv2.putText(
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
y_text_right += line_height
|
| 574 |
-
|
| 575 |
# Combine the two text panes and then with the frame image.
|
| 576 |
combined_pane = np.hstack((left_pane, right_pane))
|
| 577 |
combined_image = np.hstack((frame_img, combined_pane))
|
| 578 |
return combined_image
|
| 579 |
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
|
|
|
|
|
|
| 591 |
"""Return the combined annotated frame for frame index i as an image (in BGR)."""
|
| 592 |
# Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
|
| 593 |
|
| 594 |
# --- Process Object Predictions (for overlaying bboxes) ---
|
| 595 |
if phase == "unary":
|
| 596 |
objs = []
|
| 597 |
-
for (
|
| 598 |
gt_label = clean_label(gt_label)
|
| 599 |
objs.append((obj_id, bbox, gt_label))
|
| 600 |
-
|
| 601 |
formatted_cate_preds = format_cate_preds(cate_preds)
|
| 602 |
-
all_colors, all_texts = color_for_cate_correctness(
|
|
|
|
|
|
|
| 603 |
updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
|
| 604 |
return updated_frame_img
|
| 605 |
-
|
| 606 |
else:
|
| 607 |
# --- Process Binary Predictions & Ground Truth for the Text Pane ---
|
| 608 |
formatted_binary_preds = format_binary_cate_preds(binary_preds)
|
| 609 |
-
|
| 610 |
# Ground truth binary relations for the frame.
|
| 611 |
# Clean ground truth relations.
|
| 612 |
-
gt_relations = [
|
| 613 |
-
|
|
|
|
|
|
|
|
|
|
| 614 |
pane_width = 600 # increased pane width for more horizontal space
|
| 615 |
pane_height = frame_img.shape[0]
|
| 616 |
-
|
| 617 |
# --- Add header labels to each text pane with extra space ---
|
| 618 |
header_height = 50 # increased header space
|
| 619 |
-
white_pane = get_white_pane(
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
return combined_image
|
| 624 |
|
|
|
|
| 625 |
def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
| 626 |
# Ensure mask is a numpy array
|
| 627 |
mask = np.array(mask)
|
|
@@ -644,7 +790,7 @@ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
|
| 644 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 645 |
color[3] = 0.5
|
| 646 |
color = np.array(color)
|
| 647 |
-
|
| 648 |
# Expand mask to (H, W, 1) for broadcasting
|
| 649 |
mask_expanded = mask[..., None]
|
| 650 |
mask_image = mask_expanded * color.reshape(1, 1, -1)
|
|
@@ -663,7 +809,7 @@ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
|
| 663 |
linewidth=1.5,
|
| 664 |
edgecolor=color[:3],
|
| 665 |
facecolor="none",
|
| 666 |
-
alpha=color[3]
|
| 667 |
)
|
| 668 |
ax.add_patch(rect)
|
| 669 |
ax.text(
|
|
@@ -673,10 +819,11 @@ def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
|
| 673 |
color="white",
|
| 674 |
fontsize=6,
|
| 675 |
backgroundcolor=np.array(color),
|
| 676 |
-
alpha=1
|
| 677 |
)
|
| 678 |
ax.imshow(mask_image)
|
| 679 |
|
|
|
|
| 680 |
def save_mask_one_image(frame_image, masks, save_path):
|
| 681 |
"""Render masks on top of a frame and store the visualization on disk."""
|
| 682 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
|
@@ -695,9 +842,7 @@ def save_mask_one_image(frame_image, masks, save_path):
|
|
| 695 |
|
| 696 |
prepared_masks = {
|
| 697 |
obj_id: (
|
| 698 |
-
mask.detach().cpu().numpy()
|
| 699 |
-
if torch.is_tensor(mask)
|
| 700 |
-
else np.asarray(mask)
|
| 701 |
)
|
| 702 |
for obj_id, mask in mask_iter
|
| 703 |
}
|
|
@@ -711,54 +856,61 @@ def save_mask_one_image(frame_image, masks, save_path):
|
|
| 711 |
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 712 |
plt.close(fig)
|
| 713 |
return save_path
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
|
|
|
|
|
|
| 722 |
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 723 |
if not os.path.exists(video_save_dir):
|
| 724 |
os.makedirs(video_save_dir, exist_ok=True)
|
| 725 |
-
|
| 726 |
for frame_id, image in enumerate(video_tensor):
|
| 727 |
if frame_id not in video_masks:
|
| 728 |
print("No mask for Frame", frame_id)
|
| 729 |
continue
|
| 730 |
-
|
| 731 |
masks = video_masks[frame_id]
|
| 732 |
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 733 |
get_mask_one_image(image, masks, oid_class_pred)
|
| 734 |
|
|
|
|
| 735 |
def get_mask_one_image(frame_image, masks, oid_class_pred=None):
|
| 736 |
# Create a figure and axis
|
| 737 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 738 |
|
| 739 |
# Display the frame image
|
| 740 |
ax.imshow(frame_image)
|
| 741 |
-
ax.axis(
|
| 742 |
|
| 743 |
if type(masks) == list:
|
| 744 |
masks = {i: m for i, m in enumerate(masks)}
|
| 745 |
-
|
| 746 |
# Add the masks
|
| 747 |
for obj_id, mask in masks.items():
|
| 748 |
-
det_class =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
|
| 750 |
|
| 751 |
# Show the plot
|
| 752 |
return fig, ax
|
| 753 |
|
|
|
|
| 754 |
def save_video(frames, output_filename, output_fps):
|
| 755 |
-
|
| 756 |
# --- Create a video from all frames ---
|
| 757 |
num_frames = len(frames)
|
| 758 |
frame_h, frame_w = frames.shape[:2]
|
| 759 |
|
| 760 |
# Use a codec supported by VS Code (H.264 via 'avc1').
|
| 761 |
-
fourcc = cv2.VideoWriter_fourcc(*
|
| 762 |
out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
|
| 763 |
|
| 764 |
print(f"Processing {num_frames} frames...")
|
|
@@ -766,23 +918,26 @@ def save_video(frames, output_filename, output_fps):
|
|
| 766 |
vis_frame = get_visualized_frame(i)
|
| 767 |
out.write(vis_frame)
|
| 768 |
if i % 10 == 0:
|
| 769 |
-
print(f"Processed frame {i+1}/{num_frames}")
|
| 770 |
|
| 771 |
out.release()
|
| 772 |
print(f"Video saved as {output_filename}")
|
| 773 |
-
|
| 774 |
|
| 775 |
def list_depth(lst):
|
| 776 |
"""Calculates the depth of a nested list."""
|
| 777 |
if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
|
| 778 |
return 0
|
| 779 |
-
elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (
|
|
|
|
|
|
|
| 780 |
return 1
|
| 781 |
else:
|
| 782 |
return 1 + max(list_depth(item) for item in lst)
|
| 783 |
-
|
|
|
|
| 784 |
def normalize_prompt(points, labels):
|
| 785 |
-
if list_depth(points) == 3:
|
| 786 |
points = torch.stack([p.unsqueeze(0) for p in points])
|
| 787 |
labels = torch.stack([l.unsqueeze(0) for l in labels])
|
| 788 |
return points, labels
|
|
@@ -791,36 +946,56 @@ def normalize_prompt(points, labels):
|
|
| 791 |
def show_box(box, ax, object_id):
|
| 792 |
if len(box) == 0:
|
| 793 |
return
|
| 794 |
-
|
| 795 |
cmap = plt.get_cmap("gist_rainbow")
|
| 796 |
cmap_idx = 0 if object_id is None else object_id
|
| 797 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 798 |
-
|
| 799 |
x0, y0 = box[0], box[1]
|
| 800 |
w, h = box[2] - box[0], box[3] - box[1]
|
| 801 |
-
ax.add_patch(
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
| 803 |
def show_points(coords, labels, ax, object_id=None, marker_size=375):
|
| 804 |
if len(labels) == 0:
|
| 805 |
return
|
| 806 |
-
|
| 807 |
-
pos_points = coords[labels==1]
|
| 808 |
-
neg_points = coords[labels==0]
|
| 809 |
-
|
| 810 |
cmap = plt.get_cmap("gist_rainbow")
|
| 811 |
cmap_idx = 0 if object_id is None else object_id
|
| 812 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 813 |
-
|
| 814 |
-
ax.scatter(
|
| 815 |
-
|
| 816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
|
| 818 |
# Create a figure and axis
|
| 819 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 820 |
|
| 821 |
# Display the frame image
|
| 822 |
ax.imshow(frame_image)
|
| 823 |
-
ax.axis(
|
| 824 |
|
| 825 |
points, labels = normalize_prompt(points, labels)
|
| 826 |
if type(boxes) == torch.Tensor:
|
|
@@ -837,40 +1012,50 @@ def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
|
|
| 837 |
pass
|
| 838 |
else:
|
| 839 |
raise Exception()
|
| 840 |
-
|
| 841 |
for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
|
| 842 |
if not len(point_ls) == 0:
|
| 843 |
show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
|
| 844 |
-
|
| 845 |
# Show the plot
|
| 846 |
plt.savefig(save_path)
|
| 847 |
plt.close()
|
| 848 |
-
|
| 849 |
-
|
|
|
|
|
|
|
|
|
|
| 850 |
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 851 |
if not os.path.exists(video_save_dir):
|
| 852 |
os.makedirs(video_save_dir, exist_ok=True)
|
| 853 |
-
|
| 854 |
for frame_id, image in enumerate(video_tensor):
|
| 855 |
boxes, points, labels = [], [], []
|
| 856 |
-
|
| 857 |
if frame_id in video_boxes:
|
| 858 |
boxes = video_boxes[frame_id]
|
| 859 |
-
|
| 860 |
if frame_id in video_points:
|
| 861 |
points = video_points[frame_id]
|
| 862 |
if frame_id in video_labels:
|
| 863 |
labels = video_labels[frame_id]
|
| 864 |
-
|
| 865 |
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 866 |
save_prompts_one_image(image, boxes, points, labels, save_path)
|
| 867 |
-
|
| 868 |
|
| 869 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 871 |
if not os.path.exists(video_save_dir):
|
| 872 |
os.makedirs(video_save_dir, exist_ok=True)
|
| 873 |
-
|
| 874 |
for frame_id, image in enumerate(video_tensor):
|
| 875 |
if random.random() > sample_rate:
|
| 876 |
continue
|
|
@@ -880,18 +1065,17 @@ def save_video_masks_visualization(video_tensor, video_masks, video_id, video_sa
|
|
| 880 |
masks = video_masks[frame_id]
|
| 881 |
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 882 |
save_mask_one_image(image, masks, save_path)
|
| 883 |
-
|
| 884 |
|
| 885 |
|
| 886 |
-
def get_color(obj_id, cmap_name="gist_rainbow",alpha=0.5):
|
| 887 |
cmap = plt.get_cmap(cmap_name)
|
| 888 |
cmap_idx = 0 if obj_id is None else obj_id
|
| 889 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 890 |
color[3] = 0.5
|
| 891 |
color = np.array(color)
|
| 892 |
return color
|
| 893 |
-
|
| 894 |
-
|
| 895 |
def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
| 896 |
return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
|
| 897 |
|
|
@@ -906,7 +1090,9 @@ def relation_line(
|
|
| 906 |
"""
|
| 907 |
center1 = _bbox_center(bbox1)
|
| 908 |
center2 = _bbox_center(bbox2)
|
| 909 |
-
if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(
|
|
|
|
|
|
|
| 910 |
offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
|
| 911 |
center2 = (center2[0] + offset, center2[1])
|
| 912 |
start = (int(round(center1[0])), int(round(center1[1])))
|
|
@@ -915,57 +1101,68 @@ def relation_line(
|
|
| 915 |
end = (end[0] + 1, end[1])
|
| 916 |
return start, end
|
| 917 |
|
|
|
|
| 918 |
def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
|
| 919 |
# Create a figure and axis
|
| 920 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 921 |
|
| 922 |
# Display the frame image
|
| 923 |
ax.imshow(frame_image)
|
| 924 |
-
ax.axis(
|
| 925 |
-
|
| 926 |
all_objs_to_show = set()
|
| 927 |
all_lines_to_show = []
|
| 928 |
-
|
| 929 |
# print(rel_pred_ls[0])
|
| 930 |
for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
|
| 931 |
-
all_objs_to_show.add(from_obj_id)
|
| 932 |
-
all_objs_to_show.add(to_obj_id)
|
| 933 |
-
|
| 934 |
from_mask = masks[from_obj_id]
|
| 935 |
bbox1 = mask_to_bbox(from_mask)
|
| 936 |
to_mask = masks[to_obj_id]
|
| 937 |
bbox2 = mask_to_bbox(to_mask)
|
| 938 |
-
|
| 939 |
c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
|
| 940 |
-
|
| 941 |
line_color = get_color(from_obj_id)
|
| 942 |
face_color = get_color(to_obj_id)
|
| 943 |
line = c1, c2, face_color, line_color, rel_text
|
| 944 |
all_lines_to_show.append(line)
|
| 945 |
-
|
| 946 |
masks_to_show = {}
|
| 947 |
for oid in all_objs_to_show:
|
| 948 |
masks_to_show[oid] = masks[oid]
|
| 949 |
-
|
| 950 |
# Add the masks
|
| 951 |
for obj_id, mask in masks_to_show.items():
|
| 952 |
show_mask(mask, ax, obj_id=obj_id, random_color=False)
|
| 953 |
|
| 954 |
-
for (from_pt_x, from_pt_y), (
|
| 955 |
-
|
| 956 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
mid_pt_x = (from_pt_x + to_pt_x) / 2
|
| 958 |
mid_pt_y = (from_pt_y + to_pt_y) / 2
|
| 959 |
ax.text(
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
)
|
| 969 |
-
|
|
|
|
|
|
|
| 970 |
# Show the plot
|
| 971 |
return fig, ax
|
|
|
|
| 54 |
# All rendered frames returned by functions are RGB np.ndarray images suitable for saving or video writing.
|
| 55 |
########################################################################################
|
| 56 |
|
| 57 |
+
|
| 58 |
def clean_label(label):
|
| 59 |
"""Replace underscores and slashes with spaces for uniformity."""
|
| 60 |
return label.replace("_", " ").replace("/", " ")
|
| 61 |
|
| 62 |
+
|
| 63 |
# Should be performed somewhere else I believe
|
| 64 |
def format_cate_preds(cate_preds):
|
| 65 |
# Group object predictions from the model output.
|
|
|
|
| 74 |
obj_pred_dict[oid].sort(key=lambda x: x[1], reverse=True)
|
| 75 |
return obj_pred_dict
|
| 76 |
|
| 77 |
+
|
| 78 |
def format_binary_cate_preds(binary_preds):
|
| 79 |
frame_binary_preds = []
|
| 80 |
for key, score in binary_preds.items():
|
|
|
|
| 88 |
frame_binary_preds.sort(key=lambda x: x[3], reverse=True)
|
| 89 |
return frame_binary_preds
|
| 90 |
|
| 91 |
+
|
| 92 |
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 93 |
|
| 94 |
|
|
|
|
| 110 |
return mask_np > 0
|
| 111 |
|
| 112 |
|
| 113 |
+
def _sanitize_bbox(
|
| 114 |
+
bbox: Union[List[float], Tuple[float, ...], None], width: int, height: int
|
| 115 |
+
) -> Optional[Tuple[int, int, int, int]]:
|
| 116 |
if bbox is None:
|
| 117 |
return None
|
| 118 |
if isinstance(bbox, (list, tuple)) and len(bbox) >= 4:
|
|
|
|
| 170 |
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 171 |
text_x = left_x + 4
|
| 172 |
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 173 |
+
cv2.putText(
|
| 174 |
+
image,
|
| 175 |
+
text,
|
| 176 |
+
(text_x, text_y),
|
| 177 |
+
_FONT,
|
| 178 |
+
font_scale,
|
| 179 |
+
(0, 0, 0),
|
| 180 |
+
thickness,
|
| 181 |
+
cv2.LINE_AA,
|
| 182 |
+
)
|
| 183 |
y_cursor = bottom_y
|
| 184 |
else:
|
| 185 |
for text in lines:
|
|
|
|
| 192 |
cv2.rectangle(image, (left_x, top_y), (right_x, bottom_y), bg_color, -1)
|
| 193 |
text_x = left_x + 4
|
| 194 |
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 195 |
+
cv2.putText(
|
| 196 |
+
image,
|
| 197 |
+
text,
|
| 198 |
+
(text_x, text_y),
|
| 199 |
+
_FONT,
|
| 200 |
+
font_scale,
|
| 201 |
+
(0, 0, 0),
|
| 202 |
+
thickness,
|
| 203 |
+
cv2.LINE_AA,
|
| 204 |
+
)
|
| 205 |
y_cursor = top_y
|
| 206 |
|
| 207 |
|
|
|
|
| 222 |
top_y = int(np.clip(cy - th // 2 - baseline - 4, 0, img_h - 1))
|
| 223 |
right_x = int(np.clip(left_x + tw + 8, 0, img_w - 1))
|
| 224 |
bottom_y = int(np.clip(top_y + th + baseline + 6, 0, img_h - 1))
|
| 225 |
+
cv2.rectangle(
|
| 226 |
+
image, (left_x, top_y), (right_x, bottom_y), _background_color(color), -1
|
| 227 |
+
)
|
| 228 |
text_x = left_x + 4
|
| 229 |
text_y = min(bottom_y - baseline - 2, img_h - 1)
|
| 230 |
+
cv2.putText(
|
| 231 |
+
image,
|
| 232 |
+
text,
|
| 233 |
+
(text_x, text_y),
|
| 234 |
+
_FONT,
|
| 235 |
+
font_scale,
|
| 236 |
+
(0, 0, 0),
|
| 237 |
+
thickness,
|
| 238 |
+
cv2.LINE_AA,
|
| 239 |
+
)
|
| 240 |
|
| 241 |
|
| 242 |
+
def _extract_frame_entities(
|
| 243 |
+
store: Union[Dict[int, Dict[int, Any]], List, None], frame_idx: int
|
| 244 |
+
) -> Dict[int, Any]:
|
| 245 |
if isinstance(store, dict):
|
| 246 |
frame_entry = store.get(frame_idx, {})
|
| 247 |
elif isinstance(store, list) and 0 <= frame_idx < len(store):
|
|
|
|
| 308 |
continue
|
| 309 |
color = _object_color_bgr(obj_id)
|
| 310 |
alpha = 0.45
|
| 311 |
+
overlay[mask_np] = (1.0 - alpha) * overlay[mask_np] + alpha * np.array(
|
| 312 |
+
color, dtype=np.float32
|
| 313 |
+
)
|
| 314 |
|
| 315 |
annotated = np.clip(overlay, 0, 255).astype(np.uint8)
|
| 316 |
frame_h, frame_w = annotated.shape[:2]
|
|
|
|
| 368 |
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 369 |
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 370 |
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 371 |
+
masks: Union[
|
| 372 |
+
Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None
|
| 373 |
+
] = None,
|
| 374 |
binary_confidence_threshold: float = 0.0,
|
| 375 |
) -> Dict[str, List[np.ndarray]]:
|
| 376 |
frame_groups: Dict[str, List[np.ndarray]] = {
|
|
|
|
| 388 |
base_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
| 389 |
frame_h, frame_w = base_bgr.shape[:2]
|
| 390 |
frame_bboxes = _extract_frame_entities(bboxes, frame_idx)
|
| 391 |
+
frame_masks = (
|
| 392 |
+
_extract_frame_entities(masks, frame_idx) if masks is not None else {}
|
| 393 |
+
)
|
| 394 |
|
| 395 |
objects_bgr = base_bgr.copy()
|
| 396 |
unary_bgr = base_bgr.copy()
|
|
|
|
| 436 |
for obj_id, bbox in bbox_lookup.items():
|
| 437 |
title = titles_lookup.get(obj_id)
|
| 438 |
unary_lines = unary_lines_lookup.get(obj_id, [])
|
| 439 |
+
_draw_bbox_with_label(
|
| 440 |
+
objects_bgr, bbox, obj_id, title=title, label_position="top"
|
| 441 |
+
)
|
| 442 |
+
_draw_bbox_with_label(
|
| 443 |
+
unary_bgr, bbox, obj_id, title=title, label_position="top"
|
| 444 |
+
)
|
| 445 |
if unary_lines:
|
| 446 |
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 447 |
+
_draw_label_block(
|
| 448 |
+
unary_bgr,
|
| 449 |
+
unary_lines,
|
| 450 |
+
anchor,
|
| 451 |
+
_object_color_bgr(obj_id),
|
| 452 |
+
direction=direction,
|
| 453 |
+
)
|
| 454 |
+
_draw_bbox_with_label(
|
| 455 |
+
binary_bgr, bbox, obj_id, title=title, label_position="top"
|
| 456 |
+
)
|
| 457 |
+
_draw_bbox_with_label(
|
| 458 |
+
all_bgr, bbox, obj_id, title=title, label_position="top"
|
| 459 |
+
)
|
| 460 |
if unary_lines:
|
| 461 |
anchor, direction = _label_anchor_and_direction(bbox, "bottom")
|
| 462 |
+
_draw_label_block(
|
| 463 |
+
all_bgr,
|
| 464 |
+
unary_lines,
|
| 465 |
+
anchor,
|
| 466 |
+
_object_color_bgr(obj_id),
|
| 467 |
+
direction=direction,
|
| 468 |
+
)
|
| 469 |
|
| 470 |
# First pass: collect all pairs above threshold and deduplicate bidirectional pairs
|
| 471 |
pairs_to_draw = {} # (min_id, max_id) -> (subj_id, obj_id, prob, relation)
|
|
|
|
| 495 |
subj_bbox = bbox_lookup.get(subj_id)
|
| 496 |
obj_bbox = bbox_lookup.get(obj_id)
|
| 497 |
start, end = relation_line(subj_bbox, obj_bbox)
|
| 498 |
+
color = tuple(
|
| 499 |
+
int(c)
|
| 500 |
+
for c in np.clip(
|
| 501 |
+
(
|
| 502 |
+
np.array(_object_color_bgr(subj_id), dtype=np.float32)
|
| 503 |
+
+ np.array(_object_color_bgr(obj_id), dtype=np.float32)
|
| 504 |
+
)
|
| 505 |
+
/ 2.0,
|
| 506 |
+
0,
|
| 507 |
+
255,
|
| 508 |
+
)
|
| 509 |
+
)
|
| 510 |
label_text = f"{relation} {prob:.2f}"
|
| 511 |
mid_point = (int((start[0] + end[0]) / 2), int((start[1] + end[1]) / 2))
|
| 512 |
# Draw arrowed lines showing direction from subject to object (smaller arrow tip)
|
| 513 |
+
cv2.arrowedLine(
|
| 514 |
+
binary_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05
|
| 515 |
+
)
|
| 516 |
cv2.arrowedLine(all_bgr, start, end, color, 6, cv2.LINE_AA, tipLength=0.05)
|
| 517 |
_draw_centered_label(binary_bgr, label_text, mid_point, color)
|
| 518 |
_draw_centered_label(all_bgr, label_text, mid_point, color)
|
|
|
|
| 531 |
cat_label_lookup: Dict[int, Tuple[str, float]],
|
| 532 |
unary_lookup: Dict[int, Dict[int, List[Tuple[float, str]]]],
|
| 533 |
binary_lookup: Dict[int, List[Tuple[Tuple[int, int], List[Tuple[float, str]]]]],
|
| 534 |
+
masks: Union[
|
| 535 |
+
Dict[int, Dict[int, Union[np.ndarray, torch.Tensor]]], List, None
|
| 536 |
+
] = None,
|
| 537 |
binary_confidence_threshold: float = 0.0,
|
| 538 |
) -> List[np.ndarray]:
|
| 539 |
return render_vine_frame_sets(
|
|
|
|
| 545 |
masks,
|
| 546 |
binary_confidence_threshold,
|
| 547 |
).get("all", [])
|
| 548 |
+
|
| 549 |
+
|
| 550 |
def color_for_cate_correctness(obj_pred_dict, gt_labels, topk_object):
|
| 551 |
all_colors = []
|
| 552 |
all_texts = []
|
| 553 |
+
for obj_id, bbox, gt_label in gt_labels:
|
| 554 |
preds = obj_pred_dict.get(obj_id, [])
|
| 555 |
if len(preds) == 0:
|
| 556 |
top1 = "N/A"
|
|
|
|
| 560 |
topk_labels = [p[0] for p in preds[:topk_object]]
|
| 561 |
# Compare cleaned labels.
|
| 562 |
if top1.lower() == gt_label.lower():
|
| 563 |
+
box_color = (0, 255, 0) # bright green for correct
|
| 564 |
elif gt_label.lower() in [p.lower() for p in topk_labels]:
|
| 565 |
+
box_color = (0, 165, 255) # bright orange for partial match
|
| 566 |
else:
|
| 567 |
+
box_color = (0, 0, 255) # bright red for incorrect
|
| 568 |
+
|
| 569 |
label_text = f"ID:{obj_id}/P:{top1}/GT:{gt_label}"
|
| 570 |
all_colors.append(box_color)
|
| 571 |
all_texts.append(label_text)
|
| 572 |
return all_colors, all_texts
|
| 573 |
|
| 574 |
+
|
| 575 |
def plot_unary(frame_img, gt_labels, all_colors, all_texts):
|
| 576 |
+
for (obj_id, bbox, gt_label), box_color, label_text in zip(
|
| 577 |
+
gt_labels, all_colors, all_texts
|
| 578 |
+
):
|
| 579 |
x1, y1, x2, y2 = map(int, bbox)
|
| 580 |
cv2.rectangle(frame_img, (x1, y1), (x2, y2), color=box_color, thickness=2)
|
| 581 |
+
(tw, th), baseline = cv2.getTextSize(
|
| 582 |
+
label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
| 583 |
+
)
|
| 584 |
+
cv2.rectangle(
|
| 585 |
+
frame_img, (x1, y1 - th - baseline - 4), (x1 + tw, y1), box_color, -1
|
| 586 |
+
)
|
| 587 |
+
cv2.putText(
|
| 588 |
+
frame_img,
|
| 589 |
+
label_text,
|
| 590 |
+
(x1, y1 - 2),
|
| 591 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 592 |
+
0.5,
|
| 593 |
+
(0, 0, 0),
|
| 594 |
+
1,
|
| 595 |
+
cv2.LINE_AA,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
return frame_img
|
| 599 |
|
| 600 |
+
|
| 601 |
+
def get_white_pane(
|
| 602 |
+
pane_height,
|
| 603 |
+
pane_width=600,
|
| 604 |
+
header_height=50,
|
| 605 |
+
header_font=cv2.FONT_HERSHEY_SIMPLEX,
|
| 606 |
+
header_font_scale=0.7,
|
| 607 |
+
header_thickness=2,
|
| 608 |
+
header_color=(0, 0, 0),
|
| 609 |
+
):
|
| 610 |
+
# Create an expanded white pane to display text info.
|
| 611 |
white_pane = 255 * np.ones((pane_height, pane_width, 3), dtype=np.uint8)
|
| 612 |
+
|
| 613 |
# --- Adjust pane split: make predictions column wider (60% vs. 40%) ---
|
| 614 |
left_width = int(pane_width * 0.6)
|
| 615 |
right_width = pane_width - left_width
|
| 616 |
left_pane = white_pane[:, :left_width, :].copy()
|
| 617 |
right_pane = white_pane[:, left_width:, :].copy()
|
| 618 |
+
|
| 619 |
+
cv2.putText(
|
| 620 |
+
left_pane,
|
| 621 |
+
"Binary Predictions",
|
| 622 |
+
(10, header_height - 30),
|
| 623 |
+
header_font,
|
| 624 |
+
header_font_scale,
|
| 625 |
+
header_color,
|
| 626 |
+
header_thickness,
|
| 627 |
+
cv2.LINE_AA,
|
| 628 |
+
)
|
| 629 |
+
cv2.putText(
|
| 630 |
+
right_pane,
|
| 631 |
+
"Ground Truth",
|
| 632 |
+
(10, header_height - 30),
|
| 633 |
+
header_font,
|
| 634 |
+
header_font_scale,
|
| 635 |
+
header_color,
|
| 636 |
+
header_thickness,
|
| 637 |
+
cv2.LINE_AA,
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
return white_pane
|
| 641 |
|
| 642 |
+
|
| 643 |
# This is for ploting binary prediction results with frame-based scene graphs
|
| 644 |
+
def plot_binary_sg(
|
| 645 |
+
frame_img,
|
| 646 |
+
white_pane,
|
| 647 |
+
bin_preds,
|
| 648 |
+
gt_relations,
|
| 649 |
+
topk_binary,
|
| 650 |
+
header_height=50,
|
| 651 |
+
indicator_size=20,
|
| 652 |
+
pane_width=600,
|
| 653 |
+
):
|
| 654 |
+
# Leave vertical space for the headers.
|
| 655 |
line_height = 30 # vertical spacing per line
|
| 656 |
+
x_text = 10 # left margin for text
|
| 657 |
y_text_left = header_height + 10 # starting y for left pane text
|
| 658 |
+
y_text_right = header_height + 10 # starting y for right pane text
|
| 659 |
+
|
| 660 |
# Left section: top-k binary predictions.
|
| 661 |
left_width = int(pane_width * 0.6)
|
| 662 |
right_width = pane_width - left_width
|
| 663 |
left_pane = white_pane[:, :left_width, :].copy()
|
| 664 |
right_pane = white_pane[:, left_width:, :].copy()
|
| 665 |
+
|
| 666 |
+
for subj, pred_rel, obj, score in bin_preds[:topk_binary]:
|
| 667 |
+
correct = any(
|
| 668 |
+
(subj == gt[0] and pred_rel.lower() == gt[2].lower() and obj == gt[1])
|
| 669 |
+
for gt in gt_relations
|
| 670 |
+
)
|
| 671 |
indicator_color = (0, 255, 0) if correct else (0, 0, 255)
|
| 672 |
+
cv2.rectangle(
|
| 673 |
+
left_pane,
|
| 674 |
+
(x_text, y_text_left - indicator_size + 5),
|
| 675 |
+
(x_text + indicator_size, y_text_left + 5),
|
| 676 |
+
indicator_color,
|
| 677 |
+
-1,
|
| 678 |
+
)
|
| 679 |
text = f"{subj} - {pred_rel} - {obj} :: {score:.2f}"
|
| 680 |
+
cv2.putText(
|
| 681 |
+
left_pane,
|
| 682 |
+
text,
|
| 683 |
+
(x_text + indicator_size + 5, y_text_left + 5),
|
| 684 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 685 |
+
0.6,
|
| 686 |
+
(0, 0, 0),
|
| 687 |
+
1,
|
| 688 |
+
cv2.LINE_AA,
|
| 689 |
+
)
|
| 690 |
y_text_left += line_height
|
| 691 |
+
|
| 692 |
# Right section: ground truth binary relations.
|
| 693 |
for gt in gt_relations:
|
| 694 |
if len(gt) != 3:
|
| 695 |
continue
|
| 696 |
text = f"{gt[0]} - {gt[2]} - {gt[1]}"
|
| 697 |
+
cv2.putText(
|
| 698 |
+
right_pane,
|
| 699 |
+
text,
|
| 700 |
+
(x_text, y_text_right + 5),
|
| 701 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 702 |
+
0.6,
|
| 703 |
+
(0, 0, 0),
|
| 704 |
+
1,
|
| 705 |
+
cv2.LINE_AA,
|
| 706 |
+
)
|
| 707 |
y_text_right += line_height
|
| 708 |
+
|
| 709 |
# Combine the two text panes and then with the frame image.
|
| 710 |
combined_pane = np.hstack((left_pane, right_pane))
|
| 711 |
combined_image = np.hstack((frame_img, combined_pane))
|
| 712 |
return combined_image
|
| 713 |
|
| 714 |
+
|
| 715 |
+
def visualized_frame(
|
| 716 |
+
frame_img,
|
| 717 |
+
bboxes,
|
| 718 |
+
object_ids,
|
| 719 |
+
gt_labels,
|
| 720 |
+
cate_preds,
|
| 721 |
+
binary_preds,
|
| 722 |
+
gt_relations,
|
| 723 |
+
topk_object,
|
| 724 |
+
topk_binary,
|
| 725 |
+
phase="unary",
|
| 726 |
+
):
|
| 727 |
"""Return the combined annotated frame for frame index i as an image (in BGR)."""
|
| 728 |
# Get the frame image (assuming batched_data['batched_reshaped_raw_videos'] is a list of frames)
|
| 729 |
|
| 730 |
# --- Process Object Predictions (for overlaying bboxes) ---
|
| 731 |
if phase == "unary":
|
| 732 |
objs = []
|
| 733 |
+
for (_, f_id, obj_id), bbox, gt_label in zip(object_ids, bboxes, gt_labels):
|
| 734 |
gt_label = clean_label(gt_label)
|
| 735 |
objs.append((obj_id, bbox, gt_label))
|
| 736 |
+
|
| 737 |
formatted_cate_preds = format_cate_preds(cate_preds)
|
| 738 |
+
all_colors, all_texts = color_for_cate_correctness(
|
| 739 |
+
formatted_cate_preds, gt_labels, topk_object
|
| 740 |
+
)
|
| 741 |
updated_frame_img = plot_unary(frame_img, gt_labels, all_colors, all_texts)
|
| 742 |
return updated_frame_img
|
| 743 |
+
|
| 744 |
else:
|
| 745 |
# --- Process Binary Predictions & Ground Truth for the Text Pane ---
|
| 746 |
formatted_binary_preds = format_binary_cate_preds(binary_preds)
|
| 747 |
+
|
| 748 |
# Ground truth binary relations for the frame.
|
| 749 |
# Clean ground truth relations.
|
| 750 |
+
gt_relations = [
|
| 751 |
+
(clean_label(str(s)), clean_label(str(o)), clean_label(rel))
|
| 752 |
+
for s, o, rel in gt_relations
|
| 753 |
+
]
|
| 754 |
+
|
| 755 |
pane_width = 600 # increased pane width for more horizontal space
|
| 756 |
pane_height = frame_img.shape[0]
|
| 757 |
+
|
| 758 |
# --- Add header labels to each text pane with extra space ---
|
| 759 |
header_height = 50 # increased header space
|
| 760 |
+
white_pane = get_white_pane(
|
| 761 |
+
pane_height, pane_width, header_height=header_height
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
combined_image = plot_binary_sg(
|
| 765 |
+
frame_img, white_pane, formatted_binary_preds, gt_relations, topk_binary
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
return combined_image
|
| 769 |
|
| 770 |
+
|
| 771 |
def show_mask(mask, ax, obj_id=None, det_class=None, random_color=False):
|
| 772 |
# Ensure mask is a numpy array
|
| 773 |
mask = np.array(mask)
|
|
|
|
| 790 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 791 |
color[3] = 0.5
|
| 792 |
color = np.array(color)
|
| 793 |
+
|
| 794 |
# Expand mask to (H, W, 1) for broadcasting
|
| 795 |
mask_expanded = mask[..., None]
|
| 796 |
mask_image = mask_expanded * color.reshape(1, 1, -1)
|
|
|
|
| 809 |
linewidth=1.5,
|
| 810 |
edgecolor=color[:3],
|
| 811 |
facecolor="none",
|
| 812 |
+
alpha=color[3],
|
| 813 |
)
|
| 814 |
ax.add_patch(rect)
|
| 815 |
ax.text(
|
|
|
|
| 819 |
color="white",
|
| 820 |
fontsize=6,
|
| 821 |
backgroundcolor=np.array(color),
|
| 822 |
+
alpha=1,
|
| 823 |
)
|
| 824 |
ax.imshow(mask_image)
|
| 825 |
|
| 826 |
+
|
| 827 |
def save_mask_one_image(frame_image, masks, save_path):
|
| 828 |
"""Render masks on top of a frame and store the visualization on disk."""
|
| 829 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
|
|
|
| 842 |
|
| 843 |
prepared_masks = {
|
| 844 |
obj_id: (
|
| 845 |
+
mask.detach().cpu().numpy() if torch.is_tensor(mask) else np.asarray(mask)
|
|
|
|
|
|
|
| 846 |
)
|
| 847 |
for obj_id, mask in mask_iter
|
| 848 |
}
|
|
|
|
| 856 |
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 857 |
plt.close(fig)
|
| 858 |
return save_path
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def get_video_masks_visualization(
|
| 862 |
+
video_tensor,
|
| 863 |
+
video_masks,
|
| 864 |
+
video_id,
|
| 865 |
+
video_save_base_dir,
|
| 866 |
+
oid_class_pred=None,
|
| 867 |
+
sample_rate=1,
|
| 868 |
+
):
|
| 869 |
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 870 |
if not os.path.exists(video_save_dir):
|
| 871 |
os.makedirs(video_save_dir, exist_ok=True)
|
| 872 |
+
|
| 873 |
for frame_id, image in enumerate(video_tensor):
|
| 874 |
if frame_id not in video_masks:
|
| 875 |
print("No mask for Frame", frame_id)
|
| 876 |
continue
|
| 877 |
+
|
| 878 |
masks = video_masks[frame_id]
|
| 879 |
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 880 |
get_mask_one_image(image, masks, oid_class_pred)
|
| 881 |
|
| 882 |
+
|
| 883 |
def get_mask_one_image(frame_image, masks, oid_class_pred=None):
|
| 884 |
# Create a figure and axis
|
| 885 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 886 |
|
| 887 |
# Display the frame image
|
| 888 |
ax.imshow(frame_image)
|
| 889 |
+
ax.axis("off")
|
| 890 |
|
| 891 |
if type(masks) == list:
|
| 892 |
masks = {i: m for i, m in enumerate(masks)}
|
| 893 |
+
|
| 894 |
# Add the masks
|
| 895 |
for obj_id, mask in masks.items():
|
| 896 |
+
det_class = (
|
| 897 |
+
f"{obj_id}. {oid_class_pred[obj_id]}"
|
| 898 |
+
if not oid_class_pred is None
|
| 899 |
+
else None
|
| 900 |
+
)
|
| 901 |
show_mask(mask, ax, obj_id=obj_id, det_class=det_class, random_color=False)
|
| 902 |
|
| 903 |
# Show the plot
|
| 904 |
return fig, ax
|
| 905 |
|
| 906 |
+
|
| 907 |
def save_video(frames, output_filename, output_fps):
|
|
|
|
| 908 |
# --- Create a video from all frames ---
|
| 909 |
num_frames = len(frames)
|
| 910 |
frame_h, frame_w = frames.shape[:2]
|
| 911 |
|
| 912 |
# Use a codec supported by VS Code (H.264 via 'avc1').
|
| 913 |
+
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
| 914 |
out = cv2.VideoWriter(output_filename, fourcc, output_fps, (frame_w, frame_h))
|
| 915 |
|
| 916 |
print(f"Processing {num_frames} frames...")
|
|
|
|
| 918 |
vis_frame = get_visualized_frame(i)
|
| 919 |
out.write(vis_frame)
|
| 920 |
if i % 10 == 0:
|
| 921 |
+
print(f"Processed frame {i + 1}/{num_frames}")
|
| 922 |
|
| 923 |
out.release()
|
| 924 |
print(f"Video saved as {output_filename}")
|
| 925 |
+
|
| 926 |
|
| 927 |
def list_depth(lst):
|
| 928 |
"""Calculates the depth of a nested list."""
|
| 929 |
if not (isinstance(lst, list) or isinstance(lst, torch.Tensor)):
|
| 930 |
return 0
|
| 931 |
+
elif (isinstance(lst, torch.Tensor) and lst.shape == torch.Size([])) or (
|
| 932 |
+
isinstance(lst, list) and len(lst) == 0
|
| 933 |
+
):
|
| 934 |
return 1
|
| 935 |
else:
|
| 936 |
return 1 + max(list_depth(item) for item in lst)
|
| 937 |
+
|
| 938 |
+
|
| 939 |
def normalize_prompt(points, labels):
|
| 940 |
+
if list_depth(points) == 3:
|
| 941 |
points = torch.stack([p.unsqueeze(0) for p in points])
|
| 942 |
labels = torch.stack([l.unsqueeze(0) for l in labels])
|
| 943 |
return points, labels
|
|
|
|
| 946 |
def show_box(box, ax, object_id):
|
| 947 |
if len(box) == 0:
|
| 948 |
return
|
| 949 |
+
|
| 950 |
cmap = plt.get_cmap("gist_rainbow")
|
| 951 |
cmap_idx = 0 if object_id is None else object_id
|
| 952 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 953 |
+
|
| 954 |
x0, y0 = box[0], box[1]
|
| 955 |
w, h = box[2] - box[0], box[3] - box[1]
|
| 956 |
+
ax.add_patch(
|
| 957 |
+
plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=2)
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
|
| 961 |
def show_points(coords, labels, ax, object_id=None, marker_size=375):
|
| 962 |
if len(labels) == 0:
|
| 963 |
return
|
| 964 |
+
|
| 965 |
+
pos_points = coords[labels == 1]
|
| 966 |
+
neg_points = coords[labels == 0]
|
| 967 |
+
|
| 968 |
cmap = plt.get_cmap("gist_rainbow")
|
| 969 |
cmap_idx = 0 if object_id is None else object_id
|
| 970 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 971 |
+
|
| 972 |
+
ax.scatter(
|
| 973 |
+
pos_points[:, 0],
|
| 974 |
+
pos_points[:, 1],
|
| 975 |
+
color="green",
|
| 976 |
+
marker="P",
|
| 977 |
+
s=marker_size,
|
| 978 |
+
edgecolor=color,
|
| 979 |
+
linewidth=1.25,
|
| 980 |
+
)
|
| 981 |
+
ax.scatter(
|
| 982 |
+
neg_points[:, 0],
|
| 983 |
+
neg_points[:, 1],
|
| 984 |
+
color="red",
|
| 985 |
+
marker="s",
|
| 986 |
+
s=marker_size,
|
| 987 |
+
edgecolor=color,
|
| 988 |
+
linewidth=1.25,
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
|
| 992 |
def save_prompts_one_image(frame_image, boxes, points, labels, save_path):
|
| 993 |
# Create a figure and axis
|
| 994 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 995 |
|
| 996 |
# Display the frame image
|
| 997 |
ax.imshow(frame_image)
|
| 998 |
+
ax.axis("off")
|
| 999 |
|
| 1000 |
points, labels = normalize_prompt(points, labels)
|
| 1001 |
if type(boxes) == torch.Tensor:
|
|
|
|
| 1012 |
pass
|
| 1013 |
else:
|
| 1014 |
raise Exception()
|
| 1015 |
+
|
| 1016 |
for object_id, (point_ls, label_ls) in enumerate(zip(points, labels)):
|
| 1017 |
if not len(point_ls) == 0:
|
| 1018 |
show_points(point_ls.cpu(), label_ls.cpu(), ax, object_id=object_id)
|
| 1019 |
+
|
| 1020 |
# Show the plot
|
| 1021 |
plt.savefig(save_path)
|
| 1022 |
plt.close()
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def save_video_prompts_visualization(
|
| 1026 |
+
video_tensor, video_boxes, video_points, video_labels, video_id, video_save_base_dir
|
| 1027 |
+
):
|
| 1028 |
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 1029 |
if not os.path.exists(video_save_dir):
|
| 1030 |
os.makedirs(video_save_dir, exist_ok=True)
|
| 1031 |
+
|
| 1032 |
for frame_id, image in enumerate(video_tensor):
|
| 1033 |
boxes, points, labels = [], [], []
|
| 1034 |
+
|
| 1035 |
if frame_id in video_boxes:
|
| 1036 |
boxes = video_boxes[frame_id]
|
| 1037 |
+
|
| 1038 |
if frame_id in video_points:
|
| 1039 |
points = video_points[frame_id]
|
| 1040 |
if frame_id in video_labels:
|
| 1041 |
labels = video_labels[frame_id]
|
| 1042 |
+
|
| 1043 |
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 1044 |
save_prompts_one_image(image, boxes, points, labels, save_path)
|
|
|
|
| 1045 |
|
| 1046 |
+
|
| 1047 |
+
def save_video_masks_visualization(
|
| 1048 |
+
video_tensor,
|
| 1049 |
+
video_masks,
|
| 1050 |
+
video_id,
|
| 1051 |
+
video_save_base_dir,
|
| 1052 |
+
oid_class_pred=None,
|
| 1053 |
+
sample_rate=1,
|
| 1054 |
+
):
|
| 1055 |
video_save_dir = os.path.join(video_save_base_dir, video_id)
|
| 1056 |
if not os.path.exists(video_save_dir):
|
| 1057 |
os.makedirs(video_save_dir, exist_ok=True)
|
| 1058 |
+
|
| 1059 |
for frame_id, image in enumerate(video_tensor):
|
| 1060 |
if random.random() > sample_rate:
|
| 1061 |
continue
|
|
|
|
| 1065 |
masks = video_masks[frame_id]
|
| 1066 |
save_path = os.path.join(video_save_dir, f"{frame_id}.jpg")
|
| 1067 |
save_mask_one_image(image, masks, save_path)
|
|
|
|
| 1068 |
|
| 1069 |
|
| 1070 |
+
def get_color(obj_id, cmap_name="gist_rainbow", alpha=0.5):
|
| 1071 |
cmap = plt.get_cmap(cmap_name)
|
| 1072 |
cmap_idx = 0 if obj_id is None else obj_id
|
| 1073 |
color = list(cmap((cmap_idx * 47) % 256))
|
| 1074 |
color[3] = 0.5
|
| 1075 |
color = np.array(color)
|
| 1076 |
return color
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
def _bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
| 1080 |
return ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0)
|
| 1081 |
|
|
|
|
| 1090 |
"""
|
| 1091 |
center1 = _bbox_center(bbox1)
|
| 1092 |
center2 = _bbox_center(bbox2)
|
| 1093 |
+
if math.isclose(center1[0], center2[0], abs_tol=1e-3) and math.isclose(
|
| 1094 |
+
center1[1], center2[1], abs_tol=1e-3
|
| 1095 |
+
):
|
| 1096 |
offset = max(1.0, (bbox2[2] - bbox2[0]) * 0.05)
|
| 1097 |
center2 = (center2[0] + offset, center2[1])
|
| 1098 |
start = (int(round(center1[0])), int(round(center1[1])))
|
|
|
|
| 1101 |
end = (end[0] + 1, end[1])
|
| 1102 |
return start, end
|
| 1103 |
|
| 1104 |
+
|
| 1105 |
def get_binary_mask_one_image(frame_image, masks, rel_pred_ls=None):
|
| 1106 |
# Create a figure and axis
|
| 1107 |
fig, ax = plt.subplots(1, figsize=(6, 6))
|
| 1108 |
|
| 1109 |
# Display the frame image
|
| 1110 |
ax.imshow(frame_image)
|
| 1111 |
+
ax.axis("off")
|
| 1112 |
+
|
| 1113 |
all_objs_to_show = set()
|
| 1114 |
all_lines_to_show = []
|
| 1115 |
+
|
| 1116 |
# print(rel_pred_ls[0])
|
| 1117 |
for (from_obj_id, to_obj_id), rel_text in rel_pred_ls.items():
|
| 1118 |
+
all_objs_to_show.add(from_obj_id)
|
| 1119 |
+
all_objs_to_show.add(to_obj_id)
|
| 1120 |
+
|
| 1121 |
from_mask = masks[from_obj_id]
|
| 1122 |
bbox1 = mask_to_bbox(from_mask)
|
| 1123 |
to_mask = masks[to_obj_id]
|
| 1124 |
bbox2 = mask_to_bbox(to_mask)
|
| 1125 |
+
|
| 1126 |
c1, c2 = shortest_line_between_bboxes(bbox1, bbox2)
|
| 1127 |
+
|
| 1128 |
line_color = get_color(from_obj_id)
|
| 1129 |
face_color = get_color(to_obj_id)
|
| 1130 |
line = c1, c2, face_color, line_color, rel_text
|
| 1131 |
all_lines_to_show.append(line)
|
| 1132 |
+
|
| 1133 |
masks_to_show = {}
|
| 1134 |
for oid in all_objs_to_show:
|
| 1135 |
masks_to_show[oid] = masks[oid]
|
| 1136 |
+
|
| 1137 |
# Add the masks
|
| 1138 |
for obj_id, mask in masks_to_show.items():
|
| 1139 |
show_mask(mask, ax, obj_id=obj_id, random_color=False)
|
| 1140 |
|
| 1141 |
+
for (from_pt_x, from_pt_y), (
|
| 1142 |
+
to_pt_x,
|
| 1143 |
+
to_pt_y,
|
| 1144 |
+
), face_color, line_color, rel_text in all_lines_to_show:
|
| 1145 |
+
plt.plot(
|
| 1146 |
+
[from_pt_x, to_pt_x],
|
| 1147 |
+
[from_pt_y, to_pt_y],
|
| 1148 |
+
color=line_color,
|
| 1149 |
+
linestyle="-",
|
| 1150 |
+
linewidth=3,
|
| 1151 |
+
)
|
| 1152 |
mid_pt_x = (from_pt_x + to_pt_x) / 2
|
| 1153 |
mid_pt_y = (from_pt_y + to_pt_y) / 2
|
| 1154 |
ax.text(
|
| 1155 |
+
mid_pt_x - 5,
|
| 1156 |
+
mid_pt_y,
|
| 1157 |
+
rel_text,
|
| 1158 |
+
color="white",
|
| 1159 |
+
fontsize=6,
|
| 1160 |
+
backgroundcolor=np.array(line_color),
|
| 1161 |
+
bbox=dict(
|
| 1162 |
+
facecolor=face_color, edgecolor=line_color, boxstyle="round,pad=1"
|
| 1163 |
+
),
|
| 1164 |
+
alpha=1,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
# Show the plot
|
| 1168 |
return fig, ax
|