WMH Segmentation: Normal vs Abnormal Classification
Pre-trained models for white matter hyperintensity (WMH) segmentation with explicit distinction between normal periventricular changes and pathological lesions.
Model Description
This repository contains 8 pre-trained deep learning models (4 architectures Γ 2 training scenarios) for automated WMH segmentation from FLAIR MRI images. The models implement a novel three-class approach that distinguishes between:
- Class 0: Background
- Class 1: Normal WMH (aging-related periventricular changes)
- Class 2: Abnormal WMH (pathologically significant lesions)
This approach addresses the critical challenge of false positive detection in periventricular regions, achieving up to 27.1% improvement in Dice coefficient compared to traditional binary segmentation.
Model Architectures
| Architecture | Parameters | Best Dice (3-Class) | Binary Baseline | Improvement |
|---|---|---|---|---|
| U-Net β | 31.0M | 0.768 | 0.497 | +54.5% |
| Attention U-Net | 34.9M | 0.740 | 0.486 | +52.1% |
| TransUNet | 105.3M | 0.700 | 0.510 | +37.3% |
| DeepLabV3Plus | 40.3M | 0.586 | 0.374 | +56.7% |
β Recommended: U-Net with Scenario 2 (three-class) for optimal performance
Repository Structure
models/
βββ unet/models/
β βββ scenario1_binary_model.h5 # Binary: Background vs Abnormal
β βββ scenario2_multiclass_model.h5 # 3-Class: Background, Normal, Abnormal
βββ attention_unet/models/
β βββ scenario1_binary_model.h5
β βββ scenario2_multiclass_model.h5
βββ deeplabv3plus/models/
β βββ scenario1_binary_model.h5
β βββ scenario2_multiclass_model.h5
βββ transunet/models/
βββ scenario1_binary_model.h5
βββ scenario2_multiclass_model.h5
Quick Start
Installation
pip install huggingface_hub tensorflow numpy nibabel
Download Models
from huggingface_hub import hf_hub_download
# Download best performing model (U-Net Three-Class)
model_path = hf_hub_download(
repo_id="Bawil/wmh_leverage_normal_abnormal_segmentation",
filename="unet/models/scenario2_multiclass_model.h5"
)
# Load model
from tensorflow.keras.models import load_model
model = load_model(model_path)
Inference Example
import numpy as np
from tensorflow.keras.models import load_model
# Load pre-trained model
model = load_model(model_path)
# Prepare input (256x256 grayscale FLAIR MRI, normalized)
# input_image shape: (batch_size, 256, 256, 1)
input_image = preprocess_flair(your_flair_image)
# Run inference
predictions = model.predict(input_image)
# Get class predictions
predicted_classes = np.argmax(predictions, axis=-1)
# 0: Background
# 1: Normal WMH (periventricular)
# 2: Abnormal WMH (pathological)
# Extract pathological lesions only
abnormal_mask = (predicted_classes == 2).astype(np.uint8)
Training Data
Dataset Composition
Local Dataset: 100 MS patients (2,000 FLAIR MRI slices)
- Demographics: 26 males, 74 females
- Age range: 18-68 years
- Scanner: 1.5-Tesla TOSHIBA Vantage
Public Dataset: MSSEG2016 (15 patients, 750 FLAIR slices)
Annotations
- Expert annotations by board-certified neuroradiologists (20+ years experience)
- Three-class labeling: Background, Normal WMH, Abnormal WMH
- Approved by Ethics Committee (IR.TBZMED.REC.1402.902)
Data Split
- Training: 80% patients (local) + 60% patients (public)
- Validation: 10% patients (local) + 20% patients (public)
- Testing: 10% patients (local) + 20% patients (public)
- Strategy: Patient-level stratified split (no slice-level leakage)
Model Training
Configuration
- Framework: TensorFlow 2.11, Keras
- Optimizer: Adam (learning rate: 0.0001)
- Loss Functions:
- Scenario 1: Weighted binary cross-entropy
- Scenario 2: Weighted categorical cross-entropy
- Epochs: 50 (with early stopping)
- Batch Size: 8
- Input Size: 256Γ256Γ1
- Data Augmentation: Rotation, flipping, elastic deformation
Hardware
- GPU: NVIDIA RTX 3060 (12GB VRAM)
- Training Time: 2-3 hours per model
- Inference Time: ~35-40ms per image
Model Performance
Dice Coefficient (Primary Metric)
| Model | Scenario 1 | Scenario 2 | Ξ Improvement | p-value | Cohen's d |
|---|---|---|---|---|---|
| U-Net | 0.497Β±0.145 | 0.768Β±0.124 | +0.271 | <0.0001 | 0.564 |
| Attention U-Net | 0.486Β±0.157 | 0.740Β±0.133 | +0.253 | <0.0001 | 0.442 |
| TransUNet | 0.510Β±0.116 | 0.700Β±0.097 | +0.190 | <0.0001 | 0.478 |
| DeepLabV3Plus | 0.374Β±0.110 | 0.586Β±0.092 | +0.212 | <0.0001 | 0.565 |
Additional Metrics
- Hausdorff Distance: 27.4mm (U-Net 3-class) vs 29.8mm (binary)
- Precision: Significant improvement in pathological lesion detection
- False Positive Reduction: Marked decrease in periventricular regions
- Clinical Feasibility: 1.5s total processing time per case (40 slices)
Statistical Validation
- Paired t-tests confirm significant improvements (all p < 0.0001)
- Effect sizes range from medium (0.44) to large (0.56)
- 95% confidence intervals reported for all metrics
- Wilcoxon signed-rank test for non-parametric validation
Use Cases
Clinical Applications
- MS Lesion Quantification: Accurate measurement of disease burden
- Differential Diagnosis: Distinguish pathological from normal aging
- Longitudinal Monitoring: Track disease progression over time
- Treatment Response: Evaluate therapeutic efficacy
- Radiological Reporting: Reduce false positive alerts
Research Applications
- Baseline Comparisons: Standardized evaluation framework
- Method Development: Foundation for advanced segmentation approaches
- Multi-center Studies: Protocol for broader validation
- Reproducible Research: Complete implementation available
Limitations
- Single Modality: Trained on FLAIR MRI only
- Scanner Specificity: Primarily 1.5T TOSHIBA data
- Disease Focus: Optimized for MS patients
- 2D Segmentation: Slice-by-slice processing (no 3D context)
- Resolution: Fixed 256Γ256 input size
Model Card
Intended Use
- Primary: Automated WMH segmentation for research and clinical decision support
- Users: Radiologists, neurologists, researchers, AI developers
- Out-of-scope: Not FDA/CE approved; not for standalone clinical diagnosis
Ethical Considerations
- Privacy: All data anonymized per HIPAA/GDPR standards
- Bias: Limited scanner/protocol diversity may affect generalization
- Clinical Validation: Requires expert review before clinical use
- Transparency: Complete methodology and code openly available
Model Card Authors
Mahdi Bashiri Bawil, Mousa Shamsi, Ali Fahmi Jafargholkhanloo, Abolhassan Shakeri Bavil
Citation
@article{bawil2025wmh,
title={Incorporating Normal Periventricular Changes for Enhanced Pathological
White Matter Hyperintensity Segmentation: On Multi-Class Deep Learning Approaches},
author={Bawil, Mahdi Bashiri and Shamsi, Mousa and Jafargholkhanloo, Ali Fahmi and
Bavil, Abolhassan Shakeri},
year={2025},
note={Models: https://huggingface.co/Bawil/wmh_leverage_normal_abnormal_segmentation}
}
License
MIT License - See LICENSE
Additional Resources
- π Paper: [Under Review]
- π» GitHub Repository: Mahdi-Bashiri/wmh-normal-abnormal-segmentation
- π§ Contact: m_bashiri99@sut.ac.ir
- π₯ Institution: Sahand University of Technology & Tabriz University of Medical Sciences
Acknowledgments
- Golgasht Medical Imaging Center, Tabriz, Iran for providing clinical data
- Expert neuroradiologists for manual annotations
- Ethics Committee approval: IR.TBZMED.REC.1402.902
Keywords: white matter hyperintensities, FLAIR MRI, medical imaging, deep learning, image segmentation, multiple sclerosis, U-Net, attention mechanisms, transformers, clinical AI