Yoni232's picture
added source code of model and transcription scripts
05d6e12
"""
Hugging Face Hub-compatible wrapper for CountEM music transcription models.
"""
from pathlib import Path
from typing import Union, Tuple
import numpy as np
import torch
import soundfile as sf
from huggingface_hub import PyTorchModelHubMixin
from onsets_and_frames.transcriber import OnsetsAndFrames
from onsets_and_frames.mel import MelSpectrogram
from onsets_and_frames.midi_utils import frames2midi
from onsets_and_frames.constants import (
N_MELS,
MIN_MIDI,
MAX_MIDI,
HOP_LENGTH,
SAMPLE_RATE,
WINDOW_LENGTH,
MEL_FMIN,
MEL_FMAX,
)
class CountEMModel(
OnsetsAndFrames,
PyTorchModelHubMixin,
# Optional metadata that gets pushed to model card
library_name="countem",
tags=["audio", "music-transcription", "automatic-music-transcription", "midi"],
license="cc-by-4.0",
repo_url="https://github.com/Yoni-Yaffe/count-the-notes",
paper_url="https://arxiv.org/abs/2511.14250",
):
"""
Hugging Face Hub-compatible wrapper for CountEM automatic music transcription models.
This model performs automatic music transcription (AMT) from audio to MIDI.
It uses the Onsets & Frames architecture trained with the CountEM framework,
which enables training with weak, unordered note count histograms.
Example usage:
```python
from onsets_and_frames.hf_model import CountEMModel
import soundfile as sf
# Load model from Hub
model = CountEMModel.from_pretrained("Yoni-Yaffe/countem-musicnet")
# Load audio (must be 16kHz)
audio, sr = sf.read("audio.flac")
assert sr == 16000, "Audio must be 16kHz"
# Transcribe to MIDI
model.transcribe_to_midi(audio, "output.mid")
```
Args:
model_complexity: Complexity multiplier for the model (default: 64)
onset_complexity: Complexity multiplier for onset stack (default: 1.5)
n_instruments: Number of instruments to transcribe (default: 1)
"""
def __init__(
self,
model_complexity: int = 64,
onset_complexity: float = 1.5,
n_instruments: int = 1,
**kwargs
):
# Initialize the base OnsetsAndFrames model
n_keys = MAX_MIDI - MIN_MIDI + 1
OnsetsAndFrames.__init__(
self,
input_features=N_MELS,
output_features=n_keys,
model_complexity=model_complexity,
onset_complexity=onset_complexity,
n_instruments=n_instruments,
)
# Store config for HF Hub
self.config = {
"model_complexity": model_complexity,
"onset_complexity": onset_complexity,
"n_instruments": n_instruments,
"n_mels": N_MELS,
"n_keys": n_keys,
"sample_rate": SAMPLE_RATE,
"hop_length": HOP_LENGTH,
}
# Add mel spectrogram as a submodule for proper device management
# This ensures the mel transform moves with the model when calling .to(device)
self.melspectrogram = MelSpectrogram(
n_mels=N_MELS,
sample_rate=SAMPLE_RATE,
filter_length=WINDOW_LENGTH,
hop_length=HOP_LENGTH,
mel_fmin=MEL_FMIN,
mel_fmax=MEL_FMAX,
)
def forward(self, audio: Union[np.ndarray, torch.Tensor]):
"""
Forward pass that accepts raw audio waveforms.
Unlike the parent OnsetsAndFrames which expects mel spectrograms,
this forward method accepts raw audio and converts it internally.
Args:
audio: Raw audio waveform, shape (batch, n_samples) or (n_samples,)
Should be normalized to [-1, 1] or will be normalized automatically
Returns:
Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
"""
# Convert to torch tensor if needed
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
# Ensure audio is in range [-1, 1]
if audio.dtype == torch.int16:
audio = audio.float() / 32768.0
elif audio.max() > 1.0 or audio.min() < -1.0:
audio = audio / max(abs(audio.max()), abs(audio.min()))
# Add batch dimension if needed
if audio.dim() == 1:
audio = audio.unsqueeze(0)
device = next(self.parameters()).device
audio = audio.to(device)
# Remove last sample to fix frame count mismatch
audio = audio[:, :-1]
mel = self.melspectrogram(audio)
# Transpose to (batch, time, features) format expected by parent model
mel = mel.transpose(-1, -2)
return super().forward(mel)
@torch.no_grad()
def transcribe(
self,
audio: Union[np.ndarray, torch.Tensor],
onset_threshold: float = 0.5,
frame_threshold: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Transcribe audio to note predictions.
Automatically handles long audio by splitting into segments (max 5 minutes each)
to avoid memory issues.
Args:
audio: Audio waveform, shape (n_samples,), normalized to [-1, 1]
onset_threshold: Threshold for onset detection (default: 0.5)
frame_threshold: Threshold for frame detection (default: 0.5)
Returns:
Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
All are numpy arrays of shape (n_frames, 88) except velocity which may vary
"""
self.eval()
# Convert to torch tensor if needed
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
# Ensure audio is 1D (convert stereo to mono if needed)
if audio.dim() > 1:
# If stereo or multi-channel, take mean across channels
audio = audio.mean(dim=-1 if audio.shape[-1] <=2 else 0)
# Normalize audio
if audio.dtype == torch.int16:
audio = audio.float() / 32768.0
elif audio.max() > 1.0 or audio.min() < -1.0:
audio = audio / max(abs(audio.max()), abs(audio.min()))
device = next(self.parameters()).device
audio = audio.to(device)
# Handle long audio by segmenting
MAX_TIME = 5 * 60 * SAMPLE_RATE # 5 minutes
audio_len = len(audio)
if audio_len > MAX_TIME:
# Split into segments
n_segments = int(np.ceil(audio_len / MAX_TIME))
seg_len = MAX_TIME
onset_preds = []
offset_preds = []
activation_preds = []
frame_preds = []
velocity_preds = []
for i_s in range(n_segments):
start = i_s * seg_len
end = min((i_s + 1) * seg_len, audio_len)
segment = audio[start:end]
# Forward pass on segment
onset_seg, offset_seg, activation_seg, frame_seg, velocity_seg = self(segment)
onset_preds.append(onset_seg)
offset_preds.append(offset_seg)
activation_preds.append(activation_seg)
frame_preds.append(frame_seg)
velocity_preds.append(velocity_seg)
# Concatenate along time dimension (dim=1)
onset_pred = torch.cat(onset_preds, dim=1)
offset_pred = torch.cat(offset_preds, dim=1)
activation_pred = torch.cat(activation_preds, dim=1)
frame_pred = torch.cat(frame_preds, dim=1)
velocity_pred = torch.cat(velocity_preds, dim=1)
else:
# Short audio, process directly
onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred = self(audio)
# Convert to numpy and remove batch dimension
onset_pred = onset_pred.squeeze(0).cpu().numpy()
offset_pred = offset_pred.squeeze(0).cpu().numpy()
activation_pred = activation_pred.squeeze(0).cpu().numpy()
frame_pred = frame_pred.squeeze(0).cpu().numpy()
velocity_pred = velocity_pred.squeeze(0).cpu().numpy()
return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred
def transcribe_to_midi(
self,
audio: Union[np.ndarray, torch.Tensor, str, Path],
output_path: Union[str, Path],
onset_threshold: float = 0.5,
frame_threshold: float = 0.5,
) -> None:
"""
Transcribe audio to MIDI file.
Args:
audio: Audio waveform, numpy array, torch tensor, or path to audio file
output_path: Path to save MIDI file
onset_threshold: Threshold for onset detection (default: 0.5)
frame_threshold: Threshold for frame detection (default: 0.5)
"""
# Load audio from file if path is provided
if isinstance(audio, (str, Path)):
audio, sr = sf.read(audio, dtype="float32")
if sr != SAMPLE_RATE:
raise ValueError(
f"Audio must be {SAMPLE_RATE}Hz, got {sr}Hz. "
f"Please resample to {SAMPLE_RATE}Hz first."
)
# Get predictions
onset_pred, offset_pred, _, frame_pred, velocity_pred = self.transcribe(
audio, onset_threshold, frame_threshold
)
# Default instrument mapping (piano)
inst_mapping = {0: 0} # instrument 0 -> MIDI program 0 (Acoustic Grand Piano)
# Convert predictions to MIDI
frames2midi(
str(output_path),
onset_pred,
frame_pred,
velocity_pred,
onset_threshold=onset_threshold,
frame_threshold=frame_threshold,
scaling=HOP_LENGTH / SAMPLE_RATE,
inst_mapping=inst_mapping,
)
def to_legacy(self) -> OnsetsAndFrames:
"""
Convert this HuggingFace-compatible model to a legacy OnsetsAndFrames instance.
This is useful for:
- Fine-tuning models downloaded from HuggingFace Hub using existing training code
- Using HF models with existing inference scripts that expect OnsetsAndFrames
The legacy model will use the global melspectrogram from mel.py instead of
the instance-specific one in this model.
Returns:
OnsetsAndFrames instance with copied weights
"""
# Create legacy model with same architecture
legacy_model = OnsetsAndFrames(
input_features=self.config['n_mels'],
output_features=self.config['n_keys'],
model_complexity=self.config['model_complexity'],
onset_complexity=self.config['onset_complexity'],
n_instruments=self.config['n_instruments']
)
# Get the state dict and filter out melspectrogram keys
state_dict = self.state_dict()
legacy_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('melspectrogram.')}
# Copy state dict (only model weights, not mel spectrogram)
# The legacy model will use the global melspectrogram
legacy_model.load_state_dict(legacy_state_dict)
return legacy_model
@classmethod
def from_legacy_checkpoint(
cls,
checkpoint_path: Union[str, Path],
**kwargs
) -> "CountEMModel":
"""
Load a model from a legacy checkpoint (saved with torch.save(model)).
This is useful for converting old checkpoints to the new HF-compatible format.
Args:
checkpoint_path: Path to the legacy .pt checkpoint file
**kwargs: Additional arguments for model initialization
Returns:
CountEMModel instance with loaded weights
"""
# Load the legacy checkpoint
legacy_model = torch.load(checkpoint_path, map_location="cpu")
# Extract configuration from the loaded model
# Infer model_complexity from the model structure
# ConvStack.cnn[0] is the first Conv2d layer with out_channels = model_size // 16
first_conv_channels = legacy_model.offset_stack[0].cnn[0].out_channels
model_size = first_conv_channels * 16
model_complexity = model_size // 16
# Infer onset_complexity
onset_first_conv_channels = legacy_model.onset_stack[0].cnn[0].out_channels
onset_model_size = onset_first_conv_channels * 16
onset_complexity = onset_model_size / model_size
# Infer n_instruments from output layer
# onset_stack[2] is the Linear layer
onset_out_features = legacy_model.onset_stack[2].out_features
n_keys = MAX_MIDI - MIN_MIDI + 1
n_instruments = onset_out_features // n_keys
# Create new model with the same configuration
model = cls(
model_complexity=model_complexity,
onset_complexity=onset_complexity,
n_instruments=n_instruments,
**kwargs
)
# Copy the state dict (strict=False because new model has melspectrogram submodule)
model.load_state_dict(legacy_model.state_dict(), strict=False)
return model