File size: 10,718 Bytes
46861c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

**Grasp Any Region (GAR)** is a research project for region-level multimodal understanding in vision-language models. It enables:

1. **Single Region Understanding**: Detailed description of specific image/video regions via points/boxes/scribbles/masks
2. **Multi-Region Reasoning**: Complex relationship modeling and reasoning across multiple regions simultaneously
3. **Advanced Compositional Reasoning**: Active dialogue about regions rather than passive description

The model is built on top of Facebook's Perception-LM architecture and uses xTuner training framework with PyTorch distributed training.

## Architecture

### Core Components

**Model Architecture** (`projects/grasp_any_region/models/grasp_any_region.py:GraspAnyRegion`):
- Wraps `PerceptionLMForConditionalGeneration` from HuggingFace
- Key innovation: **RoI-aligned feature replay technique** using `torchvision.ops.roi_align`
- Adds `mask_patch_embedding` layer (Conv2d) for region mask encoding
- Supports 15 visual prompt tokens (`<Prompt0>` through `<Prompt14>`) plus `<NO_Prompt>`
- Forward pass implements feature replay mechanism at grasp_any_region.py:291-377

**Visual Prompt System**:
- Masks are encoded with prompt IDs (0-14) where each ID represents a different region
- Special value (15 = `<NO_Prompt>`) indicates background/non-region areas
- RoI features are extracted using bounding boxes and replayed into the sequence at crop token positions

**Training Pipeline**:
- Uses xTuner framework (built on MMEngine)
- Dataset: Arrow format with three subsets (Seed, Fine-Grained, Relation)
- Custom collate function handles variable-length sequences and multi-region inputs
- Flash Attention 2 required for efficiency

### Directory Structure

```
projects/grasp_any_region/     # Main model code
  β”œβ”€β”€ configs/                 # Training configs (gar_1b.py, gar_8b.py)
  β”œβ”€β”€ models/
  β”‚   β”œβ”€β”€ grasp_any_region.py  # Main model wrapper
  β”‚   └── modeling/            # Custom PerceptionLM implementations
  β”œβ”€β”€ datasets/                # Dataset and data loading
  └── hf_models/               # HuggingFace conversion utilities

demo/                          # Inference demos
  β”œβ”€β”€ gar_with_mask.py        # Direct mask input
  β”œβ”€β”€ gar_with_sam.py         # SAM-based region selection
  β”œβ”€β”€ gar_relationship.py     # Multi-region reasoning
  └── gradio/                 # Web demo

evaluation/                    # Benchmarks
  β”œβ”€β”€ GAR-Bench/              # Custom benchmark (Caption-Simple, Caption-Detailed, VQA)
  β”œβ”€β”€ DLC-Bench/              # Detailed localized captioning
  β”œβ”€β”€ Ferret-Bench/           # Region description
  └── MDVP-Bench/             # Multi-domain visual perception

tools/
  β”œβ”€β”€ train.py                # Training entry point
  β”œβ”€β”€ test.py                 # Testing entry point
  └── dist.sh                 # Distributed training launcher
```

## Common Commands

### Environment Setup

```bash
# Create environment (requires Python 3.11.2)
conda create -n gar python=3.11.2 -y
conda activate gar

# Install dependencies
pip3 install xtuner==0.2.0rc0
pip3 install -r requirements.txt
pip3 install flash-attn==2.7.4.post1 --no-build-isolation -v
```

### Training

```bash
# Single-node distributed training (8 GPUs)
bash tools/dist.sh train projects/grasp_any_region/configs/gar_1b.py 8

# The dist.sh script uses torchrun with:
# - Configurable MASTER_ADDR, PORT, NNODES, NODE_RANK
# - DeepSpeed Zero2 by default (set DEEPSPEED env var to override)
# - 5-hour timeout (TORCHELASTIC_TIMEOUT=18000)
```

**Config Files**:
- `projects/grasp_any_region/configs/gar_1b.py` - 1B model
- `projects/grasp_any_region/configs/gar_8b.py` - 8B model

Key training settings (gar_1b.py):
- Base model: `facebook/Perception-LM-1B`
- Batch size: 1 per device Γ— 2 accumulation Γ— 32 GPUs = 64 global
- Learning rate: 1e-5 (AdamW), warmup: 3%, cosine annealing
- Max length: 16384 tokens
- Saves every 5000 steps, keeps last 2 checkpoints

### Dataset Preparation

```bash
# Download dataset from HuggingFace
hf download HaochenWang/Grasp-Any-Region-Dataset --local-dir data --repo-type dataset

# Expected structure:
# data/
#   β”œβ”€β”€ Seed-Dataset/data-*.arrow
#   β”œβ”€β”€ Fine-Grained-Dataset/data-*.arrow
#   └── Relation-Dataset/data-*.arrow
```

### Inference Demos

**Single Region with Mask**:
```bash
torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_with_mask.py \
  --image_path assets/demo_image_1.png \
  --mask_path assets/demo_mask_1.png
```

**Single Region with SAM** (points or box):
```bash
# Using points
torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_with_sam.py \
  --image_path assets/demo_image_2.jpg \
  --points '[[1172, 812], [1572, 800]]'

# Using bounding box
torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_with_sam.py \
  --image_path assets/demo_image_2.jpg \
  --box '[800, 500, 1800, 1000]' \
  --use_box
```

**Multi-Region Relationship**:
```bash
torchrun --nproc-per-node=1 --master-port=8119 \
  demo/gar_relationship.py \
  --image_path assets/demo_image_3.png \
  --mask_paths "['assets/demo_mask_3_0.png', 'assets/demo_mask_3_1.png', 'assets/demo_mask_3_2.png']" \
  --question_str 'Question: What is the relationship between <Prompt0>, <Prompt1>, and <Prompt2>?'
```

**Gradio Demo**:
```bash
cd demo/gradio
pip install -r requirements.txt
python app.py
```

### Evaluation

All evaluation scripts follow the same pattern: inference β†’ evaluation with LLM judge (GPT-4o or Llama).

**GARBench-Caption-Simple**:
```bash
# Inference
torchrun --nproc-per-node=1 --master-port=9811 \
  evaluation/GAR-Bench/inference.py \
  --model_name_or_path HaochenWang/GAR-8B \
  --anno_file evaluation/GAR-Bench/annotations/GAR-Bench-Caption-Simple.json \
  --mode simple \
  --cache_name my_test \
  --data_type bf16 \
  --seed 42

# Evaluation (requires Azure OpenAI)
export AZURE_OPENAI_ENDPOINT=YOUR_ENDPOINT
export AZURE_OPENAI_KEY=YOUR_KEY
python3 evaluation/GAR-Bench/eval_simple.py \
  --pred evaluation/GAR-Bench/model_outputs/my_test_simple.json
```

**GARBench-VQA** (multi-region reasoning):
```bash
torchrun --nproc-per-node=1 --master-port=9811 \
  evaluation/GAR-Bench/inference.py \
  --model_name_or_path HaochenWang/GAR-8B \
  --anno_file evaluation/GAR-Bench/annotations/GAR-Bench-VQA.json \
  --mode vqa \
  --cache_name my_test \
  --data_type bf16
# VQA evaluation is automatic (no LLM judge)
```

**DLC-Bench** (detailed localized captioning):
```bash
# Download images first
cd evaluation/DLC-Bench/annotations
hf download nvidia/DLC-Bench --repo-type dataset --include "images/*" --local-dir ./
cd ../../..

# Inference
torchrun --nproc-per-node=1 --master-port=8841 \
  evaluation/DLC-Bench/inference.py \
  --model_name_or_path HaochenWang/GAR-8B \
  --cache_name my_test \
  --data_type bf16

# Evaluation with GPT-4o
python3 evaluation/DLC-Bench/eval_gpt_with_image.py \
  --pred evaluation/DLC-Bench/model_outputs/my_test.json

# Alternative: Evaluation with Llama3.1-8B (requires vLLM server)
bash evaluation/DLC-Bench/serve_judge.sh  # in one terminal
python3 evaluation/DLC-Bench/eval_llama_without_image.py \
  --pred evaluation/DLC-Bench/model_outputs/my_test.json \
  --base_url http://localhost:8007/v1
```

### Model Conversion

```bash
# Convert trained checkpoint to HuggingFace format
python3 projects/grasp_any_region/hf_models/convert_to_hf.py \
  projects/grasp_any_region/configs/gar_1b.py \
  --pth-model PATH_TO_PTH_MODEL \
  --save-path PATH_TO_SAVE_FOLDER

# Note: Manually copy required .py files to save folder after conversion
```

## Key Implementation Details

### RoI Feature Replay Mechanism

The core innovation is at `grasp_any_region.py:291-377`:

1. Image features are extracted as tiles (16Γ—16 patches per tile)
2. Tiles are merged into full-resolution feature map
3. For each `<PromptN>` token in input:
   - Extract RoI bounding box from `data["bboxes"]`
   - Apply `torchvision.ops.roi_align` to extract 16Γ—16 features
   - Replace prompt tokens in sequence with RoI features
4. This allows attending to region-specific features with global context

### Mask Encoding

Masks are provided as 3-channel images where pixel values encode prompt IDs:
- Values 0-14: Different region prompts
- Value 15 (or `prompt_numbers`): Background (no prompt)
- `mask_patch_embedding` (Conv2d) encodes binary masks into feature space
- Masks are processed at patch level matching vision encoder stride

### Data Format

Dataset uses Arrow format with fields:
- `pixel_values`: (num_tiles, 3, H, W) image tiles
- `input_ids`: Token sequence with special image/prompt tokens
- `labels`: Target sequence (-100 for non-loss positions)
- `global_mask_values`: Region masks with prompt IDs
- `aspect_ratios`: (ncw, nch) tile arrangement
- `bboxes`: Dict mapping crop tokens to normalized bbox coordinates

### Special Tokens

The model extends base tokenizer with:
- `<Prompt0>` through `<Prompt14>`: Region identifiers in text
- `<NO_Prompt>`: Background/non-region marker
- `<|reserved_special_token_{pid+2}|>`: Internal crop tokens for feature replay

## Important Notes

- **Flash Attention 2 is required** - training will fail without it
- **Python 3.11.2 specifically** - later versions may have compatibility issues
- **Single batch size only** - code asserts `batch_size=1` at grasp_any_region.py:270
- **Distributed training required** - single-GPU training not well supported
- **DeepSpeed Zero2** - default optimization for memory efficiency
- **torchrun vs torch.distributed.launch** - dist.sh tries torchrun first, falls back to launch
- **xTuner framework** - all training uses xTuner's runner, not native PyTorch
- **Evaluation randomness** - LLM judges have variance even with temperature=0

## HuggingFace Models

Pre-trained models available:
- `HaochenWang/GAR-1B` - 1 billion parameter model
- `HaochenWang/GAR-8B` - 8 billion parameter model

Base architecture:
- `facebook/Perception-LM-1B` - Base vision-language model
- `facebook/Perception-LM-8B` - Larger variant

## Citation

```bibtex
@article{wang2025grasp,
  title={Grasp Any Region: Prompting MLLM to Understand the Dense World},
  author={Haochen Wang and Yuhao Wang and Tao Zhang and Yikang Zhou and Yanwei Li and Jiacong Wang and Ye Tian and Jiahao Meng and Zilong Huang and Guangcan Mai and Anran Wang and Yunhai Tong and Zhuochen Wang and Xiangtai Li and Zhaoxiang Zhang},
  journal={arXiv preprint arXiv:2510.18876},
  year={2025}
}
```

## License

Apache-2.0 License