⚠️ Codebase Update: Input Flexibility & Fine-Tuning Preparation

This repository has been updated to support dynamic input lengths. The model now utilizes on-the-fly positional embeddings, removing the restriction on fixed audio durations found in previous versions. While the core pre-trained weights and architectural logic remain unchanged, we have introduced new infrastructure to facilitate downstream fine-tuning (including dynamic classification heads and normalization layers). Additional fine-tuning configurations and hyperparameters will be documented in future updates.

EAT-base (Epoch 30, Pre-trained Checkpoint)

This is the pre-trained EAT-base model at epoch 30, trained on the AS-2M dataset using the EAT framework for audio self-supervised learning. It offers efficient feature extraction and can also serve as a strong initialization for fine-tuning on a wide range of downstream audio understanding tasks such as classification and captioning.

For more details on the EAT framework, please refer to the GitHub repository and our paper EAT: Self-Supervised Pre-Training with Efficient Audio Transformer.

πŸ”§ Usage

You can load and use the model for feature extraction directly via Hugging Face Transformers:

import torchaudio
import torch
import soundfile as sf
import numpy as np
from transformers import AutoModel

model_id = "HTill/flexEAT-base_epoch30_pretrain"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()

source_file = "/path/to/input.wav"
target_file = "/path/to/output.npy"
norm_mean = -4.268
norm_std = 4.569

# Load and resample audio
wav, sr = sf.read(source_file)
waveform = torch.tensor(wav).float().cuda()
if sr != 16000:
    waveform = torchaudio.functional.resample(waveform, sr, 16000)

# Normalize and convert to mel-spectrogram
waveform = waveform - waveform.mean()
mel = torchaudio.compliance.kaldi.fbank(
    waveform.unsqueeze(0),
    htk_compat=True,
    sample_frequency=16000,
    use_energy=False,
    window_type='hanning',
    num_mel_bins=128,
    dither=0.0,
    frame_shift=10
).unsqueeze(0)

# Normalize
mel = (mel - norm_mean) / (norm_std * 2)
mel = mel.unsqueeze(0).cuda()  # shape: [1, 1, T, F]

# Extract features
with torch.no_grad():
    feat = model.extract_features(mel)

feat = feat.squeeze(0).cpu().numpy()
np.save(target_file, feat)
print(f"Feature shape: {feat.shape}")
print(f"Saved to: {target_file}")

πŸ“Œ Notes

The model supports both frame-level (~50Hz) and utterance-level (CLS token) representations. See the feature extraction guide for more instructions.

πŸ“š Citation

If you find this model useful, please consider citing our paper:

@article{chen2024eat,
  title={EAT: Self-supervised pre-training with efficient audio transformer},
  author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
  journal={arXiv preprint arXiv:2401.03497},
  year={2024}
}
Downloads last month
40
Safetensors
Model size
90M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for HTill/flexEAT-base_epoch30_pretrain

Finetuned
(2)
this model