jbilcke-hf's picture
Upload core files for paper 2510.18876
46861c5 verified

A newer version of the Gradio SDK is available: 6.0.2

Upgrade

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

Grasp Any Region (GAR) is a research project for region-level multimodal understanding in vision-language models. It enables:

  1. Single Region Understanding: Detailed description of specific image/video regions via points/boxes/scribbles/masks
  2. Multi-Region Reasoning: Complex relationship modeling and reasoning across multiple regions simultaneously
  3. Advanced Compositional Reasoning: Active dialogue about regions rather than passive description

The model is built on top of Facebook's Perception-LM architecture and uses xTuner training framework with PyTorch distributed training.

Architecture

Core Components

Model Architecture (projects/grasp_any_region/models/grasp_any_region.py:GraspAnyRegion):

  • Wraps PerceptionLMForConditionalGeneration from HuggingFace
  • Key innovation: RoI-aligned feature replay technique using torchvision.ops.roi_align
  • Adds mask_patch_embedding layer (Conv2d) for region mask encoding
  • Supports 15 visual prompt tokens (<Prompt0> through <Prompt14>) plus <NO_Prompt>
  • Forward pass implements feature replay mechanism at grasp_any_region.py:291-377

Visual Prompt System:

  • Masks are encoded with prompt IDs (0-14) where each ID represents a different region
  • Special value (15 = <NO_Prompt>) indicates background/non-region areas
  • RoI features are extracted using bounding boxes and replayed into the sequence at crop token positions

Training Pipeline:

  • Uses xTuner framework (built on MMEngine)
  • Dataset: Arrow format with three subsets (Seed, Fine-Grained, Relation)
  • Custom collate function handles variable-length sequences and multi-region inputs
  • Flash Attention 2 required for efficiency

Directory Structure

projects/grasp_any_region/     # Main model code
  β”œβ”€β”€ configs/                 # Training configs (gar_1b.py, gar_8b.py)
  β”œβ”€β”€ models/
  β”‚   β”œβ”€β”€ grasp_any_region.py  # Main model wrapper
  β”‚   └── modeling/            # Custom PerceptionLM implementations
  β”œβ”€β”€ datasets/                # Dataset and data loading
  └── hf_models/               # HuggingFace conversion utilities

demo/                          # Inference demos
  β”œβ”€β”€ gar_with_mask.py        # Direct mask input
  β”œβ”€β”€ gar_with_sam.py         # SAM-based region selection
  β”œβ”€β”€ gar_relationship.py     # Multi-region reasoning
  └── gradio/                 # Web demo

evaluation/                    # Benchmarks
  β”œβ”€β”€ GAR-Bench/              # Custom benchmark (Caption-Simple, Caption-Detailed, VQA)
  β”œβ”€β”€ DLC-Bench/              # Detailed localized captioning
  β”œβ”€β”€ Ferret-Bench/           # Region description
  └── MDVP-Bench/             # Multi-domain visual perception

tools/
  β”œβ”€β”€ train.py                # Training entry point
  β”œβ”€β”€ test.py                 # Testing entry point
  └── dist.sh                 # Distributed training launcher

Common Commands

Environment Setup

# Create environment (requires Python 3.11.2)
conda create -n gar python=3.11.2 -y
conda activate gar

# Install dependencies
pip3 install xtuner==0.2.0rc0
pip3 install -r requirements.txt
pip3 install flash-attn==2.7.4.post1 --no-build-isolation -v

Training

# Single-node distributed training (8 GPUs)
bash tools/dist.sh train projects/grasp_any_region/configs/gar_1b.py 8

# The dist.sh script uses torchrun with:
# - Configurable MASTER_ADDR, PORT, NNODES, NODE_RANK
# - DeepSpeed Zero2 by default (set DEEPSPEED env var to override)
# - 5-hour timeout (TORCHELASTIC_TIMEOUT=18000)

Config Files:

  • projects/grasp_any_region/configs/gar_1b.py - 1B model
  • projects/grasp_any_region/configs/gar_8b.py - 8B model

Key training settings (gar_1b.py):

  • Base model: facebook/Perception-LM-1B
  • Batch size: 1 per device Γ— 2 accumulation Γ— 32 GPUs = 64 global
  • Learning rate: 1e-5 (AdamW), warmup: 3%, cosine annealing
  • Max length: 16384 tokens
  • Saves every 5000 steps, keeps last 2 checkpoints

Dataset Preparation

# Download dataset from HuggingFace
hf download HaochenWang/Grasp-Any-Region-Dataset --local-dir data --repo-type dataset

# Expected structure:
# data/
#   β”œβ”€β”€ Seed-Dataset/data-*.arrow
#   β”œβ”€β”€ Fine-Grained-Dataset/data-*.arrow
#   └── Relation-Dataset/data-*.arrow

Inference Demos

Single Region with Mask:

torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_with_mask.py \
  --image_path assets/demo_image_1.png \
  --mask_path assets/demo_mask_1.png

Single Region with SAM (points or box):

# Using points
torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_with_sam.py \
  --image_path assets/demo_image_2.jpg \
  --points '[[1172, 812], [1572, 800]]'

# Using bounding box
torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_with_sam.py \
  --image_path assets/demo_image_2.jpg \
  --box '[800, 500, 1800, 1000]' \
  --use_box

Multi-Region Relationship:

torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_relationship.py \
  --image_path assets/demo_image_3.png \
  --mask_paths "['assets/demo_mask_3_0.png', 'assets/demo_mask_3_1.png', 'assets/demo_mask_3_2.png']" \
  --question_str 'Question: What is the relationship between <Prompt0>, <Prompt1>, and <Prompt2>?'

Gradio Demo:

cd demo/gradio
pip install -r requirements.txt
python app.py

Evaluation

All evaluation scripts follow the same pattern: inference β†’ evaluation with LLM judge (GPT-4o or Llama).

GARBench-Caption-Simple:

# Inference
torchrun --nproc-per-node=1 --master-port=9811 \
  evaluation/GAR-Bench/inference.py \
  --model_name_or_path HaochenWang/GAR-8B \
  --anno_file evaluation/GAR-Bench/annotations/GAR-Bench-Caption-Simple.json \
  --mode simple \
  --cache_name my_test \
  --data_type bf16 \
  --seed 42

# Evaluation (requires Azure OpenAI)
export AZURE_OPENAI_ENDPOINT=YOUR_ENDPOINT
export AZURE_OPENAI_KEY=YOUR_KEY
python3 evaluation/GAR-Bench/eval_simple.py \
  --pred evaluation/GAR-Bench/model_outputs/my_test_simple.json

GARBench-VQA (multi-region reasoning):

torchrun --nproc-per-node=1 --master-port=9811 \
  evaluation/GAR-Bench/inference.py \
  --model_name_or_path HaochenWang/GAR-8B \
  --anno_file evaluation/GAR-Bench/annotations/GAR-Bench-VQA.json \
  --mode vqa \
  --cache_name my_test \
  --data_type bf16
# VQA evaluation is automatic (no LLM judge)

DLC-Bench (detailed localized captioning):

# Download images first
cd evaluation/DLC-Bench/annotations
hf download nvidia/DLC-Bench --repo-type dataset --include "images/*" --local-dir ./
cd ../../..

# Inference
torchrun --nproc-per-node=1 --master-port=8841 \
  evaluation/DLC-Bench/inference.py \
  --model_name_or_path HaochenWang/GAR-8B \
  --cache_name my_test \
  --data_type bf16

# Evaluation with GPT-4o
python3 evaluation/DLC-Bench/eval_gpt_with_image.py \
  --pred evaluation/DLC-Bench/model_outputs/my_test.json

# Alternative: Evaluation with Llama3.1-8B (requires vLLM server)
bash evaluation/DLC-Bench/serve_judge.sh  # in one terminal
python3 evaluation/DLC-Bench/eval_llama_without_image.py \
  --pred evaluation/DLC-Bench/model_outputs/my_test.json \
  --base_url http://localhost:8007/v1

Model Conversion

# Convert trained checkpoint to HuggingFace format
python3 projects/grasp_any_region/hf_models/convert_to_hf.py \
  projects/grasp_any_region/configs/gar_1b.py \
  --pth-model PATH_TO_PTH_MODEL \
  --save-path PATH_TO_SAVE_FOLDER

# Note: Manually copy required .py files to save folder after conversion

Key Implementation Details

RoI Feature Replay Mechanism

The core innovation is at grasp_any_region.py:291-377:

  1. Image features are extracted as tiles (16Γ—16 patches per tile)
  2. Tiles are merged into full-resolution feature map
  3. For each <PromptN> token in input:
    • Extract RoI bounding box from data["bboxes"]
    • Apply torchvision.ops.roi_align to extract 16Γ—16 features
    • Replace prompt tokens in sequence with RoI features
  4. This allows attending to region-specific features with global context

Mask Encoding

Masks are provided as 3-channel images where pixel values encode prompt IDs:

  • Values 0-14: Different region prompts
  • Value 15 (or prompt_numbers): Background (no prompt)
  • mask_patch_embedding (Conv2d) encodes binary masks into feature space
  • Masks are processed at patch level matching vision encoder stride

Data Format

Dataset uses Arrow format with fields:

  • pixel_values: (num_tiles, 3, H, W) image tiles
  • input_ids: Token sequence with special image/prompt tokens
  • labels: Target sequence (-100 for non-loss positions)
  • global_mask_values: Region masks with prompt IDs
  • aspect_ratios: (ncw, nch) tile arrangement
  • bboxes: Dict mapping crop tokens to normalized bbox coordinates

Special Tokens

The model extends base tokenizer with:

  • <Prompt0> through <Prompt14>: Region identifiers in text
  • <NO_Prompt>: Background/non-region marker
  • <|reserved_special_token_{pid+2}|>: Internal crop tokens for feature replay

Important Notes

  • Flash Attention 2 is required - training will fail without it
  • Python 3.11.2 specifically - later versions may have compatibility issues
  • Single batch size only - code asserts batch_size=1 at grasp_any_region.py:270
  • Distributed training required - single-GPU training not well supported
  • DeepSpeed Zero2 - default optimization for memory efficiency
  • torchrun vs torch.distributed.launch - dist.sh tries torchrun first, falls back to launch
  • xTuner framework - all training uses xTuner's runner, not native PyTorch
  • Evaluation randomness - LLM judges have variance even with temperature=0

HuggingFace Models

Pre-trained models available:

  • HaochenWang/GAR-1B - 1 billion parameter model
  • HaochenWang/GAR-8B - 8 billion parameter model

Base architecture:

  • facebook/Perception-LM-1B - Base vision-language model
  • facebook/Perception-LM-8B - Larger variant

Citation

@article{wang2025grasp,
  title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
  author={Haochen Wang and Yuhao Wang and Tao Zhang and Yikang Zhou and Yanwei Li and Jiacong Wang and Ye Tian and Jiahao Meng and Zilong Huang and Guangcan Mai and Anran Wang and Yunhai Tong and Zhuochen Wang and Xiangtai Li and Zhaoxiang Zhang},
  journal={arXiv preprint arXiv:2510.18876},
  year={2025}
}

License

Apache-2.0 License