File size: 6,925 Bytes
bee6757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d256f2c
 
 
 
bee6757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
---
license: apache-2.0
base_model:
- Qwen/Qwen2-VL-2B-Instruct
language:
- en
---
<div align="center">
<h1>
  MedVLM-R1
</h1>
</div>

<div align="center">
 <a href="https://arxiv.org/abs/2502.19634" target="_blank">Paper</a>
</div>

<div align="center">
 <a href="https://github.com/JZPeterPan/MedVLM-R1" target="_blank">Code</a>
</div>

# <span id="Start">Introduction</span>
MedVLM-R1 is a medical Vision-Language Model built upon [Qwen2-VL-2B](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) and fine-tuned using the [GRPO](https://arxiv.org/abs/2402.03300) reinforcement learning framework. Trained on 600 MRI VQA samples from the [HuatuoGPT-Vision dataset](https://huggingface.co/datasets/FreedomIntelligence/Medical_Multimodal_Evaluation_Data), MedVLM-R1 excels in out-of-distribution performance on CT and X-ray VQA tasks. It also demonstrates explicit medical reasoning capabilities beyond merely providing final answers, ensuring greater interpretability and trustworthiness in clinical applications.

# <span id="Start">Quick Start</span>

### 1. Load the model
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, GenerationConfig
from qwen_vl_utils import process_vision_info
import torch

MODEL_PATH = 'JZPeterPan/MedVLM-R1'

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)

temp_generation_config = GenerationConfig(
    max_new_tokens=1024,
    do_sample=False,  
    temperature=1, 
    num_return_sequences=1,
    pad_token_id=151643,
)
```
### 2. Load the VQA Data
Pick one of the following examples. These are samples from [OmniMedVQA](https://huggingface.co/datasets/foreverbeliever/OmniMedVQA) data and are bundled by [HuatuoGPT-Vision](https://huggingface.co/datasets/FreedomIntelligence/Medical_Multimodal_Evaluation_Data).

```python
question = {"image": ['images/successful_cases/mdb146.png'], "problem": "What content appears in this image?\nA) Cardiac tissue\nB) Breast tissue\nC) Liver tissue\nD) Skin tissue", "solution": "B", "answer": "Breast tissue"}

question = {"image": ["images/successful_cases/person19_virus_50.jpeg"], "problem": "What content appears in this image?\nA) Lungs\nB) Bladder\nC) Brain\nD) Heart", "solution": "A", "answer": "Lungs"}

question = {"image":["images/successful_cases/abd-normal023599.png"],"problem":"Is any abnormality evident in this image?\nA) No\nB) Yes.","solution":"A","answer":"No"}

question = {"image":["images/successful_cases/foot089224.png"],"problem":"Which imaging technique was utilized for acquiring this image?\nA) MRI\nB) Electroencephalogram (EEG)\nC) Ultrasound\nD) Angiography","solution":"A","answer":"MRI"}

question = {"image":["images/successful_cases/knee031316.png"],"problem":"What can be observed in this image?\nA) Chondral abnormality\nB) Bone density loss\nC) Synovial cyst formation\nD) Ligament tear","solution":"A","answer":"Chondral abnormality"}

question = {"image":["images/successful_cases/shoulder045906.png"],"problem":"What can be visually detected in this picture?\nA) Bone fracture\nB) Soft tissue fluid\nC) Blood clot\nD) Tendon tear","solution":"B","answer":"Soft tissue fluid"}

question = {"image":["images/successful_cases/brain003631.png"],"problem":"What attribute can be observed in this image?\nA) Focal flair hyperintensity\nB) Bone fracture\nC) Vascular malformation\nD) Ligament tear","solution":"A","answer":"Focal flair hyperintensity"}

question = {"image":["images/successful_cases/mrabd005680.png"],"problem":"What can be observed in this image?\nA) Pulmonary embolism\nB) Pancreatic abscess\nC) Intraperitoneal mass\nD) Cardiac tamponade","solution":"C","answer":"Intraperitoneal mass"}
```
### 3. Run the inference

```python
QUESTION_TEMPLATE = """
    {Question} 
    Your task: 
    1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags. 
    2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.
    3. No extra information or text outside of these tags.
    """

message = [{
    "role": "user",
    "content": [{"type": "image", "image": f"file://{question['image'][0]}"}, {"type": "text","text": QUESTION_TEMPLATE.format(Question=question['problem'])}]
}]

text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    
image_inputs, video_inputs = process_vision_info(message)
inputs = processor(
    text=text,
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
).to("cuda")

generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False, generation_config=temp_generation_config)

generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]

output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(f'model output: {output_text[0]}')

```
### Failure cases
MedVLM-R1's reasoning fails when testing on more difficult VQA examples. Although it can output correct choices in the following examples, the reasoning of them is either superficial or contradictory. 
```python
question = {"image":["images/failure_cases/mrabd021764.png"],"problem":"What is the observable finding in this image?\nA) Brain lesion\nB) Intestinal lesion\nC) Gallbladder lesion\nD) Pancreatic lesion","solution":"D","answer":"Pancreatic lesion"}

question = {"image":["images/failure_cases/spine010017.png"],"problem":"What can be observed in this image?\nA) Cystic lesions\nB) Fractured bones\nC) Inflamed tissue\nD) Nerve damage","solution":"A","answer":"Cystic lesions"}

question = {"image":["images/failure_cases/ankle056120.png"],"problem":"What attribute can be observed in this image?\nA) Bursitis\nB) Flexor pathology\nC) Tendonitis\nD) Joint inflammation","solution":"B","answer":"Flexor pathology"}

question = {"image":["images/failure_cases/lung067009.png"],"problem":"What is the term for the anomaly depicted in the image?\nA) Pulmonary embolism\nB) Airspace opacity\nC) Lung consolidation\nD) Atelectasis","solution":"B","answer":"Airspace opacity"}

```

# <span id="Start">Acknowledgement</span>
We thank all machine learning / medical workers for making public codebase / datasets available to the community 🫶🫶🫶

If you find our work helpful, feel free to give us a cite.

```
@article{pan2025medvlm,
  title={MedVLM-R1: Incentivizing Medical Reasoning Capability of Vision-Language Models (VLMs) via Reinforcement Learning},
  author={Pan, Jiazhen and Liu, Che and Wu, Junde and Liu, Fenglin and Zhu, Jiayuan and Li, Hongwei Bran and Chen, Chen and Ouyang, Cheng and Rueckert, Daniel},
  journal={arXiv preprint arXiv:2502.19634},
  year={2025}
}
```