Spaces:
Running
on
Zero
A newer version of the Gradio SDK is available:
6.0.2
CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Project Overview
Grasp Any Region (GAR) is a research project for region-level multimodal understanding in vision-language models. It enables:
- Single Region Understanding: Detailed description of specific image/video regions via points/boxes/scribbles/masks
- Multi-Region Reasoning: Complex relationship modeling and reasoning across multiple regions simultaneously
- Advanced Compositional Reasoning: Active dialogue about regions rather than passive description
The model is built on top of Facebook's Perception-LM architecture and uses xTuner training framework with PyTorch distributed training.
Architecture
Core Components
Model Architecture (projects/grasp_any_region/models/grasp_any_region.py:GraspAnyRegion):
- Wraps
PerceptionLMForConditionalGenerationfrom HuggingFace - Key innovation: RoI-aligned feature replay technique using
torchvision.ops.roi_align - Adds
mask_patch_embeddinglayer (Conv2d) for region mask encoding - Supports 15 visual prompt tokens (
<Prompt0>through<Prompt14>) plus<NO_Prompt> - Forward pass implements feature replay mechanism at grasp_any_region.py:291-377
Visual Prompt System:
- Masks are encoded with prompt IDs (0-14) where each ID represents a different region
- Special value (15 =
<NO_Prompt>) indicates background/non-region areas - RoI features are extracted using bounding boxes and replayed into the sequence at crop token positions
Training Pipeline:
- Uses xTuner framework (built on MMEngine)
- Dataset: Arrow format with three subsets (Seed, Fine-Grained, Relation)
- Custom collate function handles variable-length sequences and multi-region inputs
- Flash Attention 2 required for efficiency
Directory Structure
projects/grasp_any_region/ # Main model code
βββ configs/ # Training configs (gar_1b.py, gar_8b.py)
βββ models/
β βββ grasp_any_region.py # Main model wrapper
β βββ modeling/ # Custom PerceptionLM implementations
βββ datasets/ # Dataset and data loading
βββ hf_models/ # HuggingFace conversion utilities
demo/ # Inference demos
βββ gar_with_mask.py # Direct mask input
βββ gar_with_sam.py # SAM-based region selection
βββ gar_relationship.py # Multi-region reasoning
βββ gradio/ # Web demo
evaluation/ # Benchmarks
βββ GAR-Bench/ # Custom benchmark (Caption-Simple, Caption-Detailed, VQA)
βββ DLC-Bench/ # Detailed localized captioning
βββ Ferret-Bench/ # Region description
βββ MDVP-Bench/ # Multi-domain visual perception
tools/
βββ train.py # Training entry point
βββ test.py # Testing entry point
βββ dist.sh # Distributed training launcher
Common Commands
Environment Setup
# Create environment (requires Python 3.11.2)
conda create -n gar python=3.11.2 -y
conda activate gar
# Install dependencies
pip3 install xtuner==0.2.0rc0
pip3 install -r requirements.txt
pip3 install flash-attn==2.7.4.post1 --no-build-isolation -v
Training
# Single-node distributed training (8 GPUs)
bash tools/dist.sh train projects/grasp_any_region/configs/gar_1b.py 8
# The dist.sh script uses torchrun with:
# - Configurable MASTER_ADDR, PORT, NNODES, NODE_RANK
# - DeepSpeed Zero2 by default (set DEEPSPEED env var to override)
# - 5-hour timeout (TORCHELASTIC_TIMEOUT=18000)
Config Files:
projects/grasp_any_region/configs/gar_1b.py- 1B modelprojects/grasp_any_region/configs/gar_8b.py- 8B model
Key training settings (gar_1b.py):
- Base model:
facebook/Perception-LM-1B - Batch size: 1 per device Γ 2 accumulation Γ 32 GPUs = 64 global
- Learning rate: 1e-5 (AdamW), warmup: 3%, cosine annealing
- Max length: 16384 tokens
- Saves every 5000 steps, keeps last 2 checkpoints
Dataset Preparation
# Download dataset from HuggingFace
hf download HaochenWang/Grasp-Any-Region-Dataset --local-dir data --repo-type dataset
# Expected structure:
# data/
# βββ Seed-Dataset/data-*.arrow
# βββ Fine-Grained-Dataset/data-*.arrow
# βββ Relation-Dataset/data-*.arrow
Inference Demos
Single Region with Mask:
torchrun --nproc-per-node=1 --master-port=8119 \
demo/gar_with_mask.py \
--image_path assets/demo_image_1.png \
--mask_path assets/demo_mask_1.png
Single Region with SAM (points or box):
# Using points
torchrun --nproc-per-node=1 --master-port=8119 \
demo/gar_with_sam.py \
--image_path assets/demo_image_2.jpg \
--points '[[1172, 812], [1572, 800]]'
# Using bounding box
torchrun --nproc-per-node=1 --master-port=8119 \
demo/gar_with_sam.py \
--image_path assets/demo_image_2.jpg \
--box '[800, 500, 1800, 1000]' \
--use_box
Multi-Region Relationship:
torchrun --nproc-per-node=1 --master-port=8119 \
demo/gar_relationship.py \
--image_path assets/demo_image_3.png \
--mask_paths "['assets/demo_mask_3_0.png', 'assets/demo_mask_3_1.png', 'assets/demo_mask_3_2.png']" \
--question_str 'Question: What is the relationship between <Prompt0>, <Prompt1>, and <Prompt2>?'
Gradio Demo:
cd demo/gradio
pip install -r requirements.txt
python app.py
Evaluation
All evaluation scripts follow the same pattern: inference β evaluation with LLM judge (GPT-4o or Llama).
GARBench-Caption-Simple:
# Inference
torchrun --nproc-per-node=1 --master-port=9811 \
evaluation/GAR-Bench/inference.py \
--model_name_or_path HaochenWang/GAR-8B \
--anno_file evaluation/GAR-Bench/annotations/GAR-Bench-Caption-Simple.json \
--mode simple \
--cache_name my_test \
--data_type bf16 \
--seed 42
# Evaluation (requires Azure OpenAI)
export AZURE_OPENAI_ENDPOINT=YOUR_ENDPOINT
export AZURE_OPENAI_KEY=YOUR_KEY
python3 evaluation/GAR-Bench/eval_simple.py \
--pred evaluation/GAR-Bench/model_outputs/my_test_simple.json
GARBench-VQA (multi-region reasoning):
torchrun --nproc-per-node=1 --master-port=9811 \
evaluation/GAR-Bench/inference.py \
--model_name_or_path HaochenWang/GAR-8B \
--anno_file evaluation/GAR-Bench/annotations/GAR-Bench-VQA.json \
--mode vqa \
--cache_name my_test \
--data_type bf16
# VQA evaluation is automatic (no LLM judge)
DLC-Bench (detailed localized captioning):
# Download images first
cd evaluation/DLC-Bench/annotations
hf download nvidia/DLC-Bench --repo-type dataset --include "images/*" --local-dir ./
cd ../../..
# Inference
torchrun --nproc-per-node=1 --master-port=8841 \
evaluation/DLC-Bench/inference.py \
--model_name_or_path HaochenWang/GAR-8B \
--cache_name my_test \
--data_type bf16
# Evaluation with GPT-4o
python3 evaluation/DLC-Bench/eval_gpt_with_image.py \
--pred evaluation/DLC-Bench/model_outputs/my_test.json
# Alternative: Evaluation with Llama3.1-8B (requires vLLM server)
bash evaluation/DLC-Bench/serve_judge.sh # in one terminal
python3 evaluation/DLC-Bench/eval_llama_without_image.py \
--pred evaluation/DLC-Bench/model_outputs/my_test.json \
--base_url http://localhost:8007/v1
Model Conversion
# Convert trained checkpoint to HuggingFace format
python3 projects/grasp_any_region/hf_models/convert_to_hf.py \
projects/grasp_any_region/configs/gar_1b.py \
--pth-model PATH_TO_PTH_MODEL \
--save-path PATH_TO_SAVE_FOLDER
# Note: Manually copy required .py files to save folder after conversion
Key Implementation Details
RoI Feature Replay Mechanism
The core innovation is at grasp_any_region.py:291-377:
- Image features are extracted as tiles (16Γ16 patches per tile)
- Tiles are merged into full-resolution feature map
- For each
<PromptN>token in input:- Extract RoI bounding box from
data["bboxes"] - Apply
torchvision.ops.roi_alignto extract 16Γ16 features - Replace prompt tokens in sequence with RoI features
- Extract RoI bounding box from
- This allows attending to region-specific features with global context
Mask Encoding
Masks are provided as 3-channel images where pixel values encode prompt IDs:
- Values 0-14: Different region prompts
- Value 15 (or
prompt_numbers): Background (no prompt) mask_patch_embedding(Conv2d) encodes binary masks into feature space- Masks are processed at patch level matching vision encoder stride
Data Format
Dataset uses Arrow format with fields:
pixel_values: (num_tiles, 3, H, W) image tilesinput_ids: Token sequence with special image/prompt tokenslabels: Target sequence (-100 for non-loss positions)global_mask_values: Region masks with prompt IDsaspect_ratios: (ncw, nch) tile arrangementbboxes: Dict mapping crop tokens to normalized bbox coordinates
Special Tokens
The model extends base tokenizer with:
<Prompt0>through<Prompt14>: Region identifiers in text<NO_Prompt>: Background/non-region marker<|reserved_special_token_{pid+2}|>: Internal crop tokens for feature replay
Important Notes
- Flash Attention 2 is required - training will fail without it
- Python 3.11.2 specifically - later versions may have compatibility issues
- Single batch size only - code asserts
batch_size=1at grasp_any_region.py:270 - Distributed training required - single-GPU training not well supported
- DeepSpeed Zero2 - default optimization for memory efficiency
- torchrun vs torch.distributed.launch - dist.sh tries torchrun first, falls back to launch
- xTuner framework - all training uses xTuner's runner, not native PyTorch
- Evaluation randomness - LLM judges have variance even with temperature=0
HuggingFace Models
Pre-trained models available:
HaochenWang/GAR-1B- 1 billion parameter modelHaochenWang/GAR-8B- 8 billion parameter model
Base architecture:
facebook/Perception-LM-1B- Base vision-language modelfacebook/Perception-LM-8B- Larger variant
Citation
@article{wang2025grasp,
title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
author={Haochen Wang and Yuhao Wang and Tao Zhang and Yikang Zhou and Yanwei Li and Jiacong Wang and Ye Tian and Jiahao Meng and Zilong Huang and Guangcan Mai and Anran Wang and Yunhai Tong and Zhuochen Wang and Xiangtai Li and Zhaoxiang Zhang},
journal={arXiv preprint arXiv:2510.18876},
year={2025}
}
License
Apache-2.0 License