Diffusion-Based ECG Augmentation Model
A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for training data augmentation in AFib detection systems.
Model Description
This model separates ECG signals into class-invariant content (beat morphology) and class-specific style (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another.
Architecture
| Component | Parameters | Description |
|---|---|---|
| Content Encoder (VAE) | 4.3M | Extracts class-invariant temporal patterns using BatchNorm |
| Style Encoder (CNN) | 567K | Captures class-discriminative features using InstanceNorm |
| Conditional UNet | 14.3M | Denoises with FiLM conditioning on content + style |
| Total | 19.1M |
Training
- Dataset: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows)
- Stage 1 (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence)
- Stage 2 (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss)
- Hardware: NVIDIA RTX 6000 Ada (48 GB VRAM)
- Training time: ~22 hours
Generation
- Method: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0)
- Filtering: Clinical plausibility validator (morphological + physiological checks, threshold 0.7)
- Output: 7,784 accepted synthetic ECGs from test set
Key Results
Augmentation Viability (5-Fold Cross-Validation)
| Condition | Training Data | Accuracy | F1 Score |
|---|---|---|---|
| A (Real only) | 18,681 original ECGs | 95.63 ± 0.33% | 95.65 ± 0.35% |
| B (Synthetic only) | 7,784 generated (×3) | 85.94 ± 1.32% | 86.70 ± 1.24% |
| C (Augmented) | 67% real + 33% synthetic | 95.05 ± 0.50% | 95.09 ± 0.46% |
TOST equivalence test confirms A ≈ C (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance.
Signal Quality
| Metric | Value |
|---|---|
| PSNR | 12.58 ± 2.09 dB |
| SSIM | 0.471 ± 0.110 |
| MSE | 0.005 ± 0.004 |
Files
diffusion_model.pth— Trained diffusion model (Stage 2, Epoch 100)classifier_model.pth— ResNet-BiLSTM AFib classifiermodel_metadata.json— Training configuration and final metrics
Usage
Option 1: Interactive Demo (Easiest)
Try the model directly in your browser — no code needed:
👉 Launch Demo
Upload an ECG (.npy or .csv, 2500 samples at 250 Hz) or browse pre-loaded examples.
Option 2: Download & Use in Python
from huggingface_hub import hf_hub_download
import torch
# Download model files from Hugging Face
diffusion_path = hf_hub_download(
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
filename="diffusion_model.pth"
)
classifier_path = hf_hub_download(
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
filename="classifier_model.pth"
)
# Load the diffusion model checkpoint
checkpoint = torch.load(diffusion_path, map_location="cpu")
# The checkpoint contains:
# - checkpoint['content_encoder'] → Content Encoder state dict
# - checkpoint['style_encoder'] → Style Encoder state dict
# - checkpoint['unet'] → UNet state dict
# - checkpoint['config'] → Training config with all hyperparameters
# Load the classifier checkpoint
cls_checkpoint = torch.load(classifier_path, map_location="cpu")
# - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict
# To use the full pipeline, clone the repository:
# git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
# See: diffusion_pipeline/final_pipeline/ for model architectures
Option 3: Clone the Full Pipeline
# Clone the full codebase with all model architectures
git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/
# Download weights
pip install huggingface_hub
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.')
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.')
"
Note: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts.
Citation
@misc{pera_af_detection_2025,
title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection},
author={Dilshan, D.M.T. and Karunarathne, K.N.P.},
year={2025},
institution={University of Peradeniya, Sri Lanka},
collaboration={SimulaMet, Oslo, Norway}
}