File size: 6,925 Bytes
bee6757 d256f2c bee6757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
---
license: apache-2.0
base_model:
- Qwen/Qwen2-VL-2B-Instruct
language:
- en
---
<div align="center">
<h1>
MedVLM-R1
</h1>
</div>
<div align="center">
<a href="https://arxiv.org/abs/2502.19634" target="_blank">Paper</a>
</div>
<div align="center">
<a href="https://github.com/JZPeterPan/MedVLM-R1" target="_blank">Code</a>
</div>
# <span id="Start">Introduction</span>
MedVLM-R1 is a medical Vision-Language Model built upon [Qwen2-VL-2B](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) and fine-tuned using the [GRPO](https://arxiv.org/abs/2402.03300) reinforcement learning framework. Trained on 600 MRI VQA samples from the [HuatuoGPT-Vision dataset](https://huggingface.co/datasets/FreedomIntelligence/Medical_Multimodal_Evaluation_Data), MedVLM-R1 excels in out-of-distribution performance on CT and X-ray VQA tasks. It also demonstrates explicit medical reasoning capabilities beyond merely providing final answers, ensuring greater interpretability and trustworthiness in clinical applications.
# <span id="Start">Quick Start</span>
### 1. Load the model
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, GenerationConfig
from qwen_vl_utils import process_vision_info
import torch
MODEL_PATH = 'JZPeterPan/MedVLM-R1'
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
temp_generation_config = GenerationConfig(
max_new_tokens=1024,
do_sample=False,
temperature=1,
num_return_sequences=1,
pad_token_id=151643,
)
```
### 2. Load the VQA Data
Pick one of the following examples. These are samples from [OmniMedVQA](https://huggingface.co/datasets/foreverbeliever/OmniMedVQA) data and are bundled by [HuatuoGPT-Vision](https://huggingface.co/datasets/FreedomIntelligence/Medical_Multimodal_Evaluation_Data).
```python
question = {"image": ['images/successful_cases/mdb146.png'], "problem": "What content appears in this image?\nA) Cardiac tissue\nB) Breast tissue\nC) Liver tissue\nD) Skin tissue", "solution": "B", "answer": "Breast tissue"}
question = {"image": ["images/successful_cases/person19_virus_50.jpeg"], "problem": "What content appears in this image?\nA) Lungs\nB) Bladder\nC) Brain\nD) Heart", "solution": "A", "answer": "Lungs"}
question = {"image":["images/successful_cases/abd-normal023599.png"],"problem":"Is any abnormality evident in this image?\nA) No\nB) Yes.","solution":"A","answer":"No"}
question = {"image":["images/successful_cases/foot089224.png"],"problem":"Which imaging technique was utilized for acquiring this image?\nA) MRI\nB) Electroencephalogram (EEG)\nC) Ultrasound\nD) Angiography","solution":"A","answer":"MRI"}
question = {"image":["images/successful_cases/knee031316.png"],"problem":"What can be observed in this image?\nA) Chondral abnormality\nB) Bone density loss\nC) Synovial cyst formation\nD) Ligament tear","solution":"A","answer":"Chondral abnormality"}
question = {"image":["images/successful_cases/shoulder045906.png"],"problem":"What can be visually detected in this picture?\nA) Bone fracture\nB) Soft tissue fluid\nC) Blood clot\nD) Tendon tear","solution":"B","answer":"Soft tissue fluid"}
question = {"image":["images/successful_cases/brain003631.png"],"problem":"What attribute can be observed in this image?\nA) Focal flair hyperintensity\nB) Bone fracture\nC) Vascular malformation\nD) Ligament tear","solution":"A","answer":"Focal flair hyperintensity"}
question = {"image":["images/successful_cases/mrabd005680.png"],"problem":"What can be observed in this image?\nA) Pulmonary embolism\nB) Pancreatic abscess\nC) Intraperitoneal mass\nD) Cardiac tamponade","solution":"C","answer":"Intraperitoneal mass"}
```
### 3. Run the inference
```python
QUESTION_TEMPLATE = """
{Question}
Your task:
1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.
2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.
3. No extra information or text outside of these tags.
"""
message = [{
"role": "user",
"content": [{"type": "image", "image": f"file://{question['image'][0]}"}, {"type": "text","text": QUESTION_TEMPLATE.format(Question=question['problem'])}]
}]
text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(message)
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False, generation_config=temp_generation_config)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(f'model output: {output_text[0]}')
```
### Failure cases
MedVLM-R1's reasoning fails when testing on more difficult VQA examples. Although it can output correct choices in the following examples, the reasoning of them is either superficial or contradictory.
```python
question = {"image":["images/failure_cases/mrabd021764.png"],"problem":"What is the observable finding in this image?\nA) Brain lesion\nB) Intestinal lesion\nC) Gallbladder lesion\nD) Pancreatic lesion","solution":"D","answer":"Pancreatic lesion"}
question = {"image":["images/failure_cases/spine010017.png"],"problem":"What can be observed in this image?\nA) Cystic lesions\nB) Fractured bones\nC) Inflamed tissue\nD) Nerve damage","solution":"A","answer":"Cystic lesions"}
question = {"image":["images/failure_cases/ankle056120.png"],"problem":"What attribute can be observed in this image?\nA) Bursitis\nB) Flexor pathology\nC) Tendonitis\nD) Joint inflammation","solution":"B","answer":"Flexor pathology"}
question = {"image":["images/failure_cases/lung067009.png"],"problem":"What is the term for the anomaly depicted in the image?\nA) Pulmonary embolism\nB) Airspace opacity\nC) Lung consolidation\nD) Atelectasis","solution":"B","answer":"Airspace opacity"}
```
# <span id="Start">Acknowledgement</span>
We thank all machine learning / medical workers for making public codebase / datasets available to the community 🫶🫶🫶
If you find our work helpful, feel free to give us a cite.
```
@article{pan2025medvlm,
title={MedVLM-R1: Incentivizing Medical Reasoning Capability of Vision-Language Models (VLMs) via Reinforcement Learning},
author={Pan, Jiazhen and Liu, Che and Wu, Junde and Liu, Fenglin and Zhu, Jiayuan and Li, Hongwei Bran and Chen, Chen and Ouyang, Cheng and Rueckert, Daniel},
journal={arXiv preprint arXiv:2502.19634},
year={2025}
}
```
|