Yoni232 commited on
Commit
05d6e12
·
1 Parent(s): 80e5ec8

added source code of model and transcription scripts

Browse files
onsets_and_frames/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .constants import *
2
+ from .dataset import EMDATASET
3
+ from .mel import melspectrogram
4
+ from .transcriber import OnsetsAndFrames, OnsetsNoFrames
5
+ from .utils import *
onsets_and_frames/constants.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ SAMPLE_RATE = 16000
5
+ HOP_LENGTH = 512
6
+ ONSET_LENGTH = HOP_LENGTH
7
+ OFFSET_LENGTH = HOP_LENGTH
8
+
9
+ HOPS_IN_ONSET = ONSET_LENGTH // HOP_LENGTH
10
+ HOPS_IN_OFFSET = OFFSET_LENGTH // HOP_LENGTH
11
+ MIN_MIDI = 21
12
+ MAX_MIDI = 108
13
+ N_KEYS = MAX_MIDI - MIN_MIDI + 1
14
+
15
+ DTW_FACTOR = 3
16
+
17
+ N_MELS = 229
18
+ MEL_FMIN = 30
19
+ MEL_FMAX = SAMPLE_RATE // 2
20
+ WINDOW_LENGTH = 2048
21
+
22
+ SEQ_LEN = 327680 # 20 seconds
23
+
24
+ DRUM_CHANNEL = 9
25
+
26
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
onsets_and_frames/dataset.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ import time
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import soundfile
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+ from tqdm import tqdm
12
+
13
+ from onsets_and_frames import constants
14
+ from onsets_and_frames.constants import DEFAULT_DEVICE, N_KEYS, SAMPLE_RATE
15
+ from onsets_and_frames.mel import melspectrogram
16
+ from onsets_and_frames.midi_utils import (
17
+ midi_to_frames,
18
+ save_midi_alignments_and_predictions,
19
+ )
20
+ from onsets_and_frames.utils import (
21
+ get_diff,
22
+ get_logger,
23
+ get_peaks,
24
+ shift_label,
25
+ smooth_labels,
26
+ )
27
+
28
+
29
+ class EMDATASET(Dataset):
30
+ def __init__(
31
+ self,
32
+ audio_path="NoteEM_audio",
33
+ tsv_path="NoteEM_tsv",
34
+ labels_path="NoteEm_labels",
35
+ groups=None,
36
+ sequence_length=None,
37
+ seed=42,
38
+ device=DEFAULT_DEVICE,
39
+ instrument_map=None,
40
+ update_instruments=False,
41
+ transcriber=None,
42
+ conversion_map=None,
43
+ pitch_shift=True,
44
+ pitch_shift_limit=5,
45
+ keep_eval_files=False,
46
+ n_eval=1,
47
+ evaluation_list=None,
48
+ only_eval=False,
49
+ save_to_memory=False,
50
+ smooth_labels=False,
51
+ use_onset_mask=False,
52
+ ):
53
+ # Get the dataset logger (logging system should already be initialized by train.py)
54
+ self.logger = get_logger("dataset")
55
+
56
+ self.audio_path = audio_path
57
+ self.tsv_path = tsv_path
58
+ self.labels_path = labels_path
59
+ self.sequence_length = sequence_length
60
+ self.device = device
61
+ self.random = np.random.RandomState(seed)
62
+ self.groups = groups
63
+ self.conversion_map = conversion_map
64
+ self.eval_file_list = []
65
+ self.file_list = self.files(
66
+ self.groups,
67
+ pitch_shift=pitch_shift,
68
+ keep_eval_files=keep_eval_files,
69
+ n_eval=n_eval,
70
+ evaluation_list=evaluation_list,
71
+ pitch_shift_limit=pitch_shift_limit,
72
+ )
73
+ self.save_to_memory = save_to_memory
74
+ self.smooth_labels = smooth_labels
75
+ self.use_onset_mask = use_onset_mask
76
+ self.pitch_shift_limit = pitch_shift_limit
77
+
78
+ self.logger.debug("Save to memory is %s", self.save_to_memory)
79
+ self.logger.info("len file list %d", len(self.file_list))
80
+ self.logger.info("\n\n")
81
+
82
+ if instrument_map is None:
83
+ self.get_instruments(conversion_map=conversion_map)
84
+ else:
85
+ self.instruments = instrument_map
86
+ if update_instruments:
87
+ self.add_instruments()
88
+ self.transcriber = transcriber
89
+ if only_eval:
90
+ return
91
+ self.load_pts(self.file_list)
92
+ self.data = []
93
+ self.logger.info("Reading files...")
94
+ for input_files in tqdm(self.file_list, desc="creating data list"):
95
+ flac, _ = input_files
96
+ audio_len = librosa.get_duration(path=flac)
97
+ minutes = int(np.ceil(audio_len / 60))
98
+ copies = minutes
99
+ for _ in range(copies):
100
+ self.data.append(input_files)
101
+ random.shuffle(self.data)
102
+
103
+ def flac_to_pt_path(self, flac):
104
+ pt_fname = os.path.basename(flac).replace(".flac", ".pt")
105
+ pt_path = os.path.join(self.labels_path, pt_fname)
106
+ return pt_path
107
+
108
+ def __len__(self):
109
+ return len(self.data)
110
+
111
+ def files(
112
+ self,
113
+ groups,
114
+ pitch_shift=True,
115
+ keep_eval_files=False,
116
+ n_eval=1,
117
+ evaluation_list=None,
118
+ pitch_shift_limit=5,
119
+ ):
120
+ self.path = self.audio_path
121
+ tsvs_path = self.tsv_path
122
+ self.logger.info("tsv path: %s", tsvs_path)
123
+ self.logger.info("Evaluation list: %s", evaluation_list)
124
+ res = []
125
+ self.logger.info("keep eval files: %s", keep_eval_files)
126
+ self.logger.info("n eval: %d", n_eval)
127
+ for group in groups:
128
+ tsvs = os.listdir(tsvs_path + os.sep + group)
129
+ tsvs = sorted(tsvs)
130
+ if keep_eval_files and evaluation_list is None:
131
+ eval_tsvs = tsvs[:n_eval]
132
+ tsvs = tsvs[n_eval:]
133
+ elif keep_eval_files and evaluation_list is not None:
134
+ eval_tsvs_names = [
135
+ i.split("#")[0].split(".flac")[0].split(".tsv")[0]
136
+ for i in evaluation_list
137
+ ]
138
+ eval_tsvs = [
139
+ i
140
+ for i in tsvs
141
+ if i.split("#")[0].split(".tsv")[0] in eval_tsvs_names
142
+ ]
143
+ tsvs = [i for i in tsvs if i not in eval_tsvs]
144
+ else:
145
+ eval_tsvs = []
146
+ self.logger.info("len tsvs: %d", len(tsvs))
147
+
148
+ tsvs_names = [t.split(".tsv")[0].split("#")[0] for t in tsvs]
149
+ eval_tsvs_names = [t.split(".tsv")[0].split("#")[0] for t in eval_tsvs]
150
+ for shft in range(-5, 6):
151
+ if shft != 0 and not pitch_shift or abs(shft) > pitch_shift_limit:
152
+ continue
153
+ curr_fls_pth = self.path + os.sep + group + "#{}".format(shft)
154
+
155
+ fls = os.listdir(curr_fls_pth)
156
+ orig_files = fls
157
+ # print(f"files names before\n {fls}")
158
+ fls = [
159
+ i for i in fls if i.split("#")[0] in tsvs_names
160
+ ] # in case we dont have the corresponding midi
161
+ missing_fls = [i for i in orig_files if i not in fls]
162
+ if len(missing_fls) > 0:
163
+ self.logger.warning("missing files: %s", missing_fls)
164
+ fls_names = [i.split("#")[0].split(".flac")[0] for i in fls]
165
+ tsvs = [
166
+ i for i in tsvs if i.split(".tsv")[0].split("#")[0] in fls_names
167
+ ]
168
+ assert len(tsvs) == len(fls)
169
+ # print(f"files names after\n {fls}")
170
+ fls = sorted(fls)
171
+
172
+ if shft == 0:
173
+ eval_fls = os.listdir(curr_fls_pth)
174
+ # print(f"files names\n {eval_fls}")
175
+ eval_fls = [
176
+ i for i in eval_fls if i.split("#")[0] in eval_tsvs_names
177
+ ] # in case we dont have the corresponding midi
178
+ eval_fls_names = [i.split("#")[0] for i in eval_fls]
179
+ eval_tsvs = [
180
+ i
181
+ for i in eval_tsvs
182
+ if i.split(".tsv")[0].split("#")[0] in eval_fls_names
183
+ ]
184
+ assert len(eval_fls_names) == len(eval_tsvs_names)
185
+ # print(f"files names\n {eval_fls}")
186
+ eval_fls = sorted(eval_fls)
187
+ for f, t in zip(eval_fls, eval_tsvs):
188
+ self.eval_file_list.append(
189
+ (
190
+ curr_fls_pth + os.sep + f,
191
+ tsvs_path + os.sep + group + os.sep + t,
192
+ )
193
+ )
194
+
195
+ for f, t in zip(fls, tsvs):
196
+ res.append(
197
+ (
198
+ curr_fls_pth + os.sep + f,
199
+ tsvs_path + os.sep + group + os.sep + t,
200
+ )
201
+ )
202
+
203
+ for flac, tsv in res:
204
+ if (
205
+ os.path.basename(flac).split("#")[0].split(".flac")[0]
206
+ != os.path.basename(tsv).split("#")[0].split(".tsv")[0]
207
+ ):
208
+ self.logger.warning("found mismatch in the files: ")
209
+ self.logger.warning("flac: %s", os.path.basename(flac).split("#")[0])
210
+ self.logger.warning("tsv: %s", os.path.basename(tsv).split("#")[0])
211
+ self.logger.warning("please check the input files")
212
+ exit(1)
213
+ return res
214
+
215
+ def get_instruments(self, conversion_map=None):
216
+ instruments = set()
217
+ for _, f in self.file_list:
218
+ events = np.loadtxt(f, delimiter="\t", skiprows=1)
219
+ curr_instruments = set(events[:, -1])
220
+ if conversion_map is not None:
221
+ curr_instruments = {
222
+ conversion_map[c] if c in conversion_map else c
223
+ for c in curr_instruments
224
+ }
225
+ instruments = instruments.union(curr_instruments)
226
+ instruments = [int(elem) for elem in instruments if elem < 115]
227
+ if conversion_map is not None:
228
+ instruments = [i for i in instruments if i in conversion_map]
229
+ instruments = list(set(instruments))
230
+ if 0 in instruments:
231
+ piano_ind = instruments.index(0)
232
+ instruments.pop(piano_ind)
233
+ instruments.insert(0, 0)
234
+ self.instruments = instruments
235
+ self.instruments = list(
236
+ set(self.instruments) - set(range(88, 104)) - set(range(112, 150))
237
+ )
238
+ self.logger.info("Dataset instruments: %s", self.instruments)
239
+ self.logger.info("Total: %d instruments", len(self.instruments))
240
+
241
+ def add_instruments(self):
242
+ for _, f in self.file_list:
243
+ events = np.loadtxt(f, delimiter="\t", skiprows=1)
244
+ curr_instruments = set(events[:, -1])
245
+ new_instruments = curr_instruments - set(self.instruments)
246
+ self.instruments += list(new_instruments)
247
+ instruments = [int(elem) for elem in self.instruments if (elem < 115)]
248
+ self.instruments = instruments
249
+
250
+ def __getitem__(self, index):
251
+ data = self.load(*self.data[index])
252
+ # result = dict(path=data['path'])
253
+ midi_length = len(data["label"])
254
+ n_steps = self.sequence_length // constants.HOP_LENGTH
255
+ if midi_length < n_steps:
256
+ step_begin = 0
257
+ step_end = midi_length
258
+ else:
259
+ step_begin = self.random.randint(max(midi_length - n_steps, 1))
260
+ step_end = step_begin + n_steps
261
+ begin = step_begin * constants.HOP_LENGTH
262
+ end = begin + self.sequence_length
263
+
264
+ audio = (
265
+ data["audio"][begin:end].float().div_(32768.0)
266
+ ) # torch.ShortTensor → float
267
+ label = data["label"][step_begin:step_end].clone() # torch.Tensor
268
+
269
+ if audio.shape[0] < self.sequence_length:
270
+ pad_amt = self.sequence_length - audio.shape[0]
271
+ audio = torch.cat([audio, torch.zeros(pad_amt, dtype=audio.dtype)], dim=0)
272
+
273
+ if label.shape[0] < n_steps:
274
+ pad_amt = n_steps - label.shape[0]
275
+ label = torch.cat(
276
+ [label, torch.zeros((pad_amt, *label.shape[1:]), dtype=label.dtype)],
277
+ dim=0,
278
+ )
279
+
280
+ audio = torch.clamp(audio, -1.0, 1.0)
281
+ result = {"path": data["path"], "audio": audio, "label": label}
282
+ if "velocity" in data:
283
+ result["velocity"] = data["velocity"][step_begin:step_end, ...]
284
+ result["velocity"] = result["velocity"].float() / 128.0
285
+
286
+ if result["label"].max() < 3:
287
+ result["onset"] = result["label"].float()
288
+ else:
289
+ result["onset"] = (result["label"] == 3).float()
290
+
291
+ result["offset"] = (result["label"] == 1).float()
292
+ result["frame"] = (result["label"] > 1).float()
293
+
294
+ if self.smooth_labels:
295
+ result["onset"] = smooth_labels(result["onset"])
296
+ if self.use_onset_mask:
297
+ if "onset_mask" in data:
298
+ result["onset_mask"] = data["onset_mask"][
299
+ step_begin:step_end, ...
300
+ ].float()
301
+ else:
302
+ result["onset_mask"] = torch.ones_like(result["onset"]).float()
303
+ if "frame_mask" in data:
304
+ result["frame_mask"] = data["frame_mask"][
305
+ step_begin:step_end, ...
306
+ ].float()
307
+ else:
308
+ result["frame_mask"] = torch.ones_like(result["frame"]).float()
309
+
310
+ shape = result["frame"].shape
311
+ keys = N_KEYS
312
+ new_shape = shape[:-1] + (shape[-1] // keys, keys)
313
+ result["big_frame"] = result["frame"]
314
+ result["frame"], _ = result["frame"].reshape(new_shape).max(axis=-2)
315
+
316
+ # if 'frame_mask' not in data:
317
+ # result['frame_mask'] = torch.ones_like(result['frame']).to(self.device).float()
318
+
319
+ result["big_offset"] = result["offset"]
320
+ result["offset"], _ = result["offset"].reshape(new_shape).max(axis=-2)
321
+ result["group"] = self.data[index][0].split(os.sep)[-2].split("#")[0]
322
+
323
+ return result
324
+
325
+ def load(self, audio_path, tsv_path):
326
+ if self.save_to_memory:
327
+ data = self.pts[audio_path]
328
+ else:
329
+ data = torch.load(self.flac_to_pt_path(audio_path))
330
+ if len(data["audio"].shape) > 1:
331
+ data["audio"] = (data["audio"].float().mean(dim=-1)).short()
332
+ if "label" in data:
333
+ return data
334
+ else:
335
+ piece, part = audio_path.split(os.sep)[-2:]
336
+ piece_split = piece.split("#")
337
+ if len(piece_split) == 2:
338
+ piece, shift1 = piece_split
339
+ else:
340
+ piece, shift1 = "#".join(piece_split[:2]), piece_split[-1]
341
+ part_split = part.split("#")
342
+ if len(part_split) == 2:
343
+ part, shift2 = part_split
344
+ else:
345
+ part, shift2 = "#".join(part_split[:2]), part_split[-1]
346
+ shift2, _ = shift2.split(".")
347
+ assert shift1 == shift2
348
+ shift = shift1
349
+ assert shift != 0
350
+ orig = audio_path.replace("#{}".format(shift), "#0")
351
+ if self.save_to_memory:
352
+ orig_data = self.pts[orig]
353
+ else:
354
+ orig_data = torch.load(self.flac_to_pt_path(orig))
355
+ res = {}
356
+ res["label"] = shift_label(orig_data["label"], int(shift))
357
+ res["path"] = audio_path
358
+ res["audio"] = data["audio"]
359
+ if "velocity" in orig_data:
360
+ res["velocity"] = shift_label(orig_data["velocity"], int(shift))
361
+ if "onset_mask" in orig_data:
362
+ res["onset_mask"] = shift_label(orig_data["onset_mask"], int(shift))
363
+ if "frame_mask" in orig_data:
364
+ res["frame_mask"] = shift_label(orig_data["frame_mask"], int(shift))
365
+ return res
366
+
367
+ def load_pts(self, files):
368
+ self.pts = {}
369
+ self.logger.info("loading pts...")
370
+ for flac, tsv in tqdm(files, desc="loading pts"):
371
+ # print('flac, tsv', flac, tsv)
372
+ if os.path.isfile(
373
+ self.labels_path
374
+ + os.sep
375
+ + flac.split(os.sep)[-1].replace(".flac", ".pt")
376
+ ):
377
+ if self.save_to_memory:
378
+ self.pts[flac] = torch.load(
379
+ self.labels_path
380
+ + os.sep
381
+ + flac.split(os.sep)[-1].replace(".flac", ".pt")
382
+ )
383
+ else:
384
+ if flac.count("#") != 2:
385
+ self.logger.debug("two # in filename: %s", flac)
386
+ audio, sr = soundfile.read(flac, dtype="int16")
387
+ if len(audio.shape) == 2:
388
+ audio = audio.astype(float).mean(axis=1)
389
+ else:
390
+ audio = audio.astype(float)
391
+ audio = audio.astype(np.int16)
392
+ self.logger.debug("audio len: %d", len(audio))
393
+ assert sr == SAMPLE_RATE
394
+ audio = torch.ShortTensor(audio)
395
+ if "#0" not in flac:
396
+ assert "#" in flac
397
+ data = {"audio": audio}
398
+ if self.save_to_memory:
399
+ self.pts[flac] = data
400
+ torch.save(data, self.flac_to_pt_path(flac))
401
+ continue
402
+ midi = np.loadtxt(tsv, delimiter="\t", skiprows=1)
403
+ unaligned_label = midi_to_frames(
404
+ midi, self.instruments, conversion_map=self.conversion_map
405
+ )
406
+ if len(self.instruments) == 1:
407
+ unaligned_label = unaligned_label[:, -N_KEYS:]
408
+ if len(unaligned_label) < self.sequence_length // constants.HOP_LENGTH:
409
+ diff = self.sequence_length // constants.HOP_LENGTH - len(
410
+ unaligned_label
411
+ )
412
+ pad = torch.zeros(
413
+ (diff, unaligned_label.shape[1]), dtype=unaligned_label.dtype
414
+ )
415
+ unaligned_label = torch.cat((unaligned_label, pad), dim=0)
416
+
417
+ group = flac.split(os.sep)[-2].split("#")[0]
418
+ data = dict(
419
+ path=self.labels_path + os.sep + flac.split(os.sep)[-1],
420
+ audio=audio,
421
+ unaligned_label=unaligned_label,
422
+ group=group,
423
+ BON=float("inf"),
424
+ BON_VEC=np.full(unaligned_label.shape[1], float("inf")),
425
+ )
426
+
427
+ torch.save(data, self.flac_to_pt_path(flac))
428
+ if self.save_to_memory:
429
+ self.pts[flac] = data
430
+
431
+ def update_pts_counting(
432
+ self,
433
+ transcriber,
434
+ counting_window_length,
435
+ POS=1.1,
436
+ NEG=-0.001,
437
+ FRAME_POS=0.5,
438
+ to_save=None,
439
+ first=False,
440
+ update=True,
441
+ BEST_DIST=False,
442
+ peak_size=3,
443
+ BEST_DIST_VEC=False,
444
+ counting_window_hop=0,
445
+ ):
446
+ self.logger.info("Updating pts...")
447
+ self.logger.info("First %s", first)
448
+ total_counting_time = 0.0 # Initialize total time for counting-based alignment
449
+
450
+ self.logger.info("POS, NEG: %s, %s", POS, NEG)
451
+ if to_save is not None:
452
+ os.makedirs(to_save, exist_ok=True)
453
+ self.logger.info("There are %d pts", len(self.pts))
454
+ update_count = 0
455
+ sys.stdout.flush()
456
+ onlt_pitch_0_files = [f for f in self.file_list if "#0" in f[0]]
457
+ for input_files in tqdm(onlt_pitch_0_files, desc="updating pts"):
458
+ flac, tsv = input_files
459
+ data = torch.load(self.flac_to_pt_path(flac))
460
+ if "unaligned_label" not in data:
461
+ self.logger.warning("No unaligned labels for %s", flac)
462
+ continue
463
+ audio_inp = data["audio"].float() / 32768.0
464
+ MAX_TIME = 5 * 60 * SAMPLE_RATE
465
+ audio_inp_len = len(audio_inp)
466
+ if audio_inp_len > MAX_TIME:
467
+ n_segments = int(np.ceil(audio_inp_len / MAX_TIME))
468
+ self.logger.debug("Long audio, splitting to %d segments", n_segments)
469
+ seg_len = MAX_TIME
470
+ onsets_preds = []
471
+ offset_preds = []
472
+ frame_preds = []
473
+ for i_s in range(n_segments):
474
+ curr = (
475
+ audio_inp[i_s * seg_len : (i_s + 1) * seg_len]
476
+ .unsqueeze(0)
477
+ .cuda()
478
+ )
479
+ curr_mel = melspectrogram(
480
+ curr.reshape(-1, curr.shape[-1])[:, :-1]
481
+ ).transpose(-1, -2)
482
+ (
483
+ curr_onset_pred,
484
+ curr_offset_pred,
485
+ _,
486
+ curr_frame_pred,
487
+ curr_velocity_pred,
488
+ ) = transcriber(curr_mel)
489
+ onsets_preds.append(curr_onset_pred)
490
+ offset_preds.append(curr_offset_pred)
491
+ frame_preds.append(curr_frame_pred)
492
+ onset_pred = torch.cat(onsets_preds, dim=1)
493
+ offset_pred = torch.cat(offset_preds, dim=1)
494
+ frame_pred = torch.cat(frame_preds, dim=1)
495
+ else:
496
+ audio_inp = audio_inp.unsqueeze(0).cuda()
497
+ mel = melspectrogram(
498
+ audio_inp.reshape(-1, audio_inp.shape[-1])[:, :-1]
499
+ ).transpose(-1, -2)
500
+ onset_pred, offset_pred, _, frame_pred, _ = transcriber(mel)
501
+ self.logger.debug("Done predicting.")
502
+
503
+ # We assume onset predictions are of length N_KEYS * (len(instruments) + 1),
504
+ # first N_KEYS classes are the first instrument, next N_KEYS classes are the next instrument, etc.,
505
+ # and last N_KEYS classes are for pitch regardless of instrument
506
+ # Currently, frame and offset predictions are only N_KEYS classes.
507
+ onset_pred = onset_pred.detach().squeeze().cpu()
508
+ frame_pred = frame_pred.detach().squeeze().cpu()
509
+
510
+ PEAK_SIZE = peak_size
511
+ self.logger.debug("PEAK_SIZE: %d", PEAK_SIZE)
512
+ # we peak peak the onset prediction to only keep local maximum onsets
513
+ if peak_size > 0:
514
+ peaks = get_peaks(
515
+ onset_pred, PEAK_SIZE
516
+ ) # we only want local peaks, in a 7-frame neighborhood, 3 to each side.
517
+ onset_pred[~peaks] = 0
518
+
519
+ unaligned_onsets = (data["unaligned_label"] == 3).float().numpy()
520
+
521
+ onset_pred_np = onset_pred.numpy()
522
+ frame_pred_np = frame_pred.numpy()
523
+
524
+ pred_bag_of_notes = (onset_pred_np[:, -N_KEYS:] >= 0.5).sum(axis=0)
525
+ gt_bag_of_notes = unaligned_onsets[:, -N_KEYS:].astype(bool).sum(axis=0)
526
+ bon_dist = (((pred_bag_of_notes - gt_bag_of_notes) ** 2).sum()) ** 0.5
527
+
528
+ pred_bag_of_notes_with_inst = (onset_pred_np >= 0.5).sum(axis=0)
529
+ gt_bag_of_notes_with_inst = unaligned_onsets.astype(bool).sum(axis=0)
530
+ bon_dist_vec = np.abs(
531
+ pred_bag_of_notes_with_inst - gt_bag_of_notes_with_inst
532
+ )
533
+
534
+ bon_dist /= gt_bag_of_notes.sum()
535
+ self.logger.debug("bag of notes dist: %f", bon_dist)
536
+ ####
537
+
538
+ aligned_onsets = np.zeros(onset_pred_np.shape, dtype=bool)
539
+ aligned_frames = np.zeros(onset_pred_np.shape, dtype=bool)
540
+
541
+ # This block is the main difference between the counting approach and the DTW approach.
542
+ # In the counting approach we label the audio by counting note onsets: For each onset pitch class,
543
+ # denote by K the number of times it occurs in the unaligned label. We simply take the K highest local
544
+ # peaks predicted by the current model.
545
+ # Split unaligned onsets into chunks of size counting_window_length
546
+ self.logger.debug(
547
+ "unaligned onsets shape: %s, counting window length: %d, counting window hop: %d",
548
+ unaligned_onsets.shape,
549
+ counting_window_length,
550
+ counting_window_hop,
551
+ )
552
+ assert counting_window_hop <= counting_window_length
553
+ if counting_window_hop == 0:
554
+ counting_window_hop = counting_window_length
555
+
556
+ num_chunks = (
557
+ 1
558
+ if counting_window_length == 0
559
+ else int(np.ceil(len(unaligned_onsets) / counting_window_hop))
560
+ )
561
+
562
+ self.logger.debug("number of chunks: %d", num_chunks)
563
+ start_time = time.time()
564
+ for chunk_idx in range(num_chunks):
565
+ start_idx = chunk_idx * counting_window_hop
566
+ if counting_window_length == 0:
567
+ end_idx = max(len(unaligned_onsets), len(onset_pred_np))
568
+ else:
569
+ end_idx = min(
570
+ start_idx + counting_window_length, len(unaligned_onsets)
571
+ )
572
+ chunk_onsets = unaligned_onsets[start_idx:end_idx]
573
+ chunk_onsets_count = (
574
+ (data["unaligned_label"][start_idx:end_idx, :] == 3)
575
+ .sum(dim=0)
576
+ .numpy()
577
+ )
578
+
579
+ for f, f_count in enumerate(chunk_onsets_count):
580
+ if f_count == 0:
581
+ continue
582
+ f_most_likely = np.sort(
583
+ onset_pred_np[start_idx:end_idx, f].argsort()[::-1][:f_count]
584
+ )
585
+ f_most_likely += start_idx # Adjust indices to the original size
586
+ aligned_onsets[f_most_likely, f] = 1
587
+
588
+ f_unaligned = chunk_onsets[:, f].nonzero()
589
+ assert len(f_unaligned) == 1
590
+ f_unaligned = f_unaligned[0]
591
+
592
+ counting_duration = time.time() - start_time
593
+ total_counting_time += counting_duration
594
+ self.logger.debug(
595
+ "Counting alignment for file '%s' took %.2f seconds.",
596
+ flac,
597
+ counting_duration,
598
+ )
599
+
600
+ # Pseudo labels, Pos bigger than 1 is equivalent to not using pseudo labels
601
+ pseudo_onsets = (onset_pred_np >= POS) & (~aligned_onsets)
602
+
603
+ onset_label = np.maximum(pseudo_onsets, aligned_onsets)
604
+
605
+ # in this project we do not train frame stack but we calculate the labeels anyways
606
+ pseudo_frames = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
607
+ pseudo_offsets = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
608
+ for t, f in zip(*onset_label.nonzero()):
609
+ t_off = t
610
+ while (
611
+ t_off < len(pseudo_frames)
612
+ and frame_pred[t_off, f % N_KEYS] >= FRAME_POS
613
+ ):
614
+ t_off += 1
615
+ pseudo_frames[t:t_off, f] = 1
616
+ if t_off < len(pseudo_offsets):
617
+ pseudo_offsets[t_off, f] = 1
618
+ frame_label = np.maximum(pseudo_frames, aligned_frames)
619
+ offset_label = get_diff(frame_label, offset=True)
620
+
621
+ label = np.maximum(2 * frame_label, offset_label)
622
+ label = np.maximum(3 * onset_label, label).astype(np.uint8)
623
+
624
+ if to_save is not None:
625
+ save_midi_alignments_and_predictions(
626
+ to_save,
627
+ data["path"],
628
+ self.instruments,
629
+ aligned_onsets,
630
+ aligned_frames,
631
+ onset_pred_np,
632
+ frame_pred_np,
633
+ prefix="",
634
+ group=data["group"],
635
+ )
636
+ prev_bon_dist = data.get("BON", float("inf"))
637
+ prev_bon_dist_vec = data.get("BON_VEC", None)
638
+ if update:
639
+ if BEST_DIST_VEC:
640
+ self.logger.debug("Updated Labels")
641
+ if prev_bon_dist_vec is None:
642
+ raise ValueError(
643
+ "BEST_DIST_VEC is True but no previous BON_VEC found"
644
+ )
645
+ prev_label = data["label"]
646
+ new_label = torch.from_numpy(label).byte()
647
+ if first:
648
+ prev_label = new_label
649
+ update_count += 1
650
+ else:
651
+ updated_flag = False
652
+ num_pitches_updated = 0
653
+ for k in range(prev_label.shape[1]):
654
+ if prev_bon_dist_vec[k] > bon_dist_vec[k]:
655
+ prev_label[:, k] = new_label[:, k]
656
+ prev_bon_dist_vec[k] = bon_dist_vec[k]
657
+ num_pitches_updated += 1
658
+ updated_flag = True
659
+ if updated_flag:
660
+ update_count += 1
661
+ self.logger.debug("Updated %d pitches", num_pitches_updated)
662
+ data["label"] = prev_label
663
+ data["BON_VEC"] = prev_bon_dist_vec
664
+ self.logger.debug("saved updated pt")
665
+ torch.save(
666
+ data,
667
+ self.labels_path
668
+ + os.sep
669
+ + flac.split(os.sep)[-1]
670
+ .replace(".flac", ".pt")
671
+ .replace(".mp3", ".pt"),
672
+ )
673
+
674
+ elif not BEST_DIST or bon_dist < prev_bon_dist:
675
+ update_count += 1
676
+ self.logger.debug("Updated Labels")
677
+
678
+ data["label"] = torch.from_numpy(label).byte()
679
+
680
+ data["BON"] = bon_dist
681
+ self.logger.debug("saved updated pt")
682
+ torch.save(
683
+ data,
684
+ self.labels_path
685
+ + os.sep
686
+ + flac.split(os.sep)[-1]
687
+ .replace(".flac", ".pt")
688
+ .replace(".mp3", ".pt"),
689
+ )
690
+
691
+ if bon_dist < prev_bon_dist:
692
+ self.logger.debug(
693
+ "Bag of notes distance improved from %f to %f",
694
+ prev_bon_dist,
695
+ bon_dist,
696
+ )
697
+ data["BON"] = bon_dist
698
+
699
+ if to_save is not None and BEST_DIST:
700
+ os.makedirs(to_save + "/BEST_BON", exist_ok=True)
701
+ save_midi_alignments_and_predictions(
702
+ to_save + "/BEST_BON",
703
+ data["path"],
704
+ self.instruments,
705
+ aligned_onsets,
706
+ aligned_frames,
707
+ onset_pred_np,
708
+ frame_pred_np,
709
+ prefix="BEST_BON",
710
+ group=data["group"],
711
+ use_time=False,
712
+ )
713
+
714
+ self.logger.info(
715
+ "Updated %d pts out of %d", update_count, len(onlt_pitch_0_files)
716
+ )
717
+ self.logger.info(
718
+ "Total counting alignment time for all files: %.2f seconds.", total_counting_time
719
+ )
onsets_and_frames/decoding.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def extract_notes(onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5):
6
+ """
7
+ Finds the note timings based on the onsets and frames information
8
+
9
+ Parameters
10
+ ----------
11
+ onsets: torch.FloatTensor, shape = [frames, bins]
12
+ frames: torch.FloatTensor, shape = [frames, bins]
13
+ velocity: torch.FloatTensor, shape = [frames, bins]
14
+ onset_threshold: float
15
+ frame_threshold: float
16
+
17
+ Returns
18
+ -------
19
+ pitches: np.ndarray of bin_indices
20
+ intervals: np.ndarray of rows containing (onset_index, offset_index)
21
+ velocities: np.ndarray of velocity values
22
+ """
23
+ # onsets_forward = torch.roll(onsets, shifts=(1, 0), dims=(0, 1))
24
+ # onsets_forward[0, :] = 0
25
+ # onsets_backward = torch.roll(onsets, shifts=(-1, 0), dims=(0, 1))
26
+ # onsets_backward[-1, :] = 0
27
+ # onsets_peak = torch.logical_and(onsets >= onsets_forward, onsets >= onsets_backward)
28
+ # onsets_peak = torch.logical_and(onsets >= 0.25, onsets_peak)
29
+
30
+ onsets = (onsets > onset_threshold).cpu().to(torch.uint8)
31
+ frames = (frames > frame_threshold).cpu().to(torch.uint8)
32
+ onset_diff = torch.cat([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], dim=0) == 1
33
+ # onset_diff = torch.cat([frames[:1, :], frames[1:, :] - frames[:-1, :]], dim=0) == 1
34
+
35
+ pitches = []
36
+ intervals = []
37
+ velocities = []
38
+
39
+ # for nonzero in onsets_peak.nonzero(as_tuple=False):
40
+ for nonzero in onset_diff.nonzero(as_tuple=False):
41
+ frame = nonzero[0].item()
42
+ pitch = nonzero[1].item()
43
+
44
+ onset = frame
45
+ offset = frame
46
+ velocity_samples = []
47
+
48
+ while onsets[offset, pitch].item() or frames[offset, pitch].item():
49
+ if onsets[offset, pitch].item():
50
+ # if frames[offset, pitch].item():
51
+ velocity_samples.append(velocity[offset, pitch].item())
52
+ offset += 1
53
+ if offset == onsets.shape[0]:
54
+ break
55
+
56
+ if offset > onset:
57
+ pitches.append(pitch)
58
+ intervals.append([onset, offset])
59
+ velocities.append(
60
+ np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
61
+ )
62
+
63
+ return np.array(pitches), np.array(intervals), np.array(velocities)
64
+
65
+
66
+ def notes_to_frames(pitches, intervals, shape, mask=None):
67
+ """
68
+ Takes lists specifying notes sequences and return
69
+
70
+ Parameters
71
+ ----------
72
+ pitches: list of pitch bin indices
73
+ intervals: list of [onset, offset] ranges of bin indices
74
+ shape: the shape of the original piano roll, [n_frames, n_bins]
75
+
76
+ Returns
77
+ -------
78
+ time: np.ndarray containing the frame indices
79
+ freqs: list of np.ndarray, each containing the frequency bin indices
80
+ """
81
+ roll = np.zeros(tuple(shape))
82
+ for pitch, (onset, offset) in zip(pitches, intervals):
83
+ # print('pitch', pitch, onset, offset)
84
+ # print('onset offset', onset, offset, pitch)
85
+ roll[onset:offset, pitch] = 1
86
+ if mask is not None:
87
+ roll *= mask
88
+ time = np.arange(roll.shape[0])
89
+ freqs = [roll[t, :].nonzero()[0] for t in time]
90
+ # if mask_size is not None:
91
+ # mask = np.zeros(tuple(shape))
92
+ # notes = roll.shape[1]
93
+ # for n in range(notes):
94
+ # onset_d = roll[1:, n] - roll[: -1, n]
95
+ # print('unique', np.unique(onset_d))
96
+ # onset_d[onset_d < 0] = 0
97
+ # print('n', n, onset_d.sum())
98
+ # onset_d = np.concatenate((np.zeros((1, 1)), roll[1:, n] - roll[: -1, n]))
99
+ # onset_d[onset_d < 0] = 0
100
+ # for r in range(mask_size):
101
+ # mask[:, n] += np.roll(onset_d, r)
102
+ return time, freqs
onsets_and_frames/hf_model.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Hub-compatible wrapper for CountEM music transcription models.
3
+ """
4
+ from pathlib import Path
5
+ from typing import Union, Tuple
6
+ import numpy as np
7
+ import torch
8
+ import soundfile as sf
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+
11
+ from onsets_and_frames.transcriber import OnsetsAndFrames
12
+ from onsets_and_frames.mel import MelSpectrogram
13
+ from onsets_and_frames.midi_utils import frames2midi
14
+ from onsets_and_frames.constants import (
15
+ N_MELS,
16
+ MIN_MIDI,
17
+ MAX_MIDI,
18
+ HOP_LENGTH,
19
+ SAMPLE_RATE,
20
+ WINDOW_LENGTH,
21
+ MEL_FMIN,
22
+ MEL_FMAX,
23
+ )
24
+
25
+
26
+ class CountEMModel(
27
+ OnsetsAndFrames,
28
+ PyTorchModelHubMixin,
29
+ # Optional metadata that gets pushed to model card
30
+ library_name="countem",
31
+ tags=["audio", "music-transcription", "automatic-music-transcription", "midi"],
32
+ license="cc-by-4.0",
33
+ repo_url="https://github.com/Yoni-Yaffe/count-the-notes",
34
+ paper_url="https://arxiv.org/abs/2511.14250",
35
+ ):
36
+ """
37
+ Hugging Face Hub-compatible wrapper for CountEM automatic music transcription models.
38
+
39
+ This model performs automatic music transcription (AMT) from audio to MIDI.
40
+ It uses the Onsets & Frames architecture trained with the CountEM framework,
41
+ which enables training with weak, unordered note count histograms.
42
+
43
+ Example usage:
44
+ ```python
45
+ from onsets_and_frames.hf_model import CountEMModel
46
+ import soundfile as sf
47
+
48
+ # Load model from Hub
49
+ model = CountEMModel.from_pretrained("Yoni-Yaffe/countem-musicnet")
50
+
51
+ # Load audio (must be 16kHz)
52
+ audio, sr = sf.read("audio.flac")
53
+ assert sr == 16000, "Audio must be 16kHz"
54
+
55
+ # Transcribe to MIDI
56
+ model.transcribe_to_midi(audio, "output.mid")
57
+ ```
58
+
59
+ Args:
60
+ model_complexity: Complexity multiplier for the model (default: 64)
61
+ onset_complexity: Complexity multiplier for onset stack (default: 1.5)
62
+ n_instruments: Number of instruments to transcribe (default: 1)
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ model_complexity: int = 64,
68
+ onset_complexity: float = 1.5,
69
+ n_instruments: int = 1,
70
+ **kwargs
71
+ ):
72
+ # Initialize the base OnsetsAndFrames model
73
+ n_keys = MAX_MIDI - MIN_MIDI + 1
74
+ OnsetsAndFrames.__init__(
75
+ self,
76
+ input_features=N_MELS,
77
+ output_features=n_keys,
78
+ model_complexity=model_complexity,
79
+ onset_complexity=onset_complexity,
80
+ n_instruments=n_instruments,
81
+ )
82
+
83
+ # Store config for HF Hub
84
+ self.config = {
85
+ "model_complexity": model_complexity,
86
+ "onset_complexity": onset_complexity,
87
+ "n_instruments": n_instruments,
88
+ "n_mels": N_MELS,
89
+ "n_keys": n_keys,
90
+ "sample_rate": SAMPLE_RATE,
91
+ "hop_length": HOP_LENGTH,
92
+ }
93
+
94
+ # Add mel spectrogram as a submodule for proper device management
95
+ # This ensures the mel transform moves with the model when calling .to(device)
96
+ self.melspectrogram = MelSpectrogram(
97
+ n_mels=N_MELS,
98
+ sample_rate=SAMPLE_RATE,
99
+ filter_length=WINDOW_LENGTH,
100
+ hop_length=HOP_LENGTH,
101
+ mel_fmin=MEL_FMIN,
102
+ mel_fmax=MEL_FMAX,
103
+ )
104
+
105
+ def forward(self, audio: Union[np.ndarray, torch.Tensor]):
106
+ """
107
+ Forward pass that accepts raw audio waveforms.
108
+
109
+ Unlike the parent OnsetsAndFrames which expects mel spectrograms,
110
+ this forward method accepts raw audio and converts it internally.
111
+
112
+ Args:
113
+ audio: Raw audio waveform, shape (batch, n_samples) or (n_samples,)
114
+ Should be normalized to [-1, 1] or will be normalized automatically
115
+
116
+ Returns:
117
+ Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
118
+ """
119
+ # Convert to torch tensor if needed
120
+ if isinstance(audio, np.ndarray):
121
+ audio = torch.from_numpy(audio).float()
122
+
123
+ # Ensure audio is in range [-1, 1]
124
+ if audio.dtype == torch.int16:
125
+ audio = audio.float() / 32768.0
126
+ elif audio.max() > 1.0 or audio.min() < -1.0:
127
+ audio = audio / max(abs(audio.max()), abs(audio.min()))
128
+
129
+ # Add batch dimension if needed
130
+ if audio.dim() == 1:
131
+ audio = audio.unsqueeze(0)
132
+
133
+ device = next(self.parameters()).device
134
+ audio = audio.to(device)
135
+
136
+ # Remove last sample to fix frame count mismatch
137
+ audio = audio[:, :-1]
138
+
139
+ mel = self.melspectrogram(audio)
140
+
141
+ # Transpose to (batch, time, features) format expected by parent model
142
+ mel = mel.transpose(-1, -2)
143
+
144
+ return super().forward(mel)
145
+
146
+ @torch.no_grad()
147
+ def transcribe(
148
+ self,
149
+ audio: Union[np.ndarray, torch.Tensor],
150
+ onset_threshold: float = 0.5,
151
+ frame_threshold: float = 0.5,
152
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
153
+ """
154
+ Transcribe audio to note predictions.
155
+
156
+ Automatically handles long audio by splitting into segments (max 5 minutes each)
157
+ to avoid memory issues.
158
+
159
+ Args:
160
+ audio: Audio waveform, shape (n_samples,), normalized to [-1, 1]
161
+ onset_threshold: Threshold for onset detection (default: 0.5)
162
+ frame_threshold: Threshold for frame detection (default: 0.5)
163
+
164
+ Returns:
165
+ Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
166
+ All are numpy arrays of shape (n_frames, 88) except velocity which may vary
167
+ """
168
+ self.eval()
169
+
170
+ # Convert to torch tensor if needed
171
+ if isinstance(audio, np.ndarray):
172
+ audio = torch.from_numpy(audio).float()
173
+
174
+ # Ensure audio is 1D (convert stereo to mono if needed)
175
+ if audio.dim() > 1:
176
+ # If stereo or multi-channel, take mean across channels
177
+ audio = audio.mean(dim=-1 if audio.shape[-1] <=2 else 0)
178
+
179
+ # Normalize audio
180
+ if audio.dtype == torch.int16:
181
+ audio = audio.float() / 32768.0
182
+ elif audio.max() > 1.0 or audio.min() < -1.0:
183
+ audio = audio / max(abs(audio.max()), abs(audio.min()))
184
+
185
+ device = next(self.parameters()).device
186
+ audio = audio.to(device)
187
+
188
+ # Handle long audio by segmenting
189
+ MAX_TIME = 5 * 60 * SAMPLE_RATE # 5 minutes
190
+ audio_len = len(audio)
191
+
192
+ if audio_len > MAX_TIME:
193
+ # Split into segments
194
+ n_segments = int(np.ceil(audio_len / MAX_TIME))
195
+ seg_len = MAX_TIME
196
+
197
+ onset_preds = []
198
+ offset_preds = []
199
+ activation_preds = []
200
+ frame_preds = []
201
+ velocity_preds = []
202
+
203
+ for i_s in range(n_segments):
204
+ start = i_s * seg_len
205
+ end = min((i_s + 1) * seg_len, audio_len)
206
+ segment = audio[start:end]
207
+
208
+ # Forward pass on segment
209
+ onset_seg, offset_seg, activation_seg, frame_seg, velocity_seg = self(segment)
210
+
211
+ onset_preds.append(onset_seg)
212
+ offset_preds.append(offset_seg)
213
+ activation_preds.append(activation_seg)
214
+ frame_preds.append(frame_seg)
215
+ velocity_preds.append(velocity_seg)
216
+
217
+ # Concatenate along time dimension (dim=1)
218
+ onset_pred = torch.cat(onset_preds, dim=1)
219
+ offset_pred = torch.cat(offset_preds, dim=1)
220
+ activation_pred = torch.cat(activation_preds, dim=1)
221
+ frame_pred = torch.cat(frame_preds, dim=1)
222
+ velocity_pred = torch.cat(velocity_preds, dim=1)
223
+ else:
224
+ # Short audio, process directly
225
+ onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred = self(audio)
226
+
227
+ # Convert to numpy and remove batch dimension
228
+ onset_pred = onset_pred.squeeze(0).cpu().numpy()
229
+ offset_pred = offset_pred.squeeze(0).cpu().numpy()
230
+ activation_pred = activation_pred.squeeze(0).cpu().numpy()
231
+ frame_pred = frame_pred.squeeze(0).cpu().numpy()
232
+ velocity_pred = velocity_pred.squeeze(0).cpu().numpy()
233
+
234
+ return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred
235
+
236
+ def transcribe_to_midi(
237
+ self,
238
+ audio: Union[np.ndarray, torch.Tensor, str, Path],
239
+ output_path: Union[str, Path],
240
+ onset_threshold: float = 0.5,
241
+ frame_threshold: float = 0.5,
242
+ ) -> None:
243
+ """
244
+ Transcribe audio to MIDI file.
245
+
246
+ Args:
247
+ audio: Audio waveform, numpy array, torch tensor, or path to audio file
248
+ output_path: Path to save MIDI file
249
+ onset_threshold: Threshold for onset detection (default: 0.5)
250
+ frame_threshold: Threshold for frame detection (default: 0.5)
251
+ """
252
+ # Load audio from file if path is provided
253
+ if isinstance(audio, (str, Path)):
254
+ audio, sr = sf.read(audio, dtype="float32")
255
+ if sr != SAMPLE_RATE:
256
+ raise ValueError(
257
+ f"Audio must be {SAMPLE_RATE}Hz, got {sr}Hz. "
258
+ f"Please resample to {SAMPLE_RATE}Hz first."
259
+ )
260
+
261
+ # Get predictions
262
+ onset_pred, offset_pred, _, frame_pred, velocity_pred = self.transcribe(
263
+ audio, onset_threshold, frame_threshold
264
+ )
265
+
266
+ # Default instrument mapping (piano)
267
+ inst_mapping = {0: 0} # instrument 0 -> MIDI program 0 (Acoustic Grand Piano)
268
+
269
+ # Convert predictions to MIDI
270
+ frames2midi(
271
+ str(output_path),
272
+ onset_pred,
273
+ frame_pred,
274
+ velocity_pred,
275
+ onset_threshold=onset_threshold,
276
+ frame_threshold=frame_threshold,
277
+ scaling=HOP_LENGTH / SAMPLE_RATE,
278
+ inst_mapping=inst_mapping,
279
+ )
280
+
281
+ def to_legacy(self) -> OnsetsAndFrames:
282
+ """
283
+ Convert this HuggingFace-compatible model to a legacy OnsetsAndFrames instance.
284
+
285
+ This is useful for:
286
+ - Fine-tuning models downloaded from HuggingFace Hub using existing training code
287
+ - Using HF models with existing inference scripts that expect OnsetsAndFrames
288
+
289
+ The legacy model will use the global melspectrogram from mel.py instead of
290
+ the instance-specific one in this model.
291
+
292
+ Returns:
293
+ OnsetsAndFrames instance with copied weights
294
+ """
295
+ # Create legacy model with same architecture
296
+ legacy_model = OnsetsAndFrames(
297
+ input_features=self.config['n_mels'],
298
+ output_features=self.config['n_keys'],
299
+ model_complexity=self.config['model_complexity'],
300
+ onset_complexity=self.config['onset_complexity'],
301
+ n_instruments=self.config['n_instruments']
302
+ )
303
+
304
+ # Get the state dict and filter out melspectrogram keys
305
+ state_dict = self.state_dict()
306
+ legacy_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('melspectrogram.')}
307
+
308
+ # Copy state dict (only model weights, not mel spectrogram)
309
+ # The legacy model will use the global melspectrogram
310
+ legacy_model.load_state_dict(legacy_state_dict)
311
+
312
+ return legacy_model
313
+
314
+ @classmethod
315
+ def from_legacy_checkpoint(
316
+ cls,
317
+ checkpoint_path: Union[str, Path],
318
+ **kwargs
319
+ ) -> "CountEMModel":
320
+ """
321
+ Load a model from a legacy checkpoint (saved with torch.save(model)).
322
+
323
+ This is useful for converting old checkpoints to the new HF-compatible format.
324
+
325
+ Args:
326
+ checkpoint_path: Path to the legacy .pt checkpoint file
327
+ **kwargs: Additional arguments for model initialization
328
+
329
+ Returns:
330
+ CountEMModel instance with loaded weights
331
+ """
332
+ # Load the legacy checkpoint
333
+ legacy_model = torch.load(checkpoint_path, map_location="cpu")
334
+
335
+ # Extract configuration from the loaded model
336
+ # Infer model_complexity from the model structure
337
+ # ConvStack.cnn[0] is the first Conv2d layer with out_channels = model_size // 16
338
+ first_conv_channels = legacy_model.offset_stack[0].cnn[0].out_channels
339
+ model_size = first_conv_channels * 16
340
+ model_complexity = model_size // 16
341
+
342
+ # Infer onset_complexity
343
+ onset_first_conv_channels = legacy_model.onset_stack[0].cnn[0].out_channels
344
+ onset_model_size = onset_first_conv_channels * 16
345
+ onset_complexity = onset_model_size / model_size
346
+
347
+ # Infer n_instruments from output layer
348
+ # onset_stack[2] is the Linear layer
349
+ onset_out_features = legacy_model.onset_stack[2].out_features
350
+ n_keys = MAX_MIDI - MIN_MIDI + 1
351
+ n_instruments = onset_out_features // n_keys
352
+
353
+ # Create new model with the same configuration
354
+ model = cls(
355
+ model_complexity=model_complexity,
356
+ onset_complexity=onset_complexity,
357
+ n_instruments=n_instruments,
358
+ **kwargs
359
+ )
360
+
361
+ # Copy the state dict (strict=False because new model has melspectrogram submodule)
362
+ model.load_state_dict(legacy_model.state_dict(), strict=False)
363
+
364
+ return model
onsets_and_frames/lstm.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class BiLSTM(nn.Module):
6
+ inference_chunk_length = 512
7
+
8
+ def __init__(self, input_features, recurrent_features, use_gru=False, dropout=0.0):
9
+ super().__init__()
10
+ self.rnn = (nn.LSTM if not use_gru else nn.GRU)(
11
+ input_features,
12
+ recurrent_features,
13
+ batch_first=True,
14
+ bidirectional=True,
15
+ dropout=dropout,
16
+ )
17
+
18
+ def forward(self, x):
19
+ if self.training:
20
+ return self.rnn(x)[0]
21
+ else:
22
+ # evaluation mode: support for longer sequences that do not fit in memory
23
+ batch_size, sequence_length, input_features = x.shape
24
+ hidden_size = self.rnn.hidden_size
25
+ num_directions = 2 if self.rnn.bidirectional else 1
26
+
27
+ h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
28
+ c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
29
+ output = torch.zeros(
30
+ batch_size,
31
+ sequence_length,
32
+ num_directions * hidden_size,
33
+ device=x.device,
34
+ )
35
+
36
+ # forward direction
37
+ slices = range(0, sequence_length, self.inference_chunk_length)
38
+ for start in slices:
39
+ end = start + self.inference_chunk_length
40
+ output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
41
+
42
+ # reverse direction
43
+ if self.rnn.bidirectional:
44
+ h.zero_()
45
+ c.zero_()
46
+
47
+ for start in reversed(slices):
48
+ end = start + self.inference_chunk_length
49
+ result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
50
+ output[:, start:end, hidden_size:] = result[:, :, hidden_size:]
51
+
52
+ return output
53
+
54
+
55
+ class UniLSTM(nn.Module):
56
+ inference_chunk_length = 512
57
+
58
+ def __init__(self, input_features, recurrent_features):
59
+ super().__init__()
60
+ self.rnn = nn.LSTM(input_features, recurrent_features, batch_first=True)
61
+
62
+ def forward(self, x):
63
+ if self.training:
64
+ return self.rnn(x)[0]
65
+ else:
66
+ # evaluation mode: support for longer sequences that do not fit in memory
67
+ batch_size, sequence_length, input_features = x.shape
68
+ hidden_size = self.rnn.hidden_size
69
+ num_directions = 2 if self.rnn.bidirectional else 1
70
+
71
+ h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
72
+ c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
73
+ output = torch.zeros(
74
+ batch_size,
75
+ sequence_length,
76
+ num_directions * hidden_size,
77
+ device=x.device,
78
+ )
79
+
80
+ # forward direction
81
+ slices = range(0, sequence_length, self.inference_chunk_length)
82
+ for start in slices:
83
+ end = start + self.inference_chunk_length
84
+ output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
85
+
86
+ # reverse direction
87
+ if self.rnn.bidirectional:
88
+ h.zero_()
89
+ c.zero_()
90
+
91
+ for start in reversed(slices):
92
+ end = start + self.inference_chunk_length
93
+ result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
94
+ output[:, start:end, hidden_size:] = result[:, :, hidden_size:]
95
+
96
+ return output
onsets_and_frames/mel.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from librosa.filters import mel
5
+ from librosa.util import pad_center
6
+ from scipy.signal import get_window
7
+ from torch.autograd import Variable
8
+
9
+ from onsets_and_frames.constants import (
10
+ DEFAULT_DEVICE,
11
+ HOP_LENGTH,
12
+ MEL_FMAX,
13
+ MEL_FMIN,
14
+ N_MELS,
15
+ SAMPLE_RATE,
16
+ WINDOW_LENGTH,
17
+ )
18
+
19
+
20
+ class STFT(torch.nn.Module):
21
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
22
+
23
+ def __init__(self, filter_length, hop_length, win_length=None, window="hann"):
24
+ super(STFT, self).__init__()
25
+ if win_length is None:
26
+ win_length = filter_length
27
+
28
+ self.filter_length = filter_length
29
+ self.hop_length = hop_length
30
+ self.win_length = win_length
31
+ self.window = window
32
+ self.forward_transform = None
33
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
34
+
35
+ cutoff = int((self.filter_length / 2 + 1))
36
+ fourier_basis = np.vstack(
37
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
38
+ )
39
+
40
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
41
+
42
+ if window is not None:
43
+ assert filter_length >= win_length
44
+ # get window and zero center pad it to filter_length
45
+ fft_window = get_window(window, win_length, fftbins=True)
46
+ fft_window = pad_center(fft_window, size=filter_length)
47
+ fft_window = torch.from_numpy(fft_window).float()
48
+
49
+ # window the bases
50
+ forward_basis *= fft_window
51
+
52
+ self.register_buffer("forward_basis", forward_basis.float())
53
+
54
+ def forward(self, input_data):
55
+ num_batches = input_data.size(0)
56
+ num_samples = input_data.size(1)
57
+
58
+ # similar to librosa, reflect-pad the input
59
+ input_data = input_data.view(num_batches, 1, num_samples)
60
+ # print('inp before', input_data.shape)
61
+ input_data = F.pad(
62
+ input_data.unsqueeze(1),
63
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
64
+ mode="reflect",
65
+ )
66
+ input_data = input_data.squeeze(1)
67
+ # print('inp after', input_data.shape)
68
+
69
+ forward_transform = F.conv1d(
70
+ input_data,
71
+ Variable(self.forward_basis, requires_grad=False),
72
+ stride=self.hop_length,
73
+ padding=0,
74
+ )
75
+ # print('fwd', forward_transform.shape)
76
+
77
+ cutoff = int((self.filter_length / 2) + 1)
78
+ real_part = forward_transform[:, :cutoff, :]
79
+ imag_part = forward_transform[:, cutoff:, :]
80
+
81
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
82
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
83
+
84
+ return magnitude, phase
85
+
86
+
87
+ class MelSpectrogram(torch.nn.Module):
88
+ def __init__(
89
+ self,
90
+ n_mels,
91
+ sample_rate,
92
+ filter_length,
93
+ hop_length,
94
+ win_length=None,
95
+ mel_fmin=0.0,
96
+ mel_fmax=None,
97
+ ):
98
+ super(MelSpectrogram, self).__init__()
99
+ self.stft = STFT(filter_length, hop_length, win_length)
100
+
101
+ mel_basis = mel(
102
+ sr=sample_rate,
103
+ n_fft=filter_length,
104
+ n_mels=n_mels,
105
+ fmin=mel_fmin,
106
+ fmax=mel_fmax,
107
+ htk=True,
108
+ )
109
+ mel_basis = torch.from_numpy(mel_basis).float()
110
+ self.register_buffer("mel_basis", mel_basis)
111
+
112
+ def forward(self, y):
113
+ """Computes mel-spectrograms from a batch of waves
114
+ PARAMS
115
+ ------
116
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
117
+ RETURNS
118
+ -------
119
+ mel_output: torch.FloatTensor of shape (B, T, n_mels)
120
+ """
121
+ assert torch.min(y.data) >= -1
122
+ assert torch.max(y.data) <= 1
123
+
124
+ magnitudes, phases = self.stft(y)
125
+ magnitudes = magnitudes.data
126
+
127
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
128
+ mel_output = torch.log(torch.clamp(mel_output, min=1e-5))
129
+ return mel_output
130
+
131
+
132
+ # the default melspectrogram converter across the project
133
+ melspectrogram = MelSpectrogram(
134
+ N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, mel_fmin=MEL_FMIN, mel_fmax=MEL_FMAX
135
+ )
136
+ melspectrogram.to(DEFAULT_DEVICE)
onsets_and_frames/midi_utils.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+
4
+ import mido
5
+ import numpy as np
6
+ import torch
7
+ from mido import Message, MidiFile, MidiTrack
8
+
9
+ from onsets_and_frames.constants import (
10
+ DRUM_CHANNEL,
11
+ HOP_LENGTH,
12
+ HOPS_IN_OFFSET,
13
+ HOPS_IN_ONSET,
14
+ MAX_MIDI,
15
+ MIN_MIDI,
16
+ N_KEYS,
17
+ SAMPLE_RATE,
18
+ )
19
+
20
+ from .utils import max_inst
21
+
22
+
23
+ def midi_to_hz(m):
24
+ return 440.0 * (2.0 ** ((m - 69.0) / 12.0))
25
+
26
+
27
+ def hz_to_midi(h):
28
+ return 12.0 * np.log2(h / (440.0)) + 69.0
29
+
30
+
31
+ def midi_to_frames(midi, instruments, conversion_map=None):
32
+ n_keys = MAX_MIDI - MIN_MIDI + 1
33
+ midi_length = int((max(midi[:, 1]) + 1) * SAMPLE_RATE)
34
+ n_steps = (midi_length - 1) // HOP_LENGTH + 1
35
+ n_channels = len(instruments) + 1
36
+ label = torch.zeros(n_steps, n_keys * n_channels, dtype=torch.uint8)
37
+ for onset, offset, note, vel, instrument in midi:
38
+ f = int(note) - MIN_MIDI
39
+ if 104 > instrument > 87 or instrument > 111:
40
+ continue
41
+ if f >= n_keys or f < 0:
42
+ continue
43
+ assert 0 < vel < 128
44
+ instrument = int(instrument)
45
+ if conversion_map is not None:
46
+ if instrument not in conversion_map:
47
+ continue
48
+ instrument = conversion_map[instrument]
49
+ left = int(round(onset * SAMPLE_RATE / HOP_LENGTH))
50
+ onset_right = min(n_steps, left + HOPS_IN_ONSET)
51
+ frame_right = int(round(offset * SAMPLE_RATE / HOP_LENGTH))
52
+ frame_right = min(n_steps, frame_right)
53
+ offset_right = min(n_steps, frame_right + HOPS_IN_OFFSET)
54
+ if int(instrument) not in instruments:
55
+ continue
56
+ chan = instruments.index(int(instrument))
57
+ label[left:onset_right, n_keys * chan + f] = 3
58
+ label[onset_right:frame_right, n_keys * chan + f] = 2
59
+ label[frame_right:offset_right, n_keys * chan + f] = 1
60
+
61
+ inv_chan = len(instruments)
62
+ label[left:onset_right, n_keys * inv_chan + f] = 3
63
+ label[onset_right:frame_right, n_keys * inv_chan + f] = 2
64
+ label[frame_right:offset_right, n_keys * inv_chan + f] = 1
65
+
66
+ return label
67
+
68
+
69
+ """
70
+ Convert piano roll to list of notes, pitch only.
71
+ """
72
+
73
+
74
+ def extract_notes_np_pitch(
75
+ onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5
76
+ ):
77
+ onsets = (onsets > onset_threshold).astype(np.uint8)
78
+ frames = (frames > frame_threshold).astype(np.uint8)
79
+ onset_diff = (
80
+ np.concatenate([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], axis=0) == 1
81
+ )
82
+
83
+ pitches = []
84
+ intervals = []
85
+ velocities = []
86
+
87
+ for nonzero in np.transpose(np.nonzero(onset_diff)):
88
+ frame = nonzero[0].item()
89
+ pitch = nonzero[1].item()
90
+
91
+ onset = frame
92
+ offset = frame
93
+ velocity_samples = []
94
+
95
+ while onsets[offset, pitch] or frames[offset, pitch]:
96
+ if onsets[offset, pitch]:
97
+ velocity_samples.append(velocity[offset, pitch])
98
+ offset += 1
99
+ if offset == onsets.shape[0]:
100
+ break
101
+
102
+ if offset > onset:
103
+ pitches.append(pitch)
104
+ intervals.append([onset, offset])
105
+ velocities.append(
106
+ np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
107
+ )
108
+ return np.array(pitches), np.array(intervals), np.array(velocities)
109
+
110
+
111
+ def extract_notes_np_rescaled(
112
+ onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5
113
+ ):
114
+ pitches, intervals, velocities, instruments = extract_notes_np(
115
+ onsets, frames, velocity, onset_threshold, frame_threshold
116
+ )
117
+ pitches += MIN_MIDI
118
+ scaling = HOP_LENGTH / SAMPLE_RATE
119
+ intervals = (intervals * scaling).reshape(-1, 2)
120
+ return pitches, intervals, velocities, instruments
121
+
122
+
123
+ """
124
+ Convert piano roll to list of notes, pitch and instrument.
125
+ """
126
+
127
+
128
+ def extract_notes_np(
129
+ onsets,
130
+ frames,
131
+ velocity,
132
+ onset_threshold=0.5,
133
+ frame_threshold=0.5,
134
+ onset_threshold_vec=None,
135
+ ):
136
+ if onset_threshold_vec is not None:
137
+ onsets = (onsets > np.array(onset_threshold_vec)).astype(np.uint8)
138
+ else:
139
+ onsets = (onsets > onset_threshold).astype(np.uint8)
140
+
141
+ frames = (frames > frame_threshold).astype(np.uint8)
142
+ onset_diff = (
143
+ np.concatenate([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], axis=0) == 1
144
+ )
145
+
146
+ if onsets.shape[-1] != frames.shape[-1]:
147
+ num_instruments = onsets.shape[1] / frames.shape[1]
148
+ assert num_instruments.is_integer()
149
+ num_instruments = int(num_instruments)
150
+ frames = np.tile(frames, (1, num_instruments))
151
+
152
+ pitches = []
153
+ intervals = []
154
+ velocities = []
155
+ instruments = []
156
+
157
+ for nonzero in np.transpose(np.nonzero(onset_diff)):
158
+ frame = nonzero[0].item()
159
+ pitch = nonzero[1].item()
160
+
161
+ onset = frame
162
+ offset = frame
163
+ velocity_samples = []
164
+
165
+ while onsets[offset, pitch] or frames[offset, pitch]:
166
+ if onsets[offset, pitch]:
167
+ velocity_samples.append(velocity[offset, pitch])
168
+ offset += 1
169
+ if offset == onsets.shape[0]:
170
+ break
171
+
172
+ if offset > onset:
173
+ pitch, instrument = pitch % N_KEYS, pitch // N_KEYS
174
+
175
+ pitches.append(pitch)
176
+ intervals.append([onset, offset])
177
+ velocities.append(
178
+ np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
179
+ )
180
+ instruments.append(instrument)
181
+ return (
182
+ np.array(pitches),
183
+ np.array(intervals),
184
+ np.array(velocities),
185
+ np.array(instruments),
186
+ )
187
+
188
+
189
+ def append_track_multi(file, pitches, intervals, velocities, ins, single_ins=False):
190
+ track = MidiTrack()
191
+ file.tracks.append(track)
192
+ chan = len(file.tracks) - 1
193
+ if chan >= DRUM_CHANNEL:
194
+ chan += 1
195
+ if chan > 15:
196
+ print(f"invalid chan {chan}")
197
+ chan = 15
198
+ track.append(
199
+ Message(
200
+ "program_change", channel=chan, program=ins if not single_ins else 0, time=0
201
+ )
202
+ )
203
+
204
+ ticks_per_second = file.ticks_per_beat * 2.0
205
+
206
+ events = []
207
+ for i in range(len(pitches)):
208
+ events.append(
209
+ dict(
210
+ type="on",
211
+ pitch=pitches[i],
212
+ time=intervals[i][0],
213
+ velocity=velocities[i],
214
+ )
215
+ )
216
+ events.append(
217
+ dict(
218
+ type="off",
219
+ pitch=pitches[i],
220
+ time=intervals[i][1],
221
+ velocity=velocities[i],
222
+ )
223
+ )
224
+ events.sort(key=lambda row: row["time"])
225
+
226
+ last_tick = 0
227
+ for event in events:
228
+ current_tick = int(event["time"] * ticks_per_second)
229
+ velocity = int(event["velocity"] * 127)
230
+ if velocity > 127:
231
+ velocity = 127
232
+ pitch = int(round(hz_to_midi(event["pitch"])))
233
+ track.append(
234
+ Message(
235
+ "note_" + event["type"],
236
+ channel=chan,
237
+ note=pitch,
238
+ velocity=velocity,
239
+ time=current_tick - last_tick,
240
+ )
241
+ )
242
+ # try:
243
+ # track.append(Message('note_' + event['type'], channel=chan, note=pitch, velocity=velocity, time=current_tick - last_tick))
244
+ # except Exception as e:
245
+ # print('Err Message', 'note_' + event['type'], pitch, velocity, current_tick - last_tick)
246
+ # track.append(Message('note_' + event['type'], channel=chan, note=pitch, velocity=max(0, velocity), time=current_tick - last_tick))
247
+ # if velocity >= 0:
248
+ # raise e
249
+ last_tick = current_tick
250
+
251
+
252
+ def append_track(file, pitches, intervals, velocities):
253
+ track = MidiTrack()
254
+ file.tracks.append(track)
255
+ ticks_per_second = file.ticks_per_beat * 2.0
256
+
257
+ events = []
258
+ for i in range(len(pitches)):
259
+ events.append(
260
+ dict(
261
+ type="on",
262
+ pitch=pitches[i],
263
+ time=intervals[i][0],
264
+ velocity=velocities[i],
265
+ )
266
+ )
267
+ events.append(
268
+ dict(
269
+ type="off",
270
+ pitch=pitches[i],
271
+ time=intervals[i][1],
272
+ velocity=velocities[i],
273
+ )
274
+ )
275
+ events.sort(key=lambda row: row["time"])
276
+
277
+ last_tick = 0
278
+ for event in events:
279
+ current_tick = int(event["time"] * ticks_per_second)
280
+ velocity = int(event["velocity"] * 127)
281
+ if velocity > 127:
282
+ velocity = 127
283
+ pitch = int(round(hz_to_midi(event["pitch"])))
284
+ try:
285
+ track.append(
286
+ Message(
287
+ "note_" + event["type"],
288
+ note=pitch,
289
+ velocity=velocity,
290
+ time=current_tick - last_tick,
291
+ )
292
+ )
293
+ except Exception as e:
294
+ print(
295
+ "Err Message",
296
+ "note_" + event["type"],
297
+ pitch,
298
+ velocity,
299
+ current_tick - last_tick,
300
+ )
301
+ track.append(
302
+ Message(
303
+ "note_" + event["type"],
304
+ note=pitch,
305
+ velocity=max(0, velocity),
306
+ time=current_tick - last_tick,
307
+ )
308
+ )
309
+ if velocity >= 0:
310
+ raise e
311
+ last_tick = current_tick
312
+
313
+
314
+ def save_midi(path, pitches, intervals, velocities, insts=None):
315
+ """
316
+ Save extracted notes as a MIDI file
317
+ Parameters
318
+ ----------
319
+ path: the path to save the MIDI file
320
+ pitches: np.ndarray of bin_indices
321
+ intervals: list of (onset_index, offset_index)
322
+ velocities: list of velocity values
323
+ """
324
+ file = MidiFile()
325
+ if isinstance(pitches, list):
326
+ for p, i, v, ins in zip(pitches, intervals, velocities, insts):
327
+ append_track_multi(file, p, i, v, ins)
328
+ else:
329
+ append_track(file, pitches, intervals, velocities)
330
+ file.save(path)
331
+
332
+
333
+ def frames2midi(
334
+ save_path,
335
+ onsets,
336
+ frames,
337
+ vels,
338
+ onset_threshold=0.5,
339
+ frame_threshold=0.5,
340
+ scaling=HOP_LENGTH / SAMPLE_RATE,
341
+ inst_mapping=None,
342
+ onset_threshold_vec=None,
343
+ ):
344
+ p_est, i_est, v_est, inst_est = extract_notes_np(
345
+ onsets,
346
+ frames,
347
+ vels,
348
+ onset_threshold,
349
+ frame_threshold,
350
+ onset_threshold_vec=onset_threshold_vec,
351
+ )
352
+ i_est = (i_est * scaling).reshape(-1, 2)
353
+
354
+ p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
355
+
356
+ inst_set = set(inst_est)
357
+ inst_set = sorted(list(inst_set))
358
+
359
+ p_est_lst = {}
360
+ i_est_lst = {}
361
+ v_est_lst = {}
362
+ assert len(p_est) == len(i_est) == len(v_est) == len(inst_est)
363
+ for p, i, v, ins in zip(p_est, i_est, v_est, inst_est):
364
+ if ins in p_est_lst:
365
+ p_est_lst[ins].append(p)
366
+ else:
367
+ p_est_lst[ins] = [p]
368
+ if ins in i_est_lst:
369
+ i_est_lst[ins].append(i)
370
+ else:
371
+ i_est_lst[ins] = [i]
372
+ if ins in v_est_lst:
373
+ v_est_lst[ins].append(v)
374
+ else:
375
+ v_est_lst[ins] = [v]
376
+ for elem in [p_est_lst, i_est_lst, v_est_lst]:
377
+ for k, v in elem.items():
378
+ elem[k] = np.array(v)
379
+ inst_set = [e for e in inst_set if e in p_est_lst]
380
+ # inst_set = [INSTRUMENT_MAPPING[e] for e in inst_set if e in p_est_lst]
381
+ p_est_lst = [p_est_lst[ins] for ins in inst_set if ins in p_est_lst]
382
+ i_est_lst = [i_est_lst[ins] for ins in inst_set if ins in i_est_lst]
383
+ v_est_lst = [v_est_lst[ins] for ins in inst_set if ins in v_est_lst]
384
+ assert len(p_est_lst) == len(i_est_lst) == len(v_est_lst) == len(inst_set)
385
+ inst_set = [inst_mapping[e] for e in inst_set]
386
+ save_midi(save_path, p_est_lst, i_est_lst, v_est_lst, inst_set)
387
+
388
+
389
+ def frames2midi_pitch(
390
+ save_path,
391
+ onsets,
392
+ frames,
393
+ vels,
394
+ onset_threshold=0.5,
395
+ frame_threshold=0.5,
396
+ scaling=HOP_LENGTH / SAMPLE_RATE,
397
+ ):
398
+ p_est, i_est, v_est = extract_notes_np_pitch(
399
+ onsets, frames, vels, onset_threshold, frame_threshold
400
+ )
401
+ i_est = (i_est * scaling).reshape(-1, 2)
402
+ p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
403
+ print("Saving midi in", save_path)
404
+ save_midi(save_path, p_est, i_est, v_est)
405
+
406
+
407
+ def parse_midi_multi(path, force_instrument=None):
408
+ """open midi file and return np.array of (onset, offset, note, velocity, instrument) rows"""
409
+ try:
410
+ midi = mido.MidiFile(path)
411
+ except:
412
+ print("could not open midi", path)
413
+ return
414
+
415
+ time = 0
416
+
417
+ events = []
418
+
419
+ control_changes = []
420
+ program_changes = []
421
+
422
+ sustain = {}
423
+
424
+ all_channels = set()
425
+
426
+ instruments = {} # mapping of channel: instrument
427
+
428
+ for message in midi:
429
+ time += message.time
430
+ if hasattr(message, "channel"):
431
+ if message.channel == DRUM_CHANNEL:
432
+ continue
433
+
434
+ if (
435
+ message.type == "control_change"
436
+ and message.control == 64
437
+ and (message.value >= 64) != sustain.get(message.channel, False)
438
+ ):
439
+ sustain[message.channel] = message.value >= 64
440
+ event_type = "sustain_on" if sustain[message.channel] else "sustain_off"
441
+ event = dict(
442
+ index=len(events), time=time, type=event_type, note=None, velocity=0
443
+ )
444
+ event["channel"] = message.channel
445
+ event["sustain"] = sustain[message.channel]
446
+ events.append(event)
447
+
448
+ if message.type == "control_change" and message.control != 64:
449
+ control_changes.append(
450
+ (time, message.control, message.value, message.channel)
451
+ )
452
+
453
+ if message.type == "program_change":
454
+ program_changes.append((time, message.program, message.channel))
455
+ instruments[message.channel] = instruments.get(message.channel, []) + [
456
+ (message.program, time)
457
+ ]
458
+
459
+ if "note" in message.type:
460
+ # MIDI offsets can be either 'note_off' events or 'note_on' with zero velocity
461
+ velocity = message.velocity if message.type == "note_on" else 0
462
+ event = dict(
463
+ index=len(events),
464
+ time=time,
465
+ type="note",
466
+ note=message.note,
467
+ velocity=velocity,
468
+ sustain=sustain.get(message.channel, False),
469
+ )
470
+ event["channel"] = message.channel
471
+ events.append(event)
472
+
473
+ if hasattr(message, "channel"):
474
+ all_channels.add(message.channel)
475
+
476
+ if len(instruments) == 0:
477
+ instruments = {c: [(0, 0)] for c in all_channels}
478
+ if len(all_channels) > len(instruments):
479
+ for e in all_channels - set(instruments.keys()):
480
+ instruments[e] = [(0, 0)]
481
+
482
+ if force_instrument is not None:
483
+ instruments = {c: [(force_instrument, 0)] for c in all_channels}
484
+
485
+ this_instruments = set()
486
+ for v in instruments.values():
487
+ this_instruments = this_instruments.union(set(x[0] for x in v))
488
+
489
+ notes = []
490
+ for i, onset in enumerate(events):
491
+ if onset["velocity"] == 0:
492
+ continue
493
+ offset = next(
494
+ n
495
+ for n in events[i + 1 :]
496
+ if (n["note"] == onset["note"] and n["channel"] == onset["channel"])
497
+ or n is events[-1]
498
+ )
499
+ if "sustain" not in offset:
500
+ print("offset without sustain", offset)
501
+ if offset["sustain"] and offset is not events[-1]:
502
+ # if the sustain pedal is active at offset, find when the sustain ends
503
+ offset = next(
504
+ n
505
+ for n in events[offset["index"] + 1 :]
506
+ if (n["type"] == "sustain_off" and n["channel"] == onset["channel"])
507
+ or n is events[-1]
508
+ )
509
+ for k, v in instruments.items():
510
+ if len(set(v)) == 1 and len(v) > 1:
511
+ instruments[k] = list(set(v))
512
+ for k, v in instruments.items():
513
+ instruments[k] = sorted(v, key=lambda x: x[1])
514
+ if len(instruments[onset["channel"]]) == 1:
515
+ instrument = instruments[onset["channel"]][0][0]
516
+ else:
517
+ ind = 0
518
+ while (
519
+ ind < len(instruments[onset["channel"]])
520
+ and onset["time"] >= instruments[onset["channel"]][ind][1]
521
+ ):
522
+ ind += 1
523
+ if ind > 0:
524
+ ind -= 1
525
+ instrument = instruments[onset["channel"]][ind][0]
526
+ if onset["channel"] == DRUM_CHANNEL:
527
+ print("skipping drum note")
528
+ continue
529
+ note = (
530
+ onset["time"],
531
+ offset["time"],
532
+ onset["note"],
533
+ onset["velocity"],
534
+ instrument,
535
+ )
536
+ notes.append(note)
537
+
538
+ res = np.array(notes)
539
+ return res
540
+
541
+
542
+ def save_midi_alignments_and_predictions(
543
+ save_path,
544
+ data_path,
545
+ inst_mapping,
546
+ aligned_onsets,
547
+ aligned_frames,
548
+ onset_pred_np,
549
+ frame_pred_np,
550
+ prefix="",
551
+ use_time=True,
552
+ group=None,
553
+ ):
554
+ inst_only = len(inst_mapping) * N_KEYS
555
+ time_now = datetime.now().strftime("%y%m%d-%H%M%S") if use_time else ""
556
+ if len(prefix) > 0:
557
+ prefix = "_{}".format(prefix)
558
+
559
+ # Save the aligned label. If training on a small dataset or a single performance in order to label it for later adding it
560
+ # to a large dataset, it is recommended to use this MIDI as a label.
561
+ frames2midi(
562
+ save_path
563
+ + os.sep
564
+ + data_path.replace(".flac", "").split(os.sep)[-1]
565
+ + prefix
566
+ + "_alignment_"
567
+ + time_now
568
+ + ".mid",
569
+ aligned_onsets[:, :inst_only],
570
+ aligned_frames[:, :inst_only],
571
+ 64.0 * aligned_onsets[:, :inst_only],
572
+ inst_mapping=inst_mapping,
573
+ )
574
+ return
575
+
576
+ # # Aligned label, pitch-only, on the piano.
577
+ # frames2midi_pitch(save_path + os.sep + data_path.replace('.flac', '').split(os.sep)[-1] + prefix + '_alignment_pitch_' + time_now + '.mid',
578
+ # aligned_onsets[:, -N_KEYS:], aligned_frames[:, -N_KEYS:],
579
+ # 64. * aligned_onsets[:, -N_KEYS:])
580
+
581
+ predicted_onsets = onset_pred_np >= 0.5
582
+ predicted_frames = frame_pred_np >= 0.5
583
+
584
+ # # Raw pitch with instrument prediction - will probably have lower recall, depending on the model's strength.
585
+ # frames2midi(save_path + os.sep + data_path.replace('.flac', '').split(os.sep)[-1] + prefix + '_pred_' + time_now + '.mid',
586
+ # predicted_onsets[:, : inst_only], predicted_frames[:, : inst_only],
587
+ # 64. * predicted_onsets[:, : inst_only],
588
+ # inst_mapping=inst_mapping)
589
+
590
+ # Pitch prediction played on the piano - will have high recall, since it does not differentiate between instruments.
591
+ frames2midi_pitch(
592
+ save_path
593
+ + os.sep
594
+ + data_path.replace(".flac", "").split(os.sep)[-1]
595
+ + prefix
596
+ + "_pred_pitch_"
597
+ + time_now
598
+ + ".mid",
599
+ predicted_onsets[:, -N_KEYS:],
600
+ predicted_frames[:, -N_KEYS:],
601
+ 64.0 * predicted_onsets[:, -N_KEYS:],
602
+ )
603
+
604
+ # Pitch prediction, with choice of most likely instrument for each detected note.
605
+ if len(inst_mapping) > 1:
606
+ max_pred_onsets = max_inst(onset_pred_np)
607
+ frames2midi(
608
+ save_path
609
+ + os.sep
610
+ + data_path.replace(".flac", "").split(os.sep)[-1]
611
+ + prefix
612
+ + "_pred_inst_"
613
+ + time_now
614
+ + ".mid",
615
+ max_pred_onsets[:, :inst_only],
616
+ predicted_frames[:, :inst_only],
617
+ 64.0 * max_pred_onsets[:, :inst_only],
618
+ inst_mapping=inst_mapping,
619
+ )
620
+
621
+ pseudo_onsets = (onset_pred_np >= 0.5) & (~aligned_onsets)
622
+ onset_label = np.maximum(pseudo_onsets, aligned_onsets)
623
+
624
+ pseudo_frames = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
625
+ for t, f in zip(*onset_label.nonzero()):
626
+ t_off = t
627
+ while t_off < len(pseudo_frames) and frame_pred_np[t_off, f % N_KEYS] >= 0.5:
628
+ t_off += 1
629
+ pseudo_frames[t:t_off, f] = 1
630
+ frame_label = np.maximum(pseudo_frames, aligned_frames)
631
+
632
+ # pseudo_frames = (frame_pred_np >= 0.5) & (~aligned_frames)
633
+ # frame_label = np.maximum(pseudo_frames, aligned_frames)
634
+
635
+ frames2midi(
636
+ save_path
637
+ + os.sep
638
+ + data_path.replace(".flac", "").split(os.sep)[-1]
639
+ + prefix
640
+ + "_pred_align_max_"
641
+ + time_now
642
+ + ".mid",
643
+ onset_label[:, :inst_only],
644
+ frame_label[:, :inst_only],
645
+ 64.0 * onset_label[:, :inst_only],
646
+ inst_mapping=inst_mapping,
647
+ )
648
+ # if group is not None:
649
+ # gorup_path = os.path.join(save_path, 'pred_alignment_max', group)
650
+ # file_name = os.path.basename(data_path).replace('.flac', '_pred_align_max.mid')
651
+ # os.makedirs(gorup_path, exist_ok=True)
652
+ # frames2midi(os.path.join(gorup_path, file_name),
653
+ # onset_label[:, : inst_only], frame_label[:, : inst_only],
654
+ # 64. * onset_label[:, : inst_only],
655
+ # inst_mapping=inst_mapping)
onsets_and_frames/transcriber.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from onsets_and_frames.constants import MAX_MIDI, MIN_MIDI, N_KEYS
6
+
7
+ from .lstm import BiLSTM
8
+ from .mel import melspectrogram
9
+
10
+
11
+ class ConvStack(nn.Module):
12
+ def __init__(self, input_features, output_features):
13
+ super().__init__()
14
+
15
+ # input is batch_size * 1 channel * frames * input_features
16
+ self.cnn = nn.Sequential(
17
+ # layer 0
18
+ nn.Conv2d(1, output_features // 16, (3, 3), padding=1),
19
+ nn.BatchNorm2d(output_features // 16),
20
+ nn.ReLU(),
21
+ # layer 1
22
+ nn.Conv2d(output_features // 16, output_features // 16, (3, 3), padding=1),
23
+ nn.BatchNorm2d(output_features // 16),
24
+ nn.ReLU(),
25
+ # layer 2
26
+ nn.MaxPool2d((1, 2)),
27
+ nn.Dropout(0.25),
28
+ nn.Conv2d(output_features // 16, output_features // 8, (3, 3), padding=1),
29
+ nn.BatchNorm2d(output_features // 8),
30
+ nn.ReLU(),
31
+ # layer 3
32
+ nn.MaxPool2d((1, 2)),
33
+ nn.Dropout(0.25),
34
+ )
35
+ self.fc = nn.Sequential(
36
+ nn.Linear((output_features // 8) * (input_features // 4), output_features),
37
+ nn.Dropout(0.5),
38
+ )
39
+
40
+ def forward(self, mel):
41
+ x = mel.view(mel.size(0), 1, mel.size(1), mel.size(2))
42
+ x = self.cnn(x)
43
+ x = x.transpose(1, 2).flatten(-2)
44
+ x = self.fc(x)
45
+ return x
46
+
47
+
48
+ class OnsetsAndFrames(nn.Module):
49
+ def __init__(
50
+ self,
51
+ input_features,
52
+ output_features,
53
+ model_complexity=48,
54
+ onset_complexity=1,
55
+ n_instruments=13,
56
+ ):
57
+ nn.Module.__init__(self)
58
+ model_size = model_complexity * 16
59
+ sequence_model = lambda input_size, output_size: BiLSTM(
60
+ input_size, output_size // 2
61
+ )
62
+
63
+ onset_model_size = int(onset_complexity * model_size)
64
+ self.onset_stack = nn.Sequential(
65
+ ConvStack(input_features, onset_model_size),
66
+ sequence_model(onset_model_size, onset_model_size),
67
+ nn.Linear(onset_model_size, output_features * n_instruments),
68
+ nn.Sigmoid(),
69
+ )
70
+ self.offset_stack = nn.Sequential(
71
+ ConvStack(input_features, model_size),
72
+ sequence_model(model_size, model_size),
73
+ nn.Linear(model_size, output_features),
74
+ nn.Sigmoid(),
75
+ )
76
+ self.frame_stack = nn.Sequential(
77
+ ConvStack(input_features, model_size),
78
+ nn.Linear(model_size, output_features),
79
+ nn.Sigmoid(),
80
+ )
81
+ self.combined_stack = nn.Sequential(
82
+ sequence_model(output_features * 3, model_size),
83
+ nn.Linear(model_size, output_features),
84
+ nn.Sigmoid(),
85
+ )
86
+ self.velocity_stack = nn.Sequential(
87
+ ConvStack(input_features, model_size),
88
+ nn.Linear(model_size, output_features * n_instruments),
89
+ )
90
+
91
+ def forward(self, mel):
92
+ onset_pred = self.onset_stack(mel)
93
+ offset_pred = self.offset_stack(mel)
94
+ activation_pred = self.frame_stack(mel)
95
+
96
+ onset_detached = onset_pred.detach()
97
+ shape = onset_detached.shape
98
+ keys = MAX_MIDI - MIN_MIDI + 1
99
+ new_shape = shape[:-1] + (shape[-1] // keys, keys)
100
+ onset_detached = onset_detached.reshape(new_shape)
101
+ onset_detached, _ = onset_detached.max(axis=-2)
102
+
103
+ offset_detached = offset_pred.detach()
104
+
105
+ combined_pred = torch.cat(
106
+ [onset_detached, offset_detached, activation_pred], dim=-1
107
+ )
108
+ frame_pred = self.combined_stack(combined_pred)
109
+ velocity_pred = self.velocity_stack(mel)
110
+ return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred
111
+
112
+ def run_on_batch(
113
+ self,
114
+ batch,
115
+ parallel_model=None,
116
+ positive_weight=2.0,
117
+ inv_positive_weight=2.0,
118
+ with_onset_mask=False,
119
+ ):
120
+ audio_label = batch["audio"]
121
+
122
+ onset_label = batch["onset"]
123
+ offset_label = batch["offset"]
124
+ frame_label = batch["frame"]
125
+ if "velocity" in batch:
126
+ velocity_label = batch["velocity"]
127
+ mel = melspectrogram(
128
+ audio_label.reshape(-1, audio_label.shape[-1])[:, :-1]
129
+ ).transpose(-1, -2)
130
+
131
+ if not parallel_model:
132
+ onset_pred, offset_pred, _, frame_pred, velocity_pred = self(mel)
133
+ else:
134
+ onset_pred, offset_pred, _, frame_pred, velocity_pred = parallel_model(mel)
135
+
136
+ predictions = {
137
+ "onset": onset_pred.reshape(*onset_label.shape),
138
+ "offset": offset_pred.reshape(*offset_label.shape),
139
+ "frame": frame_pred.reshape(*frame_label.shape),
140
+ # 'velocity': velocity_pred.reshape(*velocity_label.shape)
141
+ }
142
+
143
+ if "velocity" in batch:
144
+ predictions["velocity"] = velocity_pred.reshape(*velocity_label.shape)
145
+
146
+ losses = {
147
+ "loss/onset": F.binary_cross_entropy(
148
+ predictions["onset"], onset_label, reduction="none"
149
+ ),
150
+ "loss/offset": F.binary_cross_entropy(
151
+ predictions["offset"], offset_label, reduction="none"
152
+ ),
153
+ "loss/frame": F.binary_cross_entropy(
154
+ predictions["frame"], frame_label, reduction="none"
155
+ ),
156
+ # 'loss/velocity': self.velocity_loss(predictions['velocity'], velocity_label, onset_label)
157
+ }
158
+ if "velocity" in batch:
159
+ losses["loss/velocity"] = self.velocity_loss(
160
+ predictions["velocity"], velocity_label, onset_label
161
+ )
162
+
163
+ onset_mask = 1.0 * onset_label
164
+ onset_mask[..., :-N_KEYS] *= positive_weight - 1
165
+ onset_mask[..., -N_KEYS:] *= inv_positive_weight - 1
166
+ onset_mask += 1
167
+ if with_onset_mask:
168
+ if "onset_mask" in batch:
169
+ onset_mask = onset_mask * batch["onset_mask"]
170
+ # if 'onset_mask' in batch:
171
+ # onset_mask += batch['onset_mask']
172
+
173
+ offset_mask = 1.0 * offset_label
174
+ offset_positive_weight = 2.0
175
+ offset_mask *= offset_positive_weight - 1
176
+ offset_mask += 1.0
177
+
178
+ frame_mask = 1.0 * frame_label
179
+ frame_positive_weight = 2.0
180
+ frame_mask *= frame_positive_weight - 1
181
+ frame_mask += 1.0
182
+
183
+ for loss_key, mask in zip(
184
+ ["onset", "offset", "frame"], [onset_mask, offset_mask, frame_mask]
185
+ ):
186
+ losses["loss/" + loss_key] = (mask * losses["loss/" + loss_key]).mean()
187
+
188
+ return predictions, losses
189
+
190
+ def velocity_loss(self, velocity_pred, velocity_label, onset_label):
191
+ denominator = onset_label.sum()
192
+ if denominator.item() == 0:
193
+ return denominator
194
+ else:
195
+ return (
196
+ onset_label * (velocity_label - velocity_pred) ** 2
197
+ ).sum() / denominator
198
+
199
+
200
+ # same implementation as OnsetsAndFrames, but with only onset stack
201
+ class OnsetsNoFrames(nn.Module):
202
+ def __init__(
203
+ self,
204
+ input_features,
205
+ output_features,
206
+ model_complexity=48,
207
+ onset_complexity=1,
208
+ n_instruments=13,
209
+ ):
210
+ nn.Module.__init__(self)
211
+ model_size = model_complexity * 16
212
+ sequence_model = lambda input_size, output_size: BiLSTM(
213
+ input_size, output_size // 2
214
+ )
215
+
216
+ onset_model_size = int(onset_complexity * model_size)
217
+ self.onset_stack = nn.Sequential(
218
+ ConvStack(input_features, onset_model_size),
219
+ sequence_model(onset_model_size, onset_model_size),
220
+ nn.Linear(onset_model_size, output_features * n_instruments),
221
+ nn.Sigmoid(),
222
+ )
223
+
224
+ def forward(self, mel):
225
+ onset_pred = self.onset_stack(mel)
226
+
227
+ onset_detached = onset_pred.detach()
228
+ shape = onset_detached.shape
229
+ keys = MAX_MIDI - MIN_MIDI + 1
230
+ new_shape = shape[:-1] + (shape[-1] // keys, keys)
231
+ onset_detached = onset_detached.reshape(new_shape)
232
+ onset_detached, _ = onset_detached.max(axis=-2)
233
+
234
+ return onset_pred
235
+
236
+ def run_on_batch(
237
+ self,
238
+ batch,
239
+ parallel_model=None,
240
+ positive_weight=2.0,
241
+ inv_positive_weight=2.0,
242
+ with_onset_mask=False,
243
+ ):
244
+ audio_label = batch["audio"]
245
+
246
+ onset_label = batch["onset"]
247
+ mel = melspectrogram(
248
+ audio_label.reshape(-1, audio_label.shape[-1])[:, :-1]
249
+ ).transpose(-1, -2)
250
+
251
+ if not parallel_model:
252
+ onset_pred = self(mel)
253
+ else:
254
+ onset_pred = parallel_model(mel)
255
+
256
+ predictions = {
257
+ "onset": onset_pred,
258
+ }
259
+
260
+ losses = {
261
+ "loss/onset": F.binary_cross_entropy(
262
+ predictions["onset"], onset_label, reduction="none"
263
+ ),
264
+ }
265
+
266
+ onset_mask = 1.0 * onset_label
267
+ onset_mask[..., :-N_KEYS] *= positive_weight - 1
268
+ onset_mask[..., -N_KEYS:] *= inv_positive_weight - 1
269
+ onset_mask += 1
270
+ if with_onset_mask:
271
+ if "onset_mask" in batch:
272
+ onset_mask = onset_mask * batch["onset_mask"]
273
+
274
+ losses["loss/onset"] = (onset_mask * losses["loss/onset"]).mean()
275
+
276
+ return predictions, losses
onsets_and_frames/utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from onsets_and_frames.constants import (
9
+ DTW_FACTOR,
10
+ HOP_LENGTH,
11
+ MAX_MIDI,
12
+ MIN_MIDI,
13
+ N_KEYS,
14
+ )
15
+
16
+
17
+ def cycle(iterable):
18
+ while True:
19
+ for item in iterable:
20
+ yield item
21
+
22
+
23
+ def shift_label(label, shift):
24
+ if shift == 0:
25
+ return label
26
+ assert len(label.shape) == 2
27
+ t, p = label.shape
28
+ keys, instruments = N_KEYS, p // N_KEYS
29
+ label_zero_pad = torch.zeros(t, instruments, abs(shift), dtype=label.dtype)
30
+ label = label.reshape(t, instruments, keys)
31
+ to_cat = (
32
+ (label_zero_pad, label[:, :, :-shift])
33
+ if shift > 0
34
+ else (label[:, :, -shift:], label_zero_pad)
35
+ )
36
+ label = torch.cat(to_cat, dim=-1)
37
+ return label.reshape(t, p)
38
+
39
+
40
+ def get_peaks(notes, win_size, gpu=False):
41
+ constraints = []
42
+ notes = notes.cpu()
43
+ for i in range(1, win_size + 1):
44
+ forward = torch.roll(notes, i, 0)
45
+ forward[:i, ...] = 0 # assume time axis is 0
46
+ backward = torch.roll(notes, -i, 0)
47
+ backward[-i:, ...] = 0
48
+ constraints.extend([forward, backward])
49
+ res = torch.ones(notes.shape, dtype=bool)
50
+ for elem in constraints:
51
+ res = res & (notes >= elem)
52
+ return res if not gpu else res.cuda()
53
+
54
+
55
+ def get_peaks_numpy(notes, win_size):
56
+ """
57
+ Detect peaks in a NumPy array based on a window size.
58
+
59
+ Args:
60
+ notes (np.ndarray): Input array, shape (frames, ...).
61
+ win_size (int): Window size for detecting peaks.
62
+
63
+ Returns:
64
+ np.ndarray: Boolean array indicating peaks, same shape as `notes`.
65
+ """
66
+ # Initialize constraints
67
+ constraints = []
68
+ notes = np.array(notes) # Ensure input is a NumPy array
69
+
70
+ for i in range(1, win_size + 1):
71
+ # Roll array forward and backward
72
+ forward = np.roll(notes, i, axis=0)
73
+ backward = np.roll(notes, -i, axis=0)
74
+
75
+ # Zero out invalid regions
76
+ forward[:i, ...] = 0
77
+ backward[-i:, ...] = 0
78
+
79
+ constraints.extend([forward, backward])
80
+
81
+ # Initialize result with all True
82
+ res = np.ones_like(notes, dtype=bool)
83
+
84
+ # Apply constraints
85
+ for elem in constraints:
86
+ res &= notes >= elem
87
+
88
+ return res
89
+
90
+
91
+ def get_diff(notes, offset=True):
92
+ rolled = np.roll(notes, 1, axis=0)
93
+ rolled[0, ...] = 0
94
+ return (rolled & (~notes)) if offset else (notes & (~rolled))
95
+
96
+
97
+ def compress_across_octave(notes):
98
+ keys = MAX_MIDI - MIN_MIDI + 1
99
+ time, instruments = notes.shape[0], notes.shape[1] // keys
100
+ notes_reshaped = notes.reshape((time, instruments, keys))
101
+ notes_reshaped = notes_reshaped.max(axis=1)
102
+ octaves = keys // 12
103
+ res = np.zeros((time, 12), dtype=np.uint8)
104
+ for i in range(octaves):
105
+ curr_octave = notes_reshaped[:, i * 12 : (i + 1) * 12]
106
+ res = np.maximum(res, curr_octave)
107
+ return res
108
+
109
+
110
+ def compress_time(notes, factor):
111
+ t, p = notes.shape
112
+ res = np.zeros((t // factor, p), dtype=notes.dtype)
113
+ for i in range(t // factor):
114
+ res[i, :] = notes[i * factor : (i + 1) * factor, :].max(axis=0)
115
+ return res
116
+
117
+
118
+ def get_matches(index1, index2):
119
+ matches = {}
120
+ for i1, i2 in zip(index1, index2):
121
+ # matches[i1] = matches.get(i1, []) + [i2]
122
+ if i1 not in matches:
123
+ matches[i1] = []
124
+ matches[i1].append(i2)
125
+ return matches
126
+
127
+
128
+ """
129
+ Extend a temporal range to WINDOW_SIZE_SRC if it is shorter than that.
130
+ WINDOW_SIZE_SRC defaults to 28 frames for 256 hop length (assuming DTW_FACTOR=3), which is ~0.5 second.
131
+ """
132
+
133
+
134
+ def get_margin(
135
+ t_sources, max_len, WINDOW_SIZE_SRC=11 * (512 // HOP_LENGTH) + 2 * DTW_FACTOR
136
+ ):
137
+ margin = max(0, (WINDOW_SIZE_SRC - len(t_sources)) // 2)
138
+ t_sources_left = list(range(max(t_sources[0] - margin, 0), t_sources[0]))
139
+ t_sources_right = list(
140
+ range(t_sources[-1], min(t_sources[-1] + margin, max_len - 1))
141
+ )
142
+ t_sources_extended = t_sources_left + t_sources + t_sources_right
143
+ return t_sources_extended
144
+
145
+
146
+ def get_inactive_instruments(target_onsets, T):
147
+ keys = MAX_MIDI - MIN_MIDI + 1
148
+ time, instruments = target_onsets.shape[0], target_onsets.shape[1] // keys
149
+ notes_reshaped = target_onsets.reshape((time, instruments, keys))
150
+ active_instruments = notes_reshaped.max(axis=(0, 2))
151
+ res = np.zeros((T, instruments, keys), dtype=bool)
152
+ for ins in range(instruments):
153
+ if active_instruments[ins] == 0:
154
+ res[:, ins, :] = 1
155
+ return res.reshape((T, instruments * keys)), active_instruments
156
+
157
+
158
+ def max_inst(probs, threshold_vec=None):
159
+ if threshold_vec is None:
160
+ threshold_vec = 0.5
161
+ if probs.shape[-1] == N_KEYS or probs.shape[-1] == N_KEYS * 2:
162
+ # there is only pitch
163
+ return probs
164
+ keys = MAX_MIDI - MIN_MIDI + 1
165
+ instruments = probs.shape[1] // keys
166
+ time = len(probs)
167
+ probs = probs.reshape((time, instruments, keys))
168
+ notes = probs.max(axis=1) >= threshold_vec
169
+ max_instruments = np.argmax(probs[:, :-1, :], axis=1)
170
+ res = np.zeros(probs.shape, dtype=np.uint8)
171
+ for t, p in zip(*(notes.nonzero())):
172
+ res[t, max_instruments[t, p], p] = 1
173
+ res[t, -1, p] = 1
174
+ return res.reshape((time, instruments * keys))
175
+
176
+
177
+ # Define the smoothing function (operates on CPU)
178
+ def smooth_labels(onset_tensor):
179
+ """
180
+ Smooths onset labels using a triangular kernel with 1D convolution along the time axis.
181
+
182
+ Args:
183
+ onset_tensor (torch.Tensor): A (T, F) tensor where T = time steps and F = pitches.
184
+
185
+ Returns:
186
+ torch.Tensor: Smoothed onset tensor with the same shape (T, F).
187
+ """
188
+ # Define the triangular smoothing kernel
189
+ # kernel = torch.tensor([0.2, 0.4, 0.6, 0.8, 1, 0.8, 0.6, 0.4, 0.2],
190
+ # dtype=onset_tensor.dtype).view(1, 1, -1)
191
+ # kernel = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
192
+ # dtype=onset_tensor.dtype).view(1, 1, -1)
193
+ kernel = torch.tensor([0.33, 0.67, 1, 0.67, 0.33], dtype=onset_tensor.dtype).view(
194
+ 1, 1, -1
195
+ )
196
+
197
+ onset_tensor = onset_tensor.T.unsqueeze(1) # Now shape is (F, 1, T)
198
+
199
+ # Use 'same' padding so that the output has the same time dimension as the input.
200
+ padding = kernel.shape[-1] // 2
201
+ smoothed = F.conv1d(onset_tensor, kernel, padding=padding)
202
+
203
+ # Reshape back to original shape (T, F)
204
+ return smoothed.squeeze(1).T
205
+
206
+
207
+ def initialize_logging_system(logdir):
208
+ """Initialize the logging system once with named loggers for train and dataset."""
209
+ log_file = os.path.join(logdir, "training.log")
210
+
211
+ # Create formatter
212
+ formatter = logging.Formatter(
213
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
214
+ )
215
+
216
+ # File handler (shared by all loggers)
217
+ file_handler = logging.FileHandler(log_file)
218
+ file_handler.setLevel(logging.INFO)
219
+ file_handler.setFormatter(formatter)
220
+
221
+ # Console handler (shared by all loggers)
222
+ console_handler = logging.StreamHandler()
223
+ console_handler.setLevel(logging.INFO)
224
+ console_handler.setFormatter(formatter)
225
+
226
+ # Create train logger
227
+ train_logger = logging.getLogger("train")
228
+ train_logger.setLevel(logging.INFO)
229
+ train_logger.handlers.clear()
230
+ train_logger.addHandler(file_handler)
231
+ train_logger.addHandler(console_handler)
232
+
233
+ # Create dataset logger
234
+ dataset_logger = logging.getLogger("dataset")
235
+ dataset_logger.setLevel(logging.INFO)
236
+ dataset_logger.handlers.clear()
237
+ dataset_logger.addHandler(file_handler)
238
+ dataset_logger.addHandler(console_handler)
239
+
240
+ return train_logger, dataset_logger
241
+
242
+
243
+ def get_logger(name):
244
+ """Get a named logger. Call initialize_logging_system first."""
245
+ return logging.getLogger(name)