LASER / vine_hf /example_sam2_masks.py
ASethi04's picture
updates
f9a6349
raw
history blame
12.9 kB
"""
Example demonstrating SAM2 mask generation in VINE HuggingFace interface
This script shows how to use both SAM2-only and Grounding DINO + SAM2
segmentation methods with the VINE model.
"""
import os
import sys
import torch
import numpy as np
from transformers.pipelines import PIPELINE_REGISTRY
# Add the parent directory to the path to import vine_hf
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Add the parent directory to the path to import vine_hf
#Either uncomment the below or set a environemental key, though it isn't needed to run.
#os.environ['OPENAI_API_KEY'] = 'dummy-key'
from vine_hf import VineConfig, VineModel, VinePipeline
from laser.loading import load_video
def example_sam2_only_segmentation():
"""Example using SAM2 automatic mask generation only."""
print("=== SAM2-Only Segmentation Example ===")
# Create configuration for SAM2-only
config = VineConfig(
use_hf_repo=True,
model_repo="video-fm/vine_v0",
segmentation_method="sam2", # Use SAM2 only
target_fps=1,
debug_visualizations=True,
)
# Register pipeline
PIPELINE_REGISTRY.register_pipeline(
"vine-video-understanding",
pipeline_class=VinePipeline,
pt_model=VineModel,
type="multimodal",
)
# Create model and pipeline with SAM2 paths
vine_model = VineModel(config)
vine_pipeline = VinePipeline(
model=vine_model,
tokenizer=None,
sam_config_path="path/to/your/sam2/sam_config.yaml",
sam_checkpoint_path="path/to/your/sam2/sam_checkpoint.pth",
gd_config_path="path/to/your/groundingdino/config.py",
gd_checkpoint_path="path/to/your/groundingdino/checkpoint.pth",
)
# Check for demo video
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
if os.path.exists(demo_video):
print(f"Processing video: {demo_video}")
# Define keywords (SAM2 will find all objects, then classify them)
categorical_keywords = ['human', 'dog', 'frisbee', 'object', 'person', 'animal']
unary_keywords = ['running', 'jumping', 'sitting', 'standing', 'moving', 'static']
binary_keywords = ['behind', 'in front of', 'next to', 'chasing', 'following']
object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3), (0,4)]
print("Using SAM2 automatic mask generation...")
print("This will find all objects in the video automatically")
try:
# Process with SAM2 only
results = vine_pipeline(
demo_video,
categorical_keywords=categorical_keywords,
unary_keywords=unary_keywords,
binary_keywords=binary_keywords,
object_pairs=object_pairs,
segmentation_method="sam2",
return_top_k=3,
debug_visualizations=True,
debug_visualization_path=os.path.join(os.getcwd(), "sam2_debug_masks.png"),
)
print("\n✓ SAM2 segmentation completed!")
print("Results summary:")
print(f" Objects detected: {results['summary']['num_objects_detected']}")
print(f" Top categories: {results['summary']['top_categories']}")
print(f" Top actions: {results['summary']['top_actions']}")
return results
except Exception as e:
print(f"SAM2 segmentation failed: {e}")
print("Make sure SAM2 models are properly installed")
return None
else:
print(f"Demo video not found: {demo_video}")
return None
def example_grounding_dino_sam2_segmentation():
"""Example using Grounding DINO + SAM2 text-guided segmentation."""
print("\n=== Grounding DINO + SAM2 Segmentation Example ===")
# Create configuration for Grounding DINO + SAM2
config = VineConfig(
use_hf_repo=True,
model_repo="video-fm/vine_v0",
segmentation_method="grounding_dino_sam2", # Use text-guided segmentation
box_threshold=0.35,
text_threshold=0.25,
target_fps=1,
debug_visualizations=True,
)
# Create model and pipeline with both SAM2 and GroundingDINO paths
vine_model = VineModel(config)
vine_pipeline = VinePipeline(
model=vine_model,
tokenizer=None,
# SAM2 configuration
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
device=0,
)
# Check for demo video
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
if os.path.exists(demo_video):
print(f"Processing video: {demo_video}")
# Define keywords (Grounding DINO will look specifically for these)
categorical_keywords = ['human', 'dog', 'frisbee'] # Specific objects to find
unary_keywords = ['running', 'jumping', 'catching', 'throwing']
binary_keywords = ['behind', 'chasing', 'next to', 'throwing to']
object_pairs = [(0,1), (0, 2), (1, 2), (1, 3), (2, 3), (0,4)]
print("Using Grounding DINO + SAM2 text-guided segmentation...")
print(f"Looking specifically for: {categorical_keywords}")
try:
# Process with Grounding DINO + SAM2
results = vine_pipeline(
demo_video,
categorical_keywords=categorical_keywords,
unary_keywords=unary_keywords,
binary_keywords=binary_keywords,
object_pairs=object_pairs,
segmentation_method="grounding_dino_sam2",
box_threshold=0.35,
text_threshold=0.25,
return_top_k=3,
debug_visualizations=True,
)
print("\n✓ Grounding DINO + SAM2 segmentation completed!")
print("Results summary:")
print(f" Objects detected: {results['summary']['num_objects_detected']}")
print(f" Top categories: {results['summary']['top_categories']}")
print(f" Top actions: {results['summary']['top_actions']}")
print(f" Top relations: {results['summary']['top_relations']}")
return results
except Exception as e:
print(f"Grounding DINO + SAM2 segmentation failed: {e}")
print("Make sure both Grounding DINO and SAM2 models are properly installed")
return None
else:
print(f"Demo video not found: {demo_video}")
return None
def compare_segmentation_methods():
"""Compare SAM2-only vs Grounding DINO + SAM2 approaches."""
print("\n=== Comparing Segmentation Methods ===")
print("\nSAM2-Only Approach:")
print("✓ Finds all objects automatically")
print("✓ No need to specify what to look for")
print("✓ Good for exploratory analysis")
print("✗ May find too many irrelevant objects")
print("✗ Less precise for specific object types")
print("\nGrounding DINO + SAM2 Approach:")
print("✓ Finds specific objects based on text prompts")
print("✓ More precise and targeted")
print("✓ Better for known object categories")
print("✓ Integrates object detection with segmentation")
print("✗ Limited to specified categories")
print("✗ Requires knowing what objects to look for")
def demonstrate_mask_processing():
"""Demonstrate how masks are processed internally."""
print("\n=== Mask Processing Demonstration ===")
# Load a video to show the processing pipeline
demo_video = os.path.join(os.path.dirname(__file__), "../demo/videos/output.mp4")
if os.path.exists(demo_video):
print("Loading video for mask processing demo...")
# Load video tensor
video_tensor = np.asarray(load_video(demo_video, target_fps=1))
print(f"Video shape: {video_tensor.shape}")
# Create pipeline with segmentation model paths
config = VineConfig(segmentation_method="sam2")
vine_model = VineModel(config)
vine_pipeline = VinePipeline(
model=vine_model,
tokenizer=None,
# SAM2 configuration
sam_config_path="path/to/sam2/configs/sam2.1_hiera_b+.yaml",
sam_checkpoint_path="path/to/sam2/checkpoints/sam2.1_hiera_base_plus.pt",
gd_config_path="path/to/GroundingDINO/config/GroundingDINO_SwinT_OGC.py",
gd_checkpoint_path="path/to/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
)
try:
# Process just the first few frames to show the pipeline
print("\nProcessing first 2 frames with SAM2...")
# Manually call the preprocessing to show the steps
processed_data = vine_pipeline.preprocess(
video_tensor[:2], # Just first 2 frames
segmentation_method="sam2",
categorical_keywords=['object']
)
print("Mask processing results:")
print(f" Number of frames processed: {processed_data['num_frames']}")
print(f" Frames with masks: {list(processed_data['masks'].keys())}")
# Show mask details
for frame_id, frame_masks in processed_data['masks'].items():
print(f" Frame {frame_id}: {len(frame_masks)} objects detected")
for obj_id, mask in frame_masks.items():
print(f" Object {obj_id}: mask shape {mask.shape}")
print("\nBounding box extraction:")
for frame_id, frame_bboxes in processed_data['bboxes'].items():
print(f" Frame {frame_id}: {len(frame_bboxes)} bounding boxes")
for obj_id, bbox in frame_bboxes.items():
print(f" Object {obj_id}: bbox {bbox}")
except Exception as e:
print(f"Mask processing failed: {e}")
print("This is expected if SAM2 models are not properly set up")
else:
print(f"Demo video not found: {demo_video}")
def test_mask_formats():
"""Test different mask input formats."""
print("\n=== Testing Mask Formats ===")
# Create dummy data to test mask processing
height, width = 224, 224
# Test different mask formats
print("Testing mask format conversions...")
# Format 1: NumPy boolean array
mask_np = np.random.rand(height, width) > 0.5
print(f"NumPy mask: {mask_np.shape}, dtype: {mask_np.dtype}")
# Format 2: PyTorch tensor
mask_torch = torch.from_numpy(mask_np)
print(f"PyTorch mask: {mask_torch.shape}, dtype: {mask_torch.dtype}")
# Format 3: 3D mask with singleton dimension
mask_3d = mask_torch.unsqueeze(-1)
print(f"3D mask: {mask_3d.shape}")
# Test bounding box extraction
from laser.preprocess.mask_generation_grounding_dino import mask_to_bbox
try:
bbox = mask_to_bbox(mask_torch)
print(f"Extracted bbox: {bbox}")
print("✓ Mask format testing successful")
except Exception as e:
print(f"Mask format testing failed: {e}")
if __name__ == "__main__":
print("VINE SAM2 Mask Generation Examples")
print("=" * 50)
# Test SAM2-only approach
try:
sam2_results = example_sam2_only_segmentation()
except Exception as e:
print(f"SAM2-only example failed: {e}")
# Test Grounding DINO + SAM2 approach
try:
gd_sam2_results = example_grounding_dino_sam2_segmentation()
except Exception as e:
print(f"Grounding DINO + SAM2 example failed: {e}")
# Compare approaches
compare_segmentation_methods()
# Demonstrate mask processing
try:
demonstrate_mask_processing()
except Exception as e:
print(f"Mask processing demo failed: {e}")
# Test mask formats
try:
test_mask_formats()
except Exception as e:
print(f"Mask format testing failed: {e}")
print("\n" + "=" * 50)
print("Examples completed!")
print("\nKey takeaways:")
print("1. SAM2-only: Automatic object detection and segmentation")
print("2. Grounding DINO + SAM2: Text-guided object detection and segmentation")
print("3. Both methods provide masks and bounding boxes for VINE model")
print("4. Choose method based on whether you know what objects to look for")