MedVLM-R1 / README.md
nielsr's picture
nielsr HF Staff
Improve model card
f9f8626 verified
|
raw
history blame
5.86 kB
metadata
base_model:
  - Qwen/Qwen2-VL-2B-Instruct
language:
  - en
license: apache-2.0
pipeline_tag: image-text-to-text

MedVLM-R1

Introduction

MedVLM-R1 is a medical Vision-Language Model built upon Qwen2-VL-2B and fine-tuned using the GRPO reinforcement learning framework. Trained on 600 MRI VQA samples from the HuatuoGPT-Vision dataset, MedVLM-R1 excels in out-of-distribution performance on CT and X-ray VQA tasks. It also demonstrates explicit medical reasoning capabilities beyond merely providing final answers, ensuring greater interpretability and trustworthiness in clinical applications.

Paper Abstract:

Reasoning is a critical frontier for advancing medical image analysis, where transparency and trustworthiness play a central role in both clinician trust and regulatory approval. Although Medical Visual Language Models (VLMs) show promise for radiological tasks, most existing VLMs merely produce final answers without revealing the underlying reasoning. To address this gap, we introduce MedVLM-R1, a medical VLM that explicitly generates natural language reasoning to enhance transparency and trustworthiness. Instead of relying on supervised fine-tuning (SFT), which often suffers from overfitting to training distributions and fails to foster genuine reasoning, MedVLM-R1 employs a reinforcement learning framework that incentivizes the model to discover human-interpretable reasoning paths without using any reasoning references. Despite limited training data (600 visual question answering samples) and model parameters (2B), MedVLM-R1 boosts accuracy from 55.11% to 78.22% across MRI, CT, and X-ray benchmarks, outperforming larger models trained on over a million samples. It also demonstrates robust domain generalization under out-of-distribution tasks. By unifying medical image analysis with explicit reasoning, MedVLM-R1 marks a pivotal step toward trustworthy and interpretable AI in clinical practice.

Github repo: https://github.com/jzpan/MedVLM-R1

Quick Start

1. Load the model

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, GenerationConfig
from qwen_vl_utils import process_vision_info
import torch

MODEL_PATH = 'JZPeterPan/MedVLM-R1'

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)

temp_generation_config = GenerationConfig(
    max_new_tokens=1024,
    do_sample=False,  
    temperature=1, 
    num_return_sequences=1,
    pad_token_id=151643,
)

2. Question Template

QUESTION_TEMPLATE = """
    {Question} 
    Your task: 
    1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags. 
    2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.
    3. No extra information or text outside of these tags.
    """

3. Load the VQA Data

Pick one of the following examples. These are samples from OmniMedVQA data and are bundled by HuatuoGPT-Vision.

question = {"image": ['images/successful_cases/mdb146.png'], "problem": "What content appears in this image?
A) Cardiac tissue
B) Breast tissue
C) Liver tissue
D) Skin tissue", "solution": "B", "answer": "Breast tissue"}

question = {"image": ["images/successful_cases/person19_virus_50.jpeg"], "problem": "What content appears in this image?
A) Lungs
B) Bladder
C) Brain
D) Heart", "solution": "A", "answer": "Lungs"}

# ... other example questions

4. Run the inference

message = [{
    "role": "user",
    "content": [{"type": "image", "image": f"file://{question['image'][0]}"}, {"type": "text","text": QUESTION_TEMPLATE.format(Question=question['problem'])}]
}]

text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    
image_inputs, video_inputs = process_vision_info(message)
inputs = processor(
    text=text,
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
).to("cuda")

generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False, generation_config=temp_generation_config)

generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]

output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(f'model output: {output_text[0]}')

Failure cases

MedVLM-R1's reasoning fails when testing on more difficult VQA examples. Although it can output correct choices in the following examples, the reasoning of them is either superficial or contradictory.

# ... failure case examples

Acknowledgement

We thank all machine learning / medical workers for making public codebase / datasets available to the community 🫶🫶🫶

If you find our work helpful, feel free to give us a cite.

@article{pan2025medvlm,
  title={MedVLM-R1: Incentivizing Medical Reasoning Capability of Vision-Language Models (VLMs) via Reinforcement Learning},
  author={Pan, Jiazhen and Liu, Che and Wu, Junde and Liu, Fenglin and Zhu, Jiayuan and Li, Hongwei Bran and Chen, Chen and Ouyang, Cheng and Rueckert, Daniel},
  journal={arXiv preprint arXiv:2502.19634},
  year={2025}
}