Diffusion-Based ECG Augmentation Model

A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for training data augmentation in AFib detection systems.

Model Description

This model separates ECG signals into class-invariant content (beat morphology) and class-specific style (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another.

Architecture

Component Parameters Description
Content Encoder (VAE) 4.3M Extracts class-invariant temporal patterns using BatchNorm
Style Encoder (CNN) 567K Captures class-discriminative features using InstanceNorm
Conditional UNet 14.3M Denoises with FiLM conditioning on content + style
Total 19.1M

Training

  • Dataset: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows)
  • Stage 1 (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence)
  • Stage 2 (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss)
  • Hardware: NVIDIA RTX 6000 Ada (48 GB VRAM)
  • Training time: ~22 hours

Generation

  • Method: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0)
  • Filtering: Clinical plausibility validator (morphological + physiological checks, threshold 0.7)
  • Output: 7,784 accepted synthetic ECGs from test set

Key Results

Augmentation Viability (5-Fold Cross-Validation)

Condition Training Data Accuracy F1 Score
A (Real only) 18,681 original ECGs 95.63 ± 0.33% 95.65 ± 0.35%
B (Synthetic only) 7,784 generated (×3) 85.94 ± 1.32% 86.70 ± 1.24%
C (Augmented) 67% real + 33% synthetic 95.05 ± 0.50% 95.09 ± 0.46%

TOST equivalence test confirms A ≈ C (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance.

Signal Quality

Metric Value
PSNR 12.58 ± 2.09 dB
SSIM 0.471 ± 0.110
MSE 0.005 ± 0.004

Files

  • diffusion_model.pth — Trained diffusion model (Stage 2, Epoch 100)
  • classifier_model.pth — ResNet-BiLSTM AFib classifier
  • model_metadata.json — Training configuration and final metrics

Usage

Option 1: Interactive Demo (Easiest)

Try the model directly in your browser — no code needed:

👉 Launch Demo

Upload an ECG (.npy or .csv, 2500 samples at 250 Hz) or browse pre-loaded examples.

Option 2: Download & Use in Python

from huggingface_hub import hf_hub_download
import torch

# Download model files from Hugging Face
diffusion_path = hf_hub_download(
    repo_id="TharakaDil2001/diffusion-ecg-augmentation",
    filename="diffusion_model.pth"
)
classifier_path = hf_hub_download(
    repo_id="TharakaDil2001/diffusion-ecg-augmentation",
    filename="classifier_model.pth"
)

# Load the diffusion model checkpoint
checkpoint = torch.load(diffusion_path, map_location="cpu")

# The checkpoint contains:
# - checkpoint['content_encoder'] → Content Encoder state dict
# - checkpoint['style_encoder']   → Style Encoder state dict
# - checkpoint['unet']            → UNet state dict
# - checkpoint['config']          → Training config with all hyperparameters

# Load the classifier checkpoint
cls_checkpoint = torch.load(classifier_path, map_location="cpu")
# - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict

# To use the full pipeline, clone the repository:
# git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
# See: diffusion_pipeline/final_pipeline/ for model architectures

Option 3: Clone the Full Pipeline

# Clone the full codebase with all model architectures
git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/

# Download weights
pip install huggingface_hub
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.')
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.')
"

Note: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts.

Citation

@misc{pera_af_detection_2025,
  title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection},
  author={Dilshan, D.M.T. and Karunarathne, K.N.P.},
  year={2025},
  institution={University of Peradeniya, Sri Lanka},
  collaboration={SimulaMet, Oslo, Norway}
}

Links

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using TharakaDil2001/diffusion-ecg-augmentation 1