File size: 13,216 Bytes
05d6e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
"""
Hugging Face Hub-compatible wrapper for CountEM music transcription models.
"""
from pathlib import Path
from typing import Union, Tuple
import numpy as np
import torch
import soundfile as sf
from huggingface_hub import PyTorchModelHubMixin

from onsets_and_frames.transcriber import OnsetsAndFrames
from onsets_and_frames.mel import MelSpectrogram
from onsets_and_frames.midi_utils import frames2midi
from onsets_and_frames.constants import (
    N_MELS,
    MIN_MIDI,
    MAX_MIDI,
    HOP_LENGTH,
    SAMPLE_RATE,
    WINDOW_LENGTH,
    MEL_FMIN,
    MEL_FMAX,
)


class CountEMModel(
    OnsetsAndFrames,
    PyTorchModelHubMixin,
    # Optional metadata that gets pushed to model card
    library_name="countem",
    tags=["audio", "music-transcription", "automatic-music-transcription", "midi"],
    license="cc-by-4.0",
    repo_url="https://github.com/Yoni-Yaffe/count-the-notes",
    paper_url="https://arxiv.org/abs/2511.14250",
):
    """
    Hugging Face Hub-compatible wrapper for CountEM automatic music transcription models.

    This model performs automatic music transcription (AMT) from audio to MIDI.
    It uses the Onsets & Frames architecture trained with the CountEM framework,
    which enables training with weak, unordered note count histograms.

    Example usage:
        ```python
        from onsets_and_frames.hf_model import CountEMModel
        import soundfile as sf

        # Load model from Hub
        model = CountEMModel.from_pretrained("Yoni-Yaffe/countem-musicnet")

        # Load audio (must be 16kHz)
        audio, sr = sf.read("audio.flac")
        assert sr == 16000, "Audio must be 16kHz"

        # Transcribe to MIDI
        model.transcribe_to_midi(audio, "output.mid")
        ```

    Args:
        model_complexity: Complexity multiplier for the model (default: 64)
        onset_complexity: Complexity multiplier for onset stack (default: 1.5)
        n_instruments: Number of instruments to transcribe (default: 1)
    """

    def __init__(
        self,
        model_complexity: int = 64,
        onset_complexity: float = 1.5,
        n_instruments: int = 1,
        **kwargs
    ):
        # Initialize the base OnsetsAndFrames model
        n_keys = MAX_MIDI - MIN_MIDI + 1
        OnsetsAndFrames.__init__(
            self,
            input_features=N_MELS,
            output_features=n_keys,
            model_complexity=model_complexity,
            onset_complexity=onset_complexity,
            n_instruments=n_instruments,
        )

        # Store config for HF Hub
        self.config = {
            "model_complexity": model_complexity,
            "onset_complexity": onset_complexity,
            "n_instruments": n_instruments,
            "n_mels": N_MELS,
            "n_keys": n_keys,
            "sample_rate": SAMPLE_RATE,
            "hop_length": HOP_LENGTH,
        }

        # Add mel spectrogram as a submodule for proper device management
        # This ensures the mel transform moves with the model when calling .to(device)
        self.melspectrogram = MelSpectrogram(
            n_mels=N_MELS,
            sample_rate=SAMPLE_RATE,
            filter_length=WINDOW_LENGTH,
            hop_length=HOP_LENGTH,
            mel_fmin=MEL_FMIN,
            mel_fmax=MEL_FMAX,
        )

    def forward(self, audio: Union[np.ndarray, torch.Tensor]):
        """
        Forward pass that accepts raw audio waveforms.

        Unlike the parent OnsetsAndFrames which expects mel spectrograms,
        this forward method accepts raw audio and converts it internally.

        Args:
            audio: Raw audio waveform, shape (batch, n_samples) or (n_samples,)
                   Should be normalized to [-1, 1] or will be normalized automatically

        Returns:
            Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
        """
        # Convert to torch tensor if needed
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio).float()

        # Ensure audio is in range [-1, 1]
        if audio.dtype == torch.int16:
            audio = audio.float() / 32768.0
        elif audio.max() > 1.0 or audio.min() < -1.0:
            audio = audio / max(abs(audio.max()), abs(audio.min()))

        # Add batch dimension if needed
        if audio.dim() == 1:
            audio = audio.unsqueeze(0)

        device = next(self.parameters()).device
        audio = audio.to(device)

        # Remove last sample to fix frame count mismatch
        audio = audio[:, :-1]

        mel = self.melspectrogram(audio)

        # Transpose to (batch, time, features) format expected by parent model
        mel = mel.transpose(-1, -2)

        return super().forward(mel)

    @torch.no_grad()
    def transcribe(
        self,
        audio: Union[np.ndarray, torch.Tensor],
        onset_threshold: float = 0.5,
        frame_threshold: float = 0.5,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Transcribe audio to note predictions.

        Automatically handles long audio by splitting into segments (max 5 minutes each)
        to avoid memory issues.

        Args:
            audio: Audio waveform, shape (n_samples,), normalized to [-1, 1]
            onset_threshold: Threshold for onset detection (default: 0.5)
            frame_threshold: Threshold for frame detection (default: 0.5)

        Returns:
            Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
            All are numpy arrays of shape (n_frames, 88) except velocity which may vary
        """
        self.eval()

        # Convert to torch tensor if needed
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio).float()

        # Ensure audio is 1D (convert stereo to mono if needed)
        if audio.dim() > 1:
            # If stereo or multi-channel, take mean across channels
            audio = audio.mean(dim=-1 if audio.shape[-1] <=2 else 0)

        # Normalize audio
        if audio.dtype == torch.int16:
            audio = audio.float() / 32768.0
        elif audio.max() > 1.0 or audio.min() < -1.0:
            audio = audio / max(abs(audio.max()), abs(audio.min()))

        device = next(self.parameters()).device
        audio = audio.to(device)

        # Handle long audio by segmenting
        MAX_TIME = 5 * 60 * SAMPLE_RATE  # 5 minutes
        audio_len = len(audio)

        if audio_len > MAX_TIME:
            # Split into segments
            n_segments = int(np.ceil(audio_len / MAX_TIME))
            seg_len = MAX_TIME

            onset_preds = []
            offset_preds = []
            activation_preds = []
            frame_preds = []
            velocity_preds = []

            for i_s in range(n_segments):
                start = i_s * seg_len
                end = min((i_s + 1) * seg_len, audio_len)
                segment = audio[start:end]

                # Forward pass on segment
                onset_seg, offset_seg, activation_seg, frame_seg, velocity_seg = self(segment)

                onset_preds.append(onset_seg)
                offset_preds.append(offset_seg)
                activation_preds.append(activation_seg)
                frame_preds.append(frame_seg)
                velocity_preds.append(velocity_seg)

            # Concatenate along time dimension (dim=1)
            onset_pred = torch.cat(onset_preds, dim=1)
            offset_pred = torch.cat(offset_preds, dim=1)
            activation_pred = torch.cat(activation_preds, dim=1)
            frame_pred = torch.cat(frame_preds, dim=1)
            velocity_pred = torch.cat(velocity_preds, dim=1)
        else:
            # Short audio, process directly
            onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred = self(audio)

        # Convert to numpy and remove batch dimension
        onset_pred = onset_pred.squeeze(0).cpu().numpy()
        offset_pred = offset_pred.squeeze(0).cpu().numpy()
        activation_pred = activation_pred.squeeze(0).cpu().numpy()
        frame_pred = frame_pred.squeeze(0).cpu().numpy()
        velocity_pred = velocity_pred.squeeze(0).cpu().numpy()

        return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred

    def transcribe_to_midi(
        self,
        audio: Union[np.ndarray, torch.Tensor, str, Path],
        output_path: Union[str, Path],
        onset_threshold: float = 0.5,
        frame_threshold: float = 0.5,
    ) -> None:
        """
        Transcribe audio to MIDI file.

        Args:
            audio: Audio waveform, numpy array, torch tensor, or path to audio file
            output_path: Path to save MIDI file
            onset_threshold: Threshold for onset detection (default: 0.5)
            frame_threshold: Threshold for frame detection (default: 0.5)
        """
        # Load audio from file if path is provided
        if isinstance(audio, (str, Path)):
            audio, sr = sf.read(audio, dtype="float32")
            if sr != SAMPLE_RATE:
                raise ValueError(
                    f"Audio must be {SAMPLE_RATE}Hz, got {sr}Hz. "
                    f"Please resample to {SAMPLE_RATE}Hz first."
                )

        # Get predictions
        onset_pred, offset_pred, _, frame_pred, velocity_pred = self.transcribe(
            audio, onset_threshold, frame_threshold
        )

        # Default instrument mapping (piano)
        inst_mapping = {0: 0}  # instrument 0 -> MIDI program 0 (Acoustic Grand Piano)

        # Convert predictions to MIDI
        frames2midi(
            str(output_path),
            onset_pred,
            frame_pred,
            velocity_pred,
            onset_threshold=onset_threshold,
            frame_threshold=frame_threshold,
            scaling=HOP_LENGTH / SAMPLE_RATE,
            inst_mapping=inst_mapping,
        )

    def to_legacy(self) -> OnsetsAndFrames:
        """
        Convert this HuggingFace-compatible model to a legacy OnsetsAndFrames instance.

        This is useful for:
        - Fine-tuning models downloaded from HuggingFace Hub using existing training code
        - Using HF models with existing inference scripts that expect OnsetsAndFrames

        The legacy model will use the global melspectrogram from mel.py instead of
        the instance-specific one in this model.

        Returns:
            OnsetsAndFrames instance with copied weights
        """
        # Create legacy model with same architecture
        legacy_model = OnsetsAndFrames(
            input_features=self.config['n_mels'],
            output_features=self.config['n_keys'],
            model_complexity=self.config['model_complexity'],
            onset_complexity=self.config['onset_complexity'],
            n_instruments=self.config['n_instruments']
        )

        # Get the state dict and filter out melspectrogram keys
        state_dict = self.state_dict()
        legacy_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('melspectrogram.')}

        # Copy state dict (only model weights, not mel spectrogram)
        # The legacy model will use the global melspectrogram
        legacy_model.load_state_dict(legacy_state_dict)

        return legacy_model

    @classmethod
    def from_legacy_checkpoint(
        cls,
        checkpoint_path: Union[str, Path],
        **kwargs
    ) -> "CountEMModel":
        """
        Load a model from a legacy checkpoint (saved with torch.save(model)).

        This is useful for converting old checkpoints to the new HF-compatible format.

        Args:
            checkpoint_path: Path to the legacy .pt checkpoint file
            **kwargs: Additional arguments for model initialization

        Returns:
            CountEMModel instance with loaded weights
        """
        # Load the legacy checkpoint
        legacy_model = torch.load(checkpoint_path, map_location="cpu")

        # Extract configuration from the loaded model
        # Infer model_complexity from the model structure
        # ConvStack.cnn[0] is the first Conv2d layer with out_channels = model_size // 16
        first_conv_channels = legacy_model.offset_stack[0].cnn[0].out_channels
        model_size = first_conv_channels * 16
        model_complexity = model_size // 16

        # Infer onset_complexity
        onset_first_conv_channels = legacy_model.onset_stack[0].cnn[0].out_channels
        onset_model_size = onset_first_conv_channels * 16
        onset_complexity = onset_model_size / model_size

        # Infer n_instruments from output layer
        # onset_stack[2] is the Linear layer
        onset_out_features = legacy_model.onset_stack[2].out_features
        n_keys = MAX_MIDI - MIN_MIDI + 1
        n_instruments = onset_out_features // n_keys

        # Create new model with the same configuration
        model = cls(
            model_complexity=model_complexity,
            onset_complexity=onset_complexity,
            n_instruments=n_instruments,
            **kwargs
        )

        # Copy the state dict (strict=False because new model has melspectrogram submodule)
        model.load_state_dict(legacy_model.state_dict(), strict=False)

        return model