WMH Segmentation: Normal vs Abnormal Classification

Pre-trained models for white matter hyperintensity (WMH) segmentation with explicit distinction between normal periventricular changes and pathological lesions.

Model Description

This repository contains 8 pre-trained deep learning models (4 architectures Γ— 2 training scenarios) for automated WMH segmentation from FLAIR MRI images. The models implement a novel three-class approach that distinguishes between:

  • Class 0: Background
  • Class 1: Normal WMH (aging-related periventricular changes)
  • Class 2: Abnormal WMH (pathologically significant lesions)

This approach addresses the critical challenge of false positive detection in periventricular regions, achieving up to 27.1% improvement in Dice coefficient compared to traditional binary segmentation.

Model Architectures

Architecture Parameters Best Dice (3-Class) Binary Baseline Improvement
U-Net ⭐ 31.0M 0.768 0.497 +54.5%
Attention U-Net 34.9M 0.740 0.486 +52.1%
TransUNet 105.3M 0.700 0.510 +37.3%
DeepLabV3Plus 40.3M 0.586 0.374 +56.7%

⭐ Recommended: U-Net with Scenario 2 (three-class) for optimal performance

Repository Structure

models/
β”œβ”€β”€ unet/models/
β”‚   β”œβ”€β”€ scenario1_binary_model.h5          # Binary: Background vs Abnormal
β”‚   └── scenario2_multiclass_model.h5      # 3-Class: Background, Normal, Abnormal
β”œβ”€β”€ attention_unet/models/
β”‚   β”œβ”€β”€ scenario1_binary_model.h5
β”‚   └── scenario2_multiclass_model.h5
β”œβ”€β”€ deeplabv3plus/models/
β”‚   β”œβ”€β”€ scenario1_binary_model.h5
β”‚   └── scenario2_multiclass_model.h5
└── transunet/models/
    β”œβ”€β”€ scenario1_binary_model.h5
    └── scenario2_multiclass_model.h5

Quick Start

Installation

pip install huggingface_hub tensorflow numpy nibabel

Download Models

from huggingface_hub import hf_hub_download

# Download best performing model (U-Net Three-Class)
model_path = hf_hub_download(
    repo_id="Bawil/wmh_leverage_normal_abnormal_segmentation",
    filename="unet/models/scenario2_multiclass_model.h5"
)

# Load model
from tensorflow.keras.models import load_model
model = load_model(model_path)

Inference Example

import numpy as np
from tensorflow.keras.models import load_model

# Load pre-trained model
model = load_model(model_path)

# Prepare input (256x256 grayscale FLAIR MRI, normalized)
# input_image shape: (batch_size, 256, 256, 1)
input_image = preprocess_flair(your_flair_image)

# Run inference
predictions = model.predict(input_image)

# Get class predictions
predicted_classes = np.argmax(predictions, axis=-1)
# 0: Background
# 1: Normal WMH (periventricular)
# 2: Abnormal WMH (pathological)

# Extract pathological lesions only
abnormal_mask = (predicted_classes == 2).astype(np.uint8)

Training Data

Dataset Composition

  • Local Dataset: 100 MS patients (2,000 FLAIR MRI slices)

    • Demographics: 26 males, 74 females
    • Age range: 18-68 years
    • Scanner: 1.5-Tesla TOSHIBA Vantage
  • Public Dataset: MSSEG2016 (15 patients, 750 FLAIR slices)

Annotations

  • Expert annotations by board-certified neuroradiologists (20+ years experience)
  • Three-class labeling: Background, Normal WMH, Abnormal WMH
  • Approved by Ethics Committee (IR.TBZMED.REC.1402.902)

Data Split

  • Training: 80% patients (local) + 60% patients (public)
  • Validation: 10% patients (local) + 20% patients (public)
  • Testing: 10% patients (local) + 20% patients (public)
  • Strategy: Patient-level stratified split (no slice-level leakage)

Model Training

Configuration

  • Framework: TensorFlow 2.11, Keras
  • Optimizer: Adam (learning rate: 0.0001)
  • Loss Functions:
    • Scenario 1: Weighted binary cross-entropy
    • Scenario 2: Weighted categorical cross-entropy
  • Epochs: 50 (with early stopping)
  • Batch Size: 8
  • Input Size: 256Γ—256Γ—1
  • Data Augmentation: Rotation, flipping, elastic deformation

Hardware

  • GPU: NVIDIA RTX 3060 (12GB VRAM)
  • Training Time: 2-3 hours per model
  • Inference Time: ~35-40ms per image

Model Performance

Dice Coefficient (Primary Metric)

Model Scenario 1 Scenario 2 Ξ” Improvement p-value Cohen's d
U-Net 0.497Β±0.145 0.768Β±0.124 +0.271 <0.0001 0.564
Attention U-Net 0.486Β±0.157 0.740Β±0.133 +0.253 <0.0001 0.442
TransUNet 0.510Β±0.116 0.700Β±0.097 +0.190 <0.0001 0.478
DeepLabV3Plus 0.374Β±0.110 0.586Β±0.092 +0.212 <0.0001 0.565

Additional Metrics

  • Hausdorff Distance: 27.4mm (U-Net 3-class) vs 29.8mm (binary)
  • Precision: Significant improvement in pathological lesion detection
  • False Positive Reduction: Marked decrease in periventricular regions
  • Clinical Feasibility: 1.5s total processing time per case (40 slices)

Statistical Validation

  • Paired t-tests confirm significant improvements (all p < 0.0001)
  • Effect sizes range from medium (0.44) to large (0.56)
  • 95% confidence intervals reported for all metrics
  • Wilcoxon signed-rank test for non-parametric validation

Use Cases

Clinical Applications

  • MS Lesion Quantification: Accurate measurement of disease burden
  • Differential Diagnosis: Distinguish pathological from normal aging
  • Longitudinal Monitoring: Track disease progression over time
  • Treatment Response: Evaluate therapeutic efficacy
  • Radiological Reporting: Reduce false positive alerts

Research Applications

  • Baseline Comparisons: Standardized evaluation framework
  • Method Development: Foundation for advanced segmentation approaches
  • Multi-center Studies: Protocol for broader validation
  • Reproducible Research: Complete implementation available

Limitations

  • Single Modality: Trained on FLAIR MRI only
  • Scanner Specificity: Primarily 1.5T TOSHIBA data
  • Disease Focus: Optimized for MS patients
  • 2D Segmentation: Slice-by-slice processing (no 3D context)
  • Resolution: Fixed 256Γ—256 input size

Model Card

Intended Use

  • Primary: Automated WMH segmentation for research and clinical decision support
  • Users: Radiologists, neurologists, researchers, AI developers
  • Out-of-scope: Not FDA/CE approved; not for standalone clinical diagnosis

Ethical Considerations

  • Privacy: All data anonymized per HIPAA/GDPR standards
  • Bias: Limited scanner/protocol diversity may affect generalization
  • Clinical Validation: Requires expert review before clinical use
  • Transparency: Complete methodology and code openly available

Model Card Authors

Mahdi Bashiri Bawil, Mousa Shamsi, Ali Fahmi Jafargholkhanloo, Abolhassan Shakeri Bavil

Citation

@article{bawil2025wmh,
  title={Incorporating Normal Periventricular Changes for Enhanced Pathological 
         White Matter Hyperintensity Segmentation: On Multi-Class Deep Learning Approaches},
  author={Bawil, Mahdi Bashiri and Shamsi, Mousa and Jafargholkhanloo, Ali Fahmi and 
          Bavil, Abolhassan Shakeri},
  year={2025},
  note={Models: https://huggingface.co/Bawil/wmh_leverage_normal_abnormal_segmentation}
}

License

MIT License - See LICENSE

Additional Resources

Acknowledgments

  • Golgasht Medical Imaging Center, Tabriz, Iran for providing clinical data
  • Expert neuroradiologists for manual annotations
  • Ethics Committee approval: IR.TBZMED.REC.1402.902

Keywords: white matter hyperintensities, FLAIR MRI, medical imaging, deep learning, image segmentation, multiple sclerosis, U-Net, attention mechanisms, transformers, clinical AI

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support