Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Hugging Face Hub-compatible wrapper for CountEM music transcription models. | |
| """ | |
| from pathlib import Path | |
| from typing import Union, Tuple | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from onsets_and_frames.transcriber import OnsetsAndFrames | |
| from onsets_and_frames.mel import MelSpectrogram | |
| from onsets_and_frames.midi_utils import frames2midi | |
| from onsets_and_frames.constants import ( | |
| N_MELS, | |
| MIN_MIDI, | |
| MAX_MIDI, | |
| HOP_LENGTH, | |
| SAMPLE_RATE, | |
| WINDOW_LENGTH, | |
| MEL_FMIN, | |
| MEL_FMAX, | |
| ) | |
| class CountEMModel( | |
| OnsetsAndFrames, | |
| PyTorchModelHubMixin, | |
| # Optional metadata that gets pushed to model card | |
| library_name="countem", | |
| tags=["audio", "music-transcription", "automatic-music-transcription", "midi"], | |
| license="cc-by-4.0", | |
| repo_url="https://github.com/Yoni-Yaffe/count-the-notes", | |
| paper_url="https://arxiv.org/abs/2511.14250", | |
| ): | |
| """ | |
| Hugging Face Hub-compatible wrapper for CountEM automatic music transcription models. | |
| This model performs automatic music transcription (AMT) from audio to MIDI. | |
| It uses the Onsets & Frames architecture trained with the CountEM framework, | |
| which enables training with weak, unordered note count histograms. | |
| Example usage: | |
| ```python | |
| from onsets_and_frames.hf_model import CountEMModel | |
| import soundfile as sf | |
| # Load model from Hub | |
| model = CountEMModel.from_pretrained("Yoni-Yaffe/countem-musicnet") | |
| # Load audio (must be 16kHz) | |
| audio, sr = sf.read("audio.flac") | |
| assert sr == 16000, "Audio must be 16kHz" | |
| # Transcribe to MIDI | |
| model.transcribe_to_midi(audio, "output.mid") | |
| ``` | |
| Args: | |
| model_complexity: Complexity multiplier for the model (default: 64) | |
| onset_complexity: Complexity multiplier for onset stack (default: 1.5) | |
| n_instruments: Number of instruments to transcribe (default: 1) | |
| """ | |
| def __init__( | |
| self, | |
| model_complexity: int = 64, | |
| onset_complexity: float = 1.5, | |
| n_instruments: int = 1, | |
| **kwargs | |
| ): | |
| # Initialize the base OnsetsAndFrames model | |
| n_keys = MAX_MIDI - MIN_MIDI + 1 | |
| OnsetsAndFrames.__init__( | |
| self, | |
| input_features=N_MELS, | |
| output_features=n_keys, | |
| model_complexity=model_complexity, | |
| onset_complexity=onset_complexity, | |
| n_instruments=n_instruments, | |
| ) | |
| # Store config for HF Hub | |
| self.config = { | |
| "model_complexity": model_complexity, | |
| "onset_complexity": onset_complexity, | |
| "n_instruments": n_instruments, | |
| "n_mels": N_MELS, | |
| "n_keys": n_keys, | |
| "sample_rate": SAMPLE_RATE, | |
| "hop_length": HOP_LENGTH, | |
| } | |
| # Add mel spectrogram as a submodule for proper device management | |
| # This ensures the mel transform moves with the model when calling .to(device) | |
| self.melspectrogram = MelSpectrogram( | |
| n_mels=N_MELS, | |
| sample_rate=SAMPLE_RATE, | |
| filter_length=WINDOW_LENGTH, | |
| hop_length=HOP_LENGTH, | |
| mel_fmin=MEL_FMIN, | |
| mel_fmax=MEL_FMAX, | |
| ) | |
| def forward(self, audio: Union[np.ndarray, torch.Tensor]): | |
| """ | |
| Forward pass that accepts raw audio waveforms. | |
| Unlike the parent OnsetsAndFrames which expects mel spectrograms, | |
| this forward method accepts raw audio and converts it internally. | |
| Args: | |
| audio: Raw audio waveform, shape (batch, n_samples) or (n_samples,) | |
| Should be normalized to [-1, 1] or will be normalized automatically | |
| Returns: | |
| Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred) | |
| """ | |
| # Convert to torch tensor if needed | |
| if isinstance(audio, np.ndarray): | |
| audio = torch.from_numpy(audio).float() | |
| # Ensure audio is in range [-1, 1] | |
| if audio.dtype == torch.int16: | |
| audio = audio.float() / 32768.0 | |
| elif audio.max() > 1.0 or audio.min() < -1.0: | |
| audio = audio / max(abs(audio.max()), abs(audio.min())) | |
| # Add batch dimension if needed | |
| if audio.dim() == 1: | |
| audio = audio.unsqueeze(0) | |
| device = next(self.parameters()).device | |
| audio = audio.to(device) | |
| # Remove last sample to fix frame count mismatch | |
| audio = audio[:, :-1] | |
| mel = self.melspectrogram(audio) | |
| # Transpose to (batch, time, features) format expected by parent model | |
| mel = mel.transpose(-1, -2) | |
| return super().forward(mel) | |
| def transcribe( | |
| self, | |
| audio: Union[np.ndarray, torch.Tensor], | |
| onset_threshold: float = 0.5, | |
| frame_threshold: float = 0.5, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Transcribe audio to note predictions. | |
| Automatically handles long audio by splitting into segments (max 5 minutes each) | |
| to avoid memory issues. | |
| Args: | |
| audio: Audio waveform, shape (n_samples,), normalized to [-1, 1] | |
| onset_threshold: Threshold for onset detection (default: 0.5) | |
| frame_threshold: Threshold for frame detection (default: 0.5) | |
| Returns: | |
| Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred) | |
| All are numpy arrays of shape (n_frames, 88) except velocity which may vary | |
| """ | |
| self.eval() | |
| # Convert to torch tensor if needed | |
| if isinstance(audio, np.ndarray): | |
| audio = torch.from_numpy(audio).float() | |
| # Ensure audio is 1D (convert stereo to mono if needed) | |
| if audio.dim() > 1: | |
| # If stereo or multi-channel, take mean across channels | |
| audio = audio.mean(dim=-1 if audio.shape[-1] <=2 else 0) | |
| # Normalize audio | |
| if audio.dtype == torch.int16: | |
| audio = audio.float() / 32768.0 | |
| elif audio.max() > 1.0 or audio.min() < -1.0: | |
| audio = audio / max(abs(audio.max()), abs(audio.min())) | |
| device = next(self.parameters()).device | |
| audio = audio.to(device) | |
| # Handle long audio by segmenting | |
| MAX_TIME = 5 * 60 * SAMPLE_RATE # 5 minutes | |
| audio_len = len(audio) | |
| if audio_len > MAX_TIME: | |
| # Split into segments | |
| n_segments = int(np.ceil(audio_len / MAX_TIME)) | |
| seg_len = MAX_TIME | |
| onset_preds = [] | |
| offset_preds = [] | |
| activation_preds = [] | |
| frame_preds = [] | |
| velocity_preds = [] | |
| for i_s in range(n_segments): | |
| start = i_s * seg_len | |
| end = min((i_s + 1) * seg_len, audio_len) | |
| segment = audio[start:end] | |
| # Forward pass on segment | |
| onset_seg, offset_seg, activation_seg, frame_seg, velocity_seg = self(segment) | |
| onset_preds.append(onset_seg) | |
| offset_preds.append(offset_seg) | |
| activation_preds.append(activation_seg) | |
| frame_preds.append(frame_seg) | |
| velocity_preds.append(velocity_seg) | |
| # Concatenate along time dimension (dim=1) | |
| onset_pred = torch.cat(onset_preds, dim=1) | |
| offset_pred = torch.cat(offset_preds, dim=1) | |
| activation_pred = torch.cat(activation_preds, dim=1) | |
| frame_pred = torch.cat(frame_preds, dim=1) | |
| velocity_pred = torch.cat(velocity_preds, dim=1) | |
| else: | |
| # Short audio, process directly | |
| onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred = self(audio) | |
| # Convert to numpy and remove batch dimension | |
| onset_pred = onset_pred.squeeze(0).cpu().numpy() | |
| offset_pred = offset_pred.squeeze(0).cpu().numpy() | |
| activation_pred = activation_pred.squeeze(0).cpu().numpy() | |
| frame_pred = frame_pred.squeeze(0).cpu().numpy() | |
| velocity_pred = velocity_pred.squeeze(0).cpu().numpy() | |
| return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred | |
| def transcribe_to_midi( | |
| self, | |
| audio: Union[np.ndarray, torch.Tensor, str, Path], | |
| output_path: Union[str, Path], | |
| onset_threshold: float = 0.5, | |
| frame_threshold: float = 0.5, | |
| ) -> None: | |
| """ | |
| Transcribe audio to MIDI file. | |
| Args: | |
| audio: Audio waveform, numpy array, torch tensor, or path to audio file | |
| output_path: Path to save MIDI file | |
| onset_threshold: Threshold for onset detection (default: 0.5) | |
| frame_threshold: Threshold for frame detection (default: 0.5) | |
| """ | |
| # Load audio from file if path is provided | |
| if isinstance(audio, (str, Path)): | |
| audio, sr = sf.read(audio, dtype="float32") | |
| if sr != SAMPLE_RATE: | |
| raise ValueError( | |
| f"Audio must be {SAMPLE_RATE}Hz, got {sr}Hz. " | |
| f"Please resample to {SAMPLE_RATE}Hz first." | |
| ) | |
| # Get predictions | |
| onset_pred, offset_pred, _, frame_pred, velocity_pred = self.transcribe( | |
| audio, onset_threshold, frame_threshold | |
| ) | |
| # Default instrument mapping (piano) | |
| inst_mapping = {0: 0} # instrument 0 -> MIDI program 0 (Acoustic Grand Piano) | |
| # Convert predictions to MIDI | |
| frames2midi( | |
| str(output_path), | |
| onset_pred, | |
| frame_pred, | |
| velocity_pred, | |
| onset_threshold=onset_threshold, | |
| frame_threshold=frame_threshold, | |
| scaling=HOP_LENGTH / SAMPLE_RATE, | |
| inst_mapping=inst_mapping, | |
| ) | |
| def to_legacy(self) -> OnsetsAndFrames: | |
| """ | |
| Convert this HuggingFace-compatible model to a legacy OnsetsAndFrames instance. | |
| This is useful for: | |
| - Fine-tuning models downloaded from HuggingFace Hub using existing training code | |
| - Using HF models with existing inference scripts that expect OnsetsAndFrames | |
| The legacy model will use the global melspectrogram from mel.py instead of | |
| the instance-specific one in this model. | |
| Returns: | |
| OnsetsAndFrames instance with copied weights | |
| """ | |
| # Create legacy model with same architecture | |
| legacy_model = OnsetsAndFrames( | |
| input_features=self.config['n_mels'], | |
| output_features=self.config['n_keys'], | |
| model_complexity=self.config['model_complexity'], | |
| onset_complexity=self.config['onset_complexity'], | |
| n_instruments=self.config['n_instruments'] | |
| ) | |
| # Get the state dict and filter out melspectrogram keys | |
| state_dict = self.state_dict() | |
| legacy_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('melspectrogram.')} | |
| # Copy state dict (only model weights, not mel spectrogram) | |
| # The legacy model will use the global melspectrogram | |
| legacy_model.load_state_dict(legacy_state_dict) | |
| return legacy_model | |
| def from_legacy_checkpoint( | |
| cls, | |
| checkpoint_path: Union[str, Path], | |
| **kwargs | |
| ) -> "CountEMModel": | |
| """ | |
| Load a model from a legacy checkpoint (saved with torch.save(model)). | |
| This is useful for converting old checkpoints to the new HF-compatible format. | |
| Args: | |
| checkpoint_path: Path to the legacy .pt checkpoint file | |
| **kwargs: Additional arguments for model initialization | |
| Returns: | |
| CountEMModel instance with loaded weights | |
| """ | |
| # Load the legacy checkpoint | |
| legacy_model = torch.load(checkpoint_path, map_location="cpu") | |
| # Extract configuration from the loaded model | |
| # Infer model_complexity from the model structure | |
| # ConvStack.cnn[0] is the first Conv2d layer with out_channels = model_size // 16 | |
| first_conv_channels = legacy_model.offset_stack[0].cnn[0].out_channels | |
| model_size = first_conv_channels * 16 | |
| model_complexity = model_size // 16 | |
| # Infer onset_complexity | |
| onset_first_conv_channels = legacy_model.onset_stack[0].cnn[0].out_channels | |
| onset_model_size = onset_first_conv_channels * 16 | |
| onset_complexity = onset_model_size / model_size | |
| # Infer n_instruments from output layer | |
| # onset_stack[2] is the Linear layer | |
| onset_out_features = legacy_model.onset_stack[2].out_features | |
| n_keys = MAX_MIDI - MIN_MIDI + 1 | |
| n_instruments = onset_out_features // n_keys | |
| # Create new model with the same configuration | |
| model = cls( | |
| model_complexity=model_complexity, | |
| onset_complexity=onset_complexity, | |
| n_instruments=n_instruments, | |
| **kwargs | |
| ) | |
| # Copy the state dict (strict=False because new model has melspectrogram submodule) | |
| model.load_state_dict(legacy_model.state_dict(), strict=False) | |
| return model | |