jbilcke-hf's picture
Upload core files for paper 2510.18876
46861c5 verified
# *************************************************************************
# Grasp Any Region (GAR) - Gradio Demo
# Region-level Multimodal Understanding for Vision-Language Models
# *************************************************************************
# 🚨 CRITICAL: Import spaces FIRST before any CUDA-related packages
import spaces
# Now import CUDA-related packages
import torch
import numpy as np
from PIL import Image
import gradio as gr
from transformers import (
AutoModel,
AutoProcessor,
GenerationConfig,
SamModel,
SamProcessor,
)
import cv2
import sys
import os
# Add project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from evaluation.eval_dataset import SingleRegionCaptionDataset
except ImportError:
print("Warning: Could not import SingleRegionCaptionDataset. Using simplified version.")
SingleRegionCaptionDataset = None
# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Global model variables (loaded once)
gar_model = None
gar_processor = None
sam_model = None
sam_processor = None
def load_models():
"""Load models once at startup"""
global gar_model, gar_processor, sam_model, sam_processor
if gar_model is None:
print("Loading GAR model...")
model_path = "HaochenWang/GAR-1B"
gar_model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
gar_processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True,
)
print("GAR model loaded successfully!")
if sam_model is None:
print("Loading SAM model...")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
print("SAM model loaded successfully!")
@spaces.GPU(duration=120)
def generate_mask_from_points(image, points_str):
"""Generate mask using SAM from point coordinates"""
try:
load_models()
if not points_str or points_str.strip() == "":
return None, "Please provide points in format: x1,y1;x2,y2"
# Parse points
points = []
labels = []
for point in points_str.split(';'):
point = point.strip()
if point:
x, y = map(float, point.split(','))
points.append([x, y])
labels.append(1) # Foreground point
if not points:
return None, "No valid points provided"
# Apply SAM
inputs = sam_processor(
image,
input_points=[points],
input_labels=[labels],
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0][0]
scores = outputs.iou_scores[0, 0]
mask_selection_index = scores.argmax()
mask_np = masks[mask_selection_index].numpy()
# Visualize mask
mask_img = (mask_np * 255).astype(np.uint8)
return Image.fromarray(mask_img), "Mask generated successfully!"
except Exception as e:
return None, f"Error generating mask: {str(e)}"
@spaces.GPU(duration=120)
def generate_mask_from_box(image, box_str):
"""Generate mask using SAM from bounding box"""
try:
load_models()
if not box_str or box_str.strip() == "":
return None, "Please provide box in format: x1,y1,x2,y2"
# Parse box
box = list(map(float, box_str.split(',')))
if len(box) != 4:
return None, "Box must have 4 coordinates: x1,y1,x2,y2"
# Apply SAM
inputs = sam_processor(
image,
input_boxes=[[box]],
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0][0]
scores = outputs.iou_scores[0, 0]
mask_selection_index = scores.argmax()
mask_np = masks[mask_selection_index].numpy()
# Visualize mask
mask_img = (mask_np * 255).astype(np.uint8)
return Image.fromarray(mask_img), "Mask generated successfully!"
except Exception as e:
return None, f"Error generating mask: {str(e)}"
@spaces.GPU(duration=120)
def describe_region(image, mask):
"""Generate description for a region defined by a mask"""
try:
load_models()
if image is None:
return "Please provide an image"
if mask is None:
return "Please provide a mask (upload or generate using SAM)"
# Convert mask to numpy
if isinstance(mask, Image.Image):
mask_np = np.array(mask.convert("L"))
else:
mask_np = np.array(mask)
# Ensure mask is binary
mask_np = (mask_np > 127).astype(np.uint8)
# Prepare data
prompt_number = gar_model.config.prompt_numbers
prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"]
if SingleRegionCaptionDataset is not None:
dataset = SingleRegionCaptionDataset(
image=image,
mask=mask_np,
processor=gar_processor,
prompt_number=prompt_number,
visual_prompt_tokens=prompt_tokens,
data_dtype=torch.bfloat16,
)
data_sample = dataset[0]
else:
# Simplified processing if dataset class not available
# This is a fallback - the actual implementation requires SingleRegionCaptionDataset
return "Error: SingleRegionCaptionDataset not available. Please check installation."
# Generate description
with torch.no_grad():
generate_ids = gar_model.generate(
**data_sample,
generation_config=GenerationConfig(
max_new_tokens=1024,
do_sample=False,
eos_token_id=gar_processor.tokenizer.eos_token_id,
pad_token_id=gar_processor.tokenizer.pad_token_id,
),
return_dict=True,
)
output_caption = gar_processor.tokenizer.decode(
generate_ids.sequences[0], skip_special_tokens=True
).strip()
return output_caption
except Exception as e:
return f"Error generating description: {str(e)}"
def create_visualization(image, mask, points_str=None, box_str=None):
"""Create visualization with mask overlay"""
try:
if image is None or mask is None:
return None
img_np = np.array(image).astype(float) / 255.0
if isinstance(mask, Image.Image):
mask_np = np.array(mask.convert("L")) > 127
else:
mask_np = np.array(mask) > 127
# Draw contour
mask_uint8 = mask_np.astype(np.uint8) * 255
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
img_vis = img_np.copy()
cv2.drawContours(img_vis, contours, -1, (1.0, 1.0, 0.0), thickness=3)
# Draw points if provided
if points_str:
for point in points_str.split(';'):
point = point.strip()
if point:
x, y = map(float, point.split(','))
cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 0.0, 0.0), thickness=-1)
cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 1.0, 1.0), thickness=2)
# Draw box if provided
if box_str:
coords = list(map(float, box_str.split(',')))
if len(coords) == 4:
x1, y1, x2, y2 = map(int, coords)
cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 1.0, 1.0), thickness=3)
cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 0.0, 0.0), thickness=1)
img_pil = Image.fromarray((img_vis * 255.0).astype(np.uint8))
return img_pil
except Exception as e:
print(f"Error creating visualization: {str(e)}")
return None
# Create Gradio interface
with gr.Blocks(title="Grasp Any Region (GAR) Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎯 Grasp Any Region (GAR)
**Region-level Multimodal Understanding for Vision-Language Models**
This demo showcases GAR's ability to understand and describe specific regions in images:
- 🎨 **Single Region Understanding**: Describe specific areas using points, boxes, or masks
- πŸ” **SAM Integration**: Generate masks interactively using Segment Anything Model
- πŸ’‘ **Detailed Descriptions**: Get comprehensive descriptions of any region
Built on top of Perception-LM with RoI-aligned feature replay technique.
πŸ“„ [Paper](https://arxiv.org/abs/2510.18876) | πŸ’» [GitHub](https://github.com/Haochen-Wang409/Grasp-Any-Region) | πŸ€— [Model](https://huggingface.co/HaochenWang/GAR-1B)
""")
with gr.Tabs():
# Tab 1: Points-based segmentation
with gr.Tab("🎯 Points β†’ Describe"):
gr.Markdown("### Click points on the image or enter coordinates to segment and describe a region")
with gr.Row():
with gr.Column():
img_points = gr.Image(label="Input Image", type="pil")
points_input = gr.Textbox(
label="Points (format: x1,y1;x2,y2;...)",
placeholder="e.g., 1172,812;1572,800",
value="1172,812;1572,800"
)
with gr.Row():
gen_mask_points_btn = gr.Button("Generate Mask", variant="primary")
describe_points_btn = gr.Button("Describe Region", variant="secondary")
with gr.Column():
mask_points = gr.Image(label="Generated Mask", type="pil")
vis_points = gr.Image(label="Visualization")
desc_points = gr.Textbox(label="Region Description", lines=5)
points_status = gr.Textbox(label="Status", visible=False)
gen_mask_points_btn.click(
fn=generate_mask_from_points,
inputs=[img_points, points_input],
outputs=[mask_points, points_status]
)
describe_points_btn.click(
fn=describe_region,
inputs=[img_points, mask_points],
outputs=desc_points
).then(
fn=create_visualization,
inputs=[img_points, mask_points, points_input, gr.Textbox(visible=False)],
outputs=vis_points
)
gr.Examples(
examples=[
["assets/demo_image_2.jpg", "1172,812;1572,800"],
],
inputs=[img_points, points_input],
label="Example Images"
)
# Tab 2: Box-based segmentation
with gr.Tab("πŸ“¦ Box β†’ Describe"):
gr.Markdown("### Draw a bounding box or enter coordinates to segment and describe a region")
with gr.Row():
with gr.Column():
img_box = gr.Image(label="Input Image", type="pil")
box_input = gr.Textbox(
label="Bounding Box (format: x1,y1,x2,y2)",
placeholder="e.g., 800,500,1800,1000",
value="800,500,1800,1000"
)
with gr.Row():
gen_mask_box_btn = gr.Button("Generate Mask", variant="primary")
describe_box_btn = gr.Button("Describe Region", variant="secondary")
with gr.Column():
mask_box = gr.Image(label="Generated Mask", type="pil")
vis_box = gr.Image(label="Visualization")
desc_box = gr.Textbox(label="Region Description", lines=5)
box_status = gr.Textbox(label="Status", visible=False)
gen_mask_box_btn.click(
fn=generate_mask_from_box,
inputs=[img_box, box_input],
outputs=[mask_box, box_status]
)
describe_box_btn.click(
fn=describe_region,
inputs=[img_box, mask_box],
outputs=desc_box
).then(
fn=create_visualization,
inputs=[img_box, mask_box, gr.Textbox(visible=False), box_input],
outputs=vis_box
)
gr.Examples(
examples=[
["assets/demo_image_2.jpg", "800,500,1800,1000"],
],
inputs=[img_box, box_input],
label="Example Images"
)
# Tab 3: Direct mask upload
with gr.Tab("🎭 Mask β†’ Describe"):
gr.Markdown("### Upload a pre-made mask to describe a region")
with gr.Row():
with gr.Column():
img_mask = gr.Image(label="Input Image", type="pil")
mask_upload = gr.Image(label="Upload Mask", type="pil")
describe_mask_btn = gr.Button("Describe Region", variant="primary")
with gr.Column():
vis_mask = gr.Image(label="Visualization")
desc_mask = gr.Textbox(label="Region Description", lines=5)
describe_mask_btn.click(
fn=describe_region,
inputs=[img_mask, mask_upload],
outputs=desc_mask
).then(
fn=create_visualization,
inputs=[img_mask, mask_upload, gr.Textbox(visible=False), gr.Textbox(visible=False)],
outputs=vis_mask
)
gr.Examples(
examples=[
["assets/demo_image_1.png", "assets/demo_mask_1.png"],
],
inputs=[img_mask, mask_upload],
label="Example Images"
)
gr.Markdown("""
---
### πŸ“– How to Use:
1. **Points β†’ Describe**: Click or enter point coordinates, generate mask, then describe
2. **Box β†’ Describe**: Draw or enter a bounding box, generate mask, then describe
3. **Mask β†’ Describe**: Upload a pre-made mask directly and describe
### πŸ”§ Technical Details:
- **Model**: GAR-1B (1 billion parameters)
- **Base**: Facebook Perception-LM with RoI-aligned feature replay
- **Segmentation**: Segment Anything Model (SAM ViT-Huge)
- **Hardware**: Powered by ZeroGPU (NVIDIA H200, 70GB VRAM)
### πŸ“š Citation:
```bibtex
@article{wang2025grasp,
title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
author={Haochen Wang et al.},
journal={arXiv preprint arXiv:2510.18876},
year={2025}
}
```
""")
# Load models on startup
try:
load_models()
except Exception as e:
print(f"Warning: Could not pre-load models: {e}")
print("Models will be loaded on first use.")
if __name__ == "__main__":
demo.launch()