Spaces:
Running
on
Zero
Running
on
Zero
added source code of model and transcription scripts
Browse files- onsets_and_frames/__init__.py +5 -0
- onsets_and_frames/constants.py +26 -0
- onsets_and_frames/dataset.py +719 -0
- onsets_and_frames/decoding.py +102 -0
- onsets_and_frames/hf_model.py +364 -0
- onsets_and_frames/lstm.py +96 -0
- onsets_and_frames/mel.py +136 -0
- onsets_and_frames/midi_utils.py +655 -0
- onsets_and_frames/transcriber.py +276 -0
- onsets_and_frames/utils.py +245 -0
onsets_and_frames/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .constants import *
|
| 2 |
+
from .dataset import EMDATASET
|
| 3 |
+
from .mel import melspectrogram
|
| 4 |
+
from .transcriber import OnsetsAndFrames, OnsetsNoFrames
|
| 5 |
+
from .utils import *
|
onsets_and_frames/constants.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
SAMPLE_RATE = 16000
|
| 5 |
+
HOP_LENGTH = 512
|
| 6 |
+
ONSET_LENGTH = HOP_LENGTH
|
| 7 |
+
OFFSET_LENGTH = HOP_LENGTH
|
| 8 |
+
|
| 9 |
+
HOPS_IN_ONSET = ONSET_LENGTH // HOP_LENGTH
|
| 10 |
+
HOPS_IN_OFFSET = OFFSET_LENGTH // HOP_LENGTH
|
| 11 |
+
MIN_MIDI = 21
|
| 12 |
+
MAX_MIDI = 108
|
| 13 |
+
N_KEYS = MAX_MIDI - MIN_MIDI + 1
|
| 14 |
+
|
| 15 |
+
DTW_FACTOR = 3
|
| 16 |
+
|
| 17 |
+
N_MELS = 229
|
| 18 |
+
MEL_FMIN = 30
|
| 19 |
+
MEL_FMAX = SAMPLE_RATE // 2
|
| 20 |
+
WINDOW_LENGTH = 2048
|
| 21 |
+
|
| 22 |
+
SEQ_LEN = 327680 # 20 seconds
|
| 23 |
+
|
| 24 |
+
DRUM_CHANNEL = 9
|
| 25 |
+
|
| 26 |
+
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
onsets_and_frames/dataset.py
ADDED
|
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import librosa
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from onsets_and_frames import constants
|
| 14 |
+
from onsets_and_frames.constants import DEFAULT_DEVICE, N_KEYS, SAMPLE_RATE
|
| 15 |
+
from onsets_and_frames.mel import melspectrogram
|
| 16 |
+
from onsets_and_frames.midi_utils import (
|
| 17 |
+
midi_to_frames,
|
| 18 |
+
save_midi_alignments_and_predictions,
|
| 19 |
+
)
|
| 20 |
+
from onsets_and_frames.utils import (
|
| 21 |
+
get_diff,
|
| 22 |
+
get_logger,
|
| 23 |
+
get_peaks,
|
| 24 |
+
shift_label,
|
| 25 |
+
smooth_labels,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class EMDATASET(Dataset):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
audio_path="NoteEM_audio",
|
| 33 |
+
tsv_path="NoteEM_tsv",
|
| 34 |
+
labels_path="NoteEm_labels",
|
| 35 |
+
groups=None,
|
| 36 |
+
sequence_length=None,
|
| 37 |
+
seed=42,
|
| 38 |
+
device=DEFAULT_DEVICE,
|
| 39 |
+
instrument_map=None,
|
| 40 |
+
update_instruments=False,
|
| 41 |
+
transcriber=None,
|
| 42 |
+
conversion_map=None,
|
| 43 |
+
pitch_shift=True,
|
| 44 |
+
pitch_shift_limit=5,
|
| 45 |
+
keep_eval_files=False,
|
| 46 |
+
n_eval=1,
|
| 47 |
+
evaluation_list=None,
|
| 48 |
+
only_eval=False,
|
| 49 |
+
save_to_memory=False,
|
| 50 |
+
smooth_labels=False,
|
| 51 |
+
use_onset_mask=False,
|
| 52 |
+
):
|
| 53 |
+
# Get the dataset logger (logging system should already be initialized by train.py)
|
| 54 |
+
self.logger = get_logger("dataset")
|
| 55 |
+
|
| 56 |
+
self.audio_path = audio_path
|
| 57 |
+
self.tsv_path = tsv_path
|
| 58 |
+
self.labels_path = labels_path
|
| 59 |
+
self.sequence_length = sequence_length
|
| 60 |
+
self.device = device
|
| 61 |
+
self.random = np.random.RandomState(seed)
|
| 62 |
+
self.groups = groups
|
| 63 |
+
self.conversion_map = conversion_map
|
| 64 |
+
self.eval_file_list = []
|
| 65 |
+
self.file_list = self.files(
|
| 66 |
+
self.groups,
|
| 67 |
+
pitch_shift=pitch_shift,
|
| 68 |
+
keep_eval_files=keep_eval_files,
|
| 69 |
+
n_eval=n_eval,
|
| 70 |
+
evaluation_list=evaluation_list,
|
| 71 |
+
pitch_shift_limit=pitch_shift_limit,
|
| 72 |
+
)
|
| 73 |
+
self.save_to_memory = save_to_memory
|
| 74 |
+
self.smooth_labels = smooth_labels
|
| 75 |
+
self.use_onset_mask = use_onset_mask
|
| 76 |
+
self.pitch_shift_limit = pitch_shift_limit
|
| 77 |
+
|
| 78 |
+
self.logger.debug("Save to memory is %s", self.save_to_memory)
|
| 79 |
+
self.logger.info("len file list %d", len(self.file_list))
|
| 80 |
+
self.logger.info("\n\n")
|
| 81 |
+
|
| 82 |
+
if instrument_map is None:
|
| 83 |
+
self.get_instruments(conversion_map=conversion_map)
|
| 84 |
+
else:
|
| 85 |
+
self.instruments = instrument_map
|
| 86 |
+
if update_instruments:
|
| 87 |
+
self.add_instruments()
|
| 88 |
+
self.transcriber = transcriber
|
| 89 |
+
if only_eval:
|
| 90 |
+
return
|
| 91 |
+
self.load_pts(self.file_list)
|
| 92 |
+
self.data = []
|
| 93 |
+
self.logger.info("Reading files...")
|
| 94 |
+
for input_files in tqdm(self.file_list, desc="creating data list"):
|
| 95 |
+
flac, _ = input_files
|
| 96 |
+
audio_len = librosa.get_duration(path=flac)
|
| 97 |
+
minutes = int(np.ceil(audio_len / 60))
|
| 98 |
+
copies = minutes
|
| 99 |
+
for _ in range(copies):
|
| 100 |
+
self.data.append(input_files)
|
| 101 |
+
random.shuffle(self.data)
|
| 102 |
+
|
| 103 |
+
def flac_to_pt_path(self, flac):
|
| 104 |
+
pt_fname = os.path.basename(flac).replace(".flac", ".pt")
|
| 105 |
+
pt_path = os.path.join(self.labels_path, pt_fname)
|
| 106 |
+
return pt_path
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.data)
|
| 110 |
+
|
| 111 |
+
def files(
|
| 112 |
+
self,
|
| 113 |
+
groups,
|
| 114 |
+
pitch_shift=True,
|
| 115 |
+
keep_eval_files=False,
|
| 116 |
+
n_eval=1,
|
| 117 |
+
evaluation_list=None,
|
| 118 |
+
pitch_shift_limit=5,
|
| 119 |
+
):
|
| 120 |
+
self.path = self.audio_path
|
| 121 |
+
tsvs_path = self.tsv_path
|
| 122 |
+
self.logger.info("tsv path: %s", tsvs_path)
|
| 123 |
+
self.logger.info("Evaluation list: %s", evaluation_list)
|
| 124 |
+
res = []
|
| 125 |
+
self.logger.info("keep eval files: %s", keep_eval_files)
|
| 126 |
+
self.logger.info("n eval: %d", n_eval)
|
| 127 |
+
for group in groups:
|
| 128 |
+
tsvs = os.listdir(tsvs_path + os.sep + group)
|
| 129 |
+
tsvs = sorted(tsvs)
|
| 130 |
+
if keep_eval_files and evaluation_list is None:
|
| 131 |
+
eval_tsvs = tsvs[:n_eval]
|
| 132 |
+
tsvs = tsvs[n_eval:]
|
| 133 |
+
elif keep_eval_files and evaluation_list is not None:
|
| 134 |
+
eval_tsvs_names = [
|
| 135 |
+
i.split("#")[0].split(".flac")[0].split(".tsv")[0]
|
| 136 |
+
for i in evaluation_list
|
| 137 |
+
]
|
| 138 |
+
eval_tsvs = [
|
| 139 |
+
i
|
| 140 |
+
for i in tsvs
|
| 141 |
+
if i.split("#")[0].split(".tsv")[0] in eval_tsvs_names
|
| 142 |
+
]
|
| 143 |
+
tsvs = [i for i in tsvs if i not in eval_tsvs]
|
| 144 |
+
else:
|
| 145 |
+
eval_tsvs = []
|
| 146 |
+
self.logger.info("len tsvs: %d", len(tsvs))
|
| 147 |
+
|
| 148 |
+
tsvs_names = [t.split(".tsv")[0].split("#")[0] for t in tsvs]
|
| 149 |
+
eval_tsvs_names = [t.split(".tsv")[0].split("#")[0] for t in eval_tsvs]
|
| 150 |
+
for shft in range(-5, 6):
|
| 151 |
+
if shft != 0 and not pitch_shift or abs(shft) > pitch_shift_limit:
|
| 152 |
+
continue
|
| 153 |
+
curr_fls_pth = self.path + os.sep + group + "#{}".format(shft)
|
| 154 |
+
|
| 155 |
+
fls = os.listdir(curr_fls_pth)
|
| 156 |
+
orig_files = fls
|
| 157 |
+
# print(f"files names before\n {fls}")
|
| 158 |
+
fls = [
|
| 159 |
+
i for i in fls if i.split("#")[0] in tsvs_names
|
| 160 |
+
] # in case we dont have the corresponding midi
|
| 161 |
+
missing_fls = [i for i in orig_files if i not in fls]
|
| 162 |
+
if len(missing_fls) > 0:
|
| 163 |
+
self.logger.warning("missing files: %s", missing_fls)
|
| 164 |
+
fls_names = [i.split("#")[0].split(".flac")[0] for i in fls]
|
| 165 |
+
tsvs = [
|
| 166 |
+
i for i in tsvs if i.split(".tsv")[0].split("#")[0] in fls_names
|
| 167 |
+
]
|
| 168 |
+
assert len(tsvs) == len(fls)
|
| 169 |
+
# print(f"files names after\n {fls}")
|
| 170 |
+
fls = sorted(fls)
|
| 171 |
+
|
| 172 |
+
if shft == 0:
|
| 173 |
+
eval_fls = os.listdir(curr_fls_pth)
|
| 174 |
+
# print(f"files names\n {eval_fls}")
|
| 175 |
+
eval_fls = [
|
| 176 |
+
i for i in eval_fls if i.split("#")[0] in eval_tsvs_names
|
| 177 |
+
] # in case we dont have the corresponding midi
|
| 178 |
+
eval_fls_names = [i.split("#")[0] for i in eval_fls]
|
| 179 |
+
eval_tsvs = [
|
| 180 |
+
i
|
| 181 |
+
for i in eval_tsvs
|
| 182 |
+
if i.split(".tsv")[0].split("#")[0] in eval_fls_names
|
| 183 |
+
]
|
| 184 |
+
assert len(eval_fls_names) == len(eval_tsvs_names)
|
| 185 |
+
# print(f"files names\n {eval_fls}")
|
| 186 |
+
eval_fls = sorted(eval_fls)
|
| 187 |
+
for f, t in zip(eval_fls, eval_tsvs):
|
| 188 |
+
self.eval_file_list.append(
|
| 189 |
+
(
|
| 190 |
+
curr_fls_pth + os.sep + f,
|
| 191 |
+
tsvs_path + os.sep + group + os.sep + t,
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
for f, t in zip(fls, tsvs):
|
| 196 |
+
res.append(
|
| 197 |
+
(
|
| 198 |
+
curr_fls_pth + os.sep + f,
|
| 199 |
+
tsvs_path + os.sep + group + os.sep + t,
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
for flac, tsv in res:
|
| 204 |
+
if (
|
| 205 |
+
os.path.basename(flac).split("#")[0].split(".flac")[0]
|
| 206 |
+
!= os.path.basename(tsv).split("#")[0].split(".tsv")[0]
|
| 207 |
+
):
|
| 208 |
+
self.logger.warning("found mismatch in the files: ")
|
| 209 |
+
self.logger.warning("flac: %s", os.path.basename(flac).split("#")[0])
|
| 210 |
+
self.logger.warning("tsv: %s", os.path.basename(tsv).split("#")[0])
|
| 211 |
+
self.logger.warning("please check the input files")
|
| 212 |
+
exit(1)
|
| 213 |
+
return res
|
| 214 |
+
|
| 215 |
+
def get_instruments(self, conversion_map=None):
|
| 216 |
+
instruments = set()
|
| 217 |
+
for _, f in self.file_list:
|
| 218 |
+
events = np.loadtxt(f, delimiter="\t", skiprows=1)
|
| 219 |
+
curr_instruments = set(events[:, -1])
|
| 220 |
+
if conversion_map is not None:
|
| 221 |
+
curr_instruments = {
|
| 222 |
+
conversion_map[c] if c in conversion_map else c
|
| 223 |
+
for c in curr_instruments
|
| 224 |
+
}
|
| 225 |
+
instruments = instruments.union(curr_instruments)
|
| 226 |
+
instruments = [int(elem) for elem in instruments if elem < 115]
|
| 227 |
+
if conversion_map is not None:
|
| 228 |
+
instruments = [i for i in instruments if i in conversion_map]
|
| 229 |
+
instruments = list(set(instruments))
|
| 230 |
+
if 0 in instruments:
|
| 231 |
+
piano_ind = instruments.index(0)
|
| 232 |
+
instruments.pop(piano_ind)
|
| 233 |
+
instruments.insert(0, 0)
|
| 234 |
+
self.instruments = instruments
|
| 235 |
+
self.instruments = list(
|
| 236 |
+
set(self.instruments) - set(range(88, 104)) - set(range(112, 150))
|
| 237 |
+
)
|
| 238 |
+
self.logger.info("Dataset instruments: %s", self.instruments)
|
| 239 |
+
self.logger.info("Total: %d instruments", len(self.instruments))
|
| 240 |
+
|
| 241 |
+
def add_instruments(self):
|
| 242 |
+
for _, f in self.file_list:
|
| 243 |
+
events = np.loadtxt(f, delimiter="\t", skiprows=1)
|
| 244 |
+
curr_instruments = set(events[:, -1])
|
| 245 |
+
new_instruments = curr_instruments - set(self.instruments)
|
| 246 |
+
self.instruments += list(new_instruments)
|
| 247 |
+
instruments = [int(elem) for elem in self.instruments if (elem < 115)]
|
| 248 |
+
self.instruments = instruments
|
| 249 |
+
|
| 250 |
+
def __getitem__(self, index):
|
| 251 |
+
data = self.load(*self.data[index])
|
| 252 |
+
# result = dict(path=data['path'])
|
| 253 |
+
midi_length = len(data["label"])
|
| 254 |
+
n_steps = self.sequence_length // constants.HOP_LENGTH
|
| 255 |
+
if midi_length < n_steps:
|
| 256 |
+
step_begin = 0
|
| 257 |
+
step_end = midi_length
|
| 258 |
+
else:
|
| 259 |
+
step_begin = self.random.randint(max(midi_length - n_steps, 1))
|
| 260 |
+
step_end = step_begin + n_steps
|
| 261 |
+
begin = step_begin * constants.HOP_LENGTH
|
| 262 |
+
end = begin + self.sequence_length
|
| 263 |
+
|
| 264 |
+
audio = (
|
| 265 |
+
data["audio"][begin:end].float().div_(32768.0)
|
| 266 |
+
) # torch.ShortTensor → float
|
| 267 |
+
label = data["label"][step_begin:step_end].clone() # torch.Tensor
|
| 268 |
+
|
| 269 |
+
if audio.shape[0] < self.sequence_length:
|
| 270 |
+
pad_amt = self.sequence_length - audio.shape[0]
|
| 271 |
+
audio = torch.cat([audio, torch.zeros(pad_amt, dtype=audio.dtype)], dim=0)
|
| 272 |
+
|
| 273 |
+
if label.shape[0] < n_steps:
|
| 274 |
+
pad_amt = n_steps - label.shape[0]
|
| 275 |
+
label = torch.cat(
|
| 276 |
+
[label, torch.zeros((pad_amt, *label.shape[1:]), dtype=label.dtype)],
|
| 277 |
+
dim=0,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 281 |
+
result = {"path": data["path"], "audio": audio, "label": label}
|
| 282 |
+
if "velocity" in data:
|
| 283 |
+
result["velocity"] = data["velocity"][step_begin:step_end, ...]
|
| 284 |
+
result["velocity"] = result["velocity"].float() / 128.0
|
| 285 |
+
|
| 286 |
+
if result["label"].max() < 3:
|
| 287 |
+
result["onset"] = result["label"].float()
|
| 288 |
+
else:
|
| 289 |
+
result["onset"] = (result["label"] == 3).float()
|
| 290 |
+
|
| 291 |
+
result["offset"] = (result["label"] == 1).float()
|
| 292 |
+
result["frame"] = (result["label"] > 1).float()
|
| 293 |
+
|
| 294 |
+
if self.smooth_labels:
|
| 295 |
+
result["onset"] = smooth_labels(result["onset"])
|
| 296 |
+
if self.use_onset_mask:
|
| 297 |
+
if "onset_mask" in data:
|
| 298 |
+
result["onset_mask"] = data["onset_mask"][
|
| 299 |
+
step_begin:step_end, ...
|
| 300 |
+
].float()
|
| 301 |
+
else:
|
| 302 |
+
result["onset_mask"] = torch.ones_like(result["onset"]).float()
|
| 303 |
+
if "frame_mask" in data:
|
| 304 |
+
result["frame_mask"] = data["frame_mask"][
|
| 305 |
+
step_begin:step_end, ...
|
| 306 |
+
].float()
|
| 307 |
+
else:
|
| 308 |
+
result["frame_mask"] = torch.ones_like(result["frame"]).float()
|
| 309 |
+
|
| 310 |
+
shape = result["frame"].shape
|
| 311 |
+
keys = N_KEYS
|
| 312 |
+
new_shape = shape[:-1] + (shape[-1] // keys, keys)
|
| 313 |
+
result["big_frame"] = result["frame"]
|
| 314 |
+
result["frame"], _ = result["frame"].reshape(new_shape).max(axis=-2)
|
| 315 |
+
|
| 316 |
+
# if 'frame_mask' not in data:
|
| 317 |
+
# result['frame_mask'] = torch.ones_like(result['frame']).to(self.device).float()
|
| 318 |
+
|
| 319 |
+
result["big_offset"] = result["offset"]
|
| 320 |
+
result["offset"], _ = result["offset"].reshape(new_shape).max(axis=-2)
|
| 321 |
+
result["group"] = self.data[index][0].split(os.sep)[-2].split("#")[0]
|
| 322 |
+
|
| 323 |
+
return result
|
| 324 |
+
|
| 325 |
+
def load(self, audio_path, tsv_path):
|
| 326 |
+
if self.save_to_memory:
|
| 327 |
+
data = self.pts[audio_path]
|
| 328 |
+
else:
|
| 329 |
+
data = torch.load(self.flac_to_pt_path(audio_path))
|
| 330 |
+
if len(data["audio"].shape) > 1:
|
| 331 |
+
data["audio"] = (data["audio"].float().mean(dim=-1)).short()
|
| 332 |
+
if "label" in data:
|
| 333 |
+
return data
|
| 334 |
+
else:
|
| 335 |
+
piece, part = audio_path.split(os.sep)[-2:]
|
| 336 |
+
piece_split = piece.split("#")
|
| 337 |
+
if len(piece_split) == 2:
|
| 338 |
+
piece, shift1 = piece_split
|
| 339 |
+
else:
|
| 340 |
+
piece, shift1 = "#".join(piece_split[:2]), piece_split[-1]
|
| 341 |
+
part_split = part.split("#")
|
| 342 |
+
if len(part_split) == 2:
|
| 343 |
+
part, shift2 = part_split
|
| 344 |
+
else:
|
| 345 |
+
part, shift2 = "#".join(part_split[:2]), part_split[-1]
|
| 346 |
+
shift2, _ = shift2.split(".")
|
| 347 |
+
assert shift1 == shift2
|
| 348 |
+
shift = shift1
|
| 349 |
+
assert shift != 0
|
| 350 |
+
orig = audio_path.replace("#{}".format(shift), "#0")
|
| 351 |
+
if self.save_to_memory:
|
| 352 |
+
orig_data = self.pts[orig]
|
| 353 |
+
else:
|
| 354 |
+
orig_data = torch.load(self.flac_to_pt_path(orig))
|
| 355 |
+
res = {}
|
| 356 |
+
res["label"] = shift_label(orig_data["label"], int(shift))
|
| 357 |
+
res["path"] = audio_path
|
| 358 |
+
res["audio"] = data["audio"]
|
| 359 |
+
if "velocity" in orig_data:
|
| 360 |
+
res["velocity"] = shift_label(orig_data["velocity"], int(shift))
|
| 361 |
+
if "onset_mask" in orig_data:
|
| 362 |
+
res["onset_mask"] = shift_label(orig_data["onset_mask"], int(shift))
|
| 363 |
+
if "frame_mask" in orig_data:
|
| 364 |
+
res["frame_mask"] = shift_label(orig_data["frame_mask"], int(shift))
|
| 365 |
+
return res
|
| 366 |
+
|
| 367 |
+
def load_pts(self, files):
|
| 368 |
+
self.pts = {}
|
| 369 |
+
self.logger.info("loading pts...")
|
| 370 |
+
for flac, tsv in tqdm(files, desc="loading pts"):
|
| 371 |
+
# print('flac, tsv', flac, tsv)
|
| 372 |
+
if os.path.isfile(
|
| 373 |
+
self.labels_path
|
| 374 |
+
+ os.sep
|
| 375 |
+
+ flac.split(os.sep)[-1].replace(".flac", ".pt")
|
| 376 |
+
):
|
| 377 |
+
if self.save_to_memory:
|
| 378 |
+
self.pts[flac] = torch.load(
|
| 379 |
+
self.labels_path
|
| 380 |
+
+ os.sep
|
| 381 |
+
+ flac.split(os.sep)[-1].replace(".flac", ".pt")
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
if flac.count("#") != 2:
|
| 385 |
+
self.logger.debug("two # in filename: %s", flac)
|
| 386 |
+
audio, sr = soundfile.read(flac, dtype="int16")
|
| 387 |
+
if len(audio.shape) == 2:
|
| 388 |
+
audio = audio.astype(float).mean(axis=1)
|
| 389 |
+
else:
|
| 390 |
+
audio = audio.astype(float)
|
| 391 |
+
audio = audio.astype(np.int16)
|
| 392 |
+
self.logger.debug("audio len: %d", len(audio))
|
| 393 |
+
assert sr == SAMPLE_RATE
|
| 394 |
+
audio = torch.ShortTensor(audio)
|
| 395 |
+
if "#0" not in flac:
|
| 396 |
+
assert "#" in flac
|
| 397 |
+
data = {"audio": audio}
|
| 398 |
+
if self.save_to_memory:
|
| 399 |
+
self.pts[flac] = data
|
| 400 |
+
torch.save(data, self.flac_to_pt_path(flac))
|
| 401 |
+
continue
|
| 402 |
+
midi = np.loadtxt(tsv, delimiter="\t", skiprows=1)
|
| 403 |
+
unaligned_label = midi_to_frames(
|
| 404 |
+
midi, self.instruments, conversion_map=self.conversion_map
|
| 405 |
+
)
|
| 406 |
+
if len(self.instruments) == 1:
|
| 407 |
+
unaligned_label = unaligned_label[:, -N_KEYS:]
|
| 408 |
+
if len(unaligned_label) < self.sequence_length // constants.HOP_LENGTH:
|
| 409 |
+
diff = self.sequence_length // constants.HOP_LENGTH - len(
|
| 410 |
+
unaligned_label
|
| 411 |
+
)
|
| 412 |
+
pad = torch.zeros(
|
| 413 |
+
(diff, unaligned_label.shape[1]), dtype=unaligned_label.dtype
|
| 414 |
+
)
|
| 415 |
+
unaligned_label = torch.cat((unaligned_label, pad), dim=0)
|
| 416 |
+
|
| 417 |
+
group = flac.split(os.sep)[-2].split("#")[0]
|
| 418 |
+
data = dict(
|
| 419 |
+
path=self.labels_path + os.sep + flac.split(os.sep)[-1],
|
| 420 |
+
audio=audio,
|
| 421 |
+
unaligned_label=unaligned_label,
|
| 422 |
+
group=group,
|
| 423 |
+
BON=float("inf"),
|
| 424 |
+
BON_VEC=np.full(unaligned_label.shape[1], float("inf")),
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
torch.save(data, self.flac_to_pt_path(flac))
|
| 428 |
+
if self.save_to_memory:
|
| 429 |
+
self.pts[flac] = data
|
| 430 |
+
|
| 431 |
+
def update_pts_counting(
|
| 432 |
+
self,
|
| 433 |
+
transcriber,
|
| 434 |
+
counting_window_length,
|
| 435 |
+
POS=1.1,
|
| 436 |
+
NEG=-0.001,
|
| 437 |
+
FRAME_POS=0.5,
|
| 438 |
+
to_save=None,
|
| 439 |
+
first=False,
|
| 440 |
+
update=True,
|
| 441 |
+
BEST_DIST=False,
|
| 442 |
+
peak_size=3,
|
| 443 |
+
BEST_DIST_VEC=False,
|
| 444 |
+
counting_window_hop=0,
|
| 445 |
+
):
|
| 446 |
+
self.logger.info("Updating pts...")
|
| 447 |
+
self.logger.info("First %s", first)
|
| 448 |
+
total_counting_time = 0.0 # Initialize total time for counting-based alignment
|
| 449 |
+
|
| 450 |
+
self.logger.info("POS, NEG: %s, %s", POS, NEG)
|
| 451 |
+
if to_save is not None:
|
| 452 |
+
os.makedirs(to_save, exist_ok=True)
|
| 453 |
+
self.logger.info("There are %d pts", len(self.pts))
|
| 454 |
+
update_count = 0
|
| 455 |
+
sys.stdout.flush()
|
| 456 |
+
onlt_pitch_0_files = [f for f in self.file_list if "#0" in f[0]]
|
| 457 |
+
for input_files in tqdm(onlt_pitch_0_files, desc="updating pts"):
|
| 458 |
+
flac, tsv = input_files
|
| 459 |
+
data = torch.load(self.flac_to_pt_path(flac))
|
| 460 |
+
if "unaligned_label" not in data:
|
| 461 |
+
self.logger.warning("No unaligned labels for %s", flac)
|
| 462 |
+
continue
|
| 463 |
+
audio_inp = data["audio"].float() / 32768.0
|
| 464 |
+
MAX_TIME = 5 * 60 * SAMPLE_RATE
|
| 465 |
+
audio_inp_len = len(audio_inp)
|
| 466 |
+
if audio_inp_len > MAX_TIME:
|
| 467 |
+
n_segments = int(np.ceil(audio_inp_len / MAX_TIME))
|
| 468 |
+
self.logger.debug("Long audio, splitting to %d segments", n_segments)
|
| 469 |
+
seg_len = MAX_TIME
|
| 470 |
+
onsets_preds = []
|
| 471 |
+
offset_preds = []
|
| 472 |
+
frame_preds = []
|
| 473 |
+
for i_s in range(n_segments):
|
| 474 |
+
curr = (
|
| 475 |
+
audio_inp[i_s * seg_len : (i_s + 1) * seg_len]
|
| 476 |
+
.unsqueeze(0)
|
| 477 |
+
.cuda()
|
| 478 |
+
)
|
| 479 |
+
curr_mel = melspectrogram(
|
| 480 |
+
curr.reshape(-1, curr.shape[-1])[:, :-1]
|
| 481 |
+
).transpose(-1, -2)
|
| 482 |
+
(
|
| 483 |
+
curr_onset_pred,
|
| 484 |
+
curr_offset_pred,
|
| 485 |
+
_,
|
| 486 |
+
curr_frame_pred,
|
| 487 |
+
curr_velocity_pred,
|
| 488 |
+
) = transcriber(curr_mel)
|
| 489 |
+
onsets_preds.append(curr_onset_pred)
|
| 490 |
+
offset_preds.append(curr_offset_pred)
|
| 491 |
+
frame_preds.append(curr_frame_pred)
|
| 492 |
+
onset_pred = torch.cat(onsets_preds, dim=1)
|
| 493 |
+
offset_pred = torch.cat(offset_preds, dim=1)
|
| 494 |
+
frame_pred = torch.cat(frame_preds, dim=1)
|
| 495 |
+
else:
|
| 496 |
+
audio_inp = audio_inp.unsqueeze(0).cuda()
|
| 497 |
+
mel = melspectrogram(
|
| 498 |
+
audio_inp.reshape(-1, audio_inp.shape[-1])[:, :-1]
|
| 499 |
+
).transpose(-1, -2)
|
| 500 |
+
onset_pred, offset_pred, _, frame_pred, _ = transcriber(mel)
|
| 501 |
+
self.logger.debug("Done predicting.")
|
| 502 |
+
|
| 503 |
+
# We assume onset predictions are of length N_KEYS * (len(instruments) + 1),
|
| 504 |
+
# first N_KEYS classes are the first instrument, next N_KEYS classes are the next instrument, etc.,
|
| 505 |
+
# and last N_KEYS classes are for pitch regardless of instrument
|
| 506 |
+
# Currently, frame and offset predictions are only N_KEYS classes.
|
| 507 |
+
onset_pred = onset_pred.detach().squeeze().cpu()
|
| 508 |
+
frame_pred = frame_pred.detach().squeeze().cpu()
|
| 509 |
+
|
| 510 |
+
PEAK_SIZE = peak_size
|
| 511 |
+
self.logger.debug("PEAK_SIZE: %d", PEAK_SIZE)
|
| 512 |
+
# we peak peak the onset prediction to only keep local maximum onsets
|
| 513 |
+
if peak_size > 0:
|
| 514 |
+
peaks = get_peaks(
|
| 515 |
+
onset_pred, PEAK_SIZE
|
| 516 |
+
) # we only want local peaks, in a 7-frame neighborhood, 3 to each side.
|
| 517 |
+
onset_pred[~peaks] = 0
|
| 518 |
+
|
| 519 |
+
unaligned_onsets = (data["unaligned_label"] == 3).float().numpy()
|
| 520 |
+
|
| 521 |
+
onset_pred_np = onset_pred.numpy()
|
| 522 |
+
frame_pred_np = frame_pred.numpy()
|
| 523 |
+
|
| 524 |
+
pred_bag_of_notes = (onset_pred_np[:, -N_KEYS:] >= 0.5).sum(axis=0)
|
| 525 |
+
gt_bag_of_notes = unaligned_onsets[:, -N_KEYS:].astype(bool).sum(axis=0)
|
| 526 |
+
bon_dist = (((pred_bag_of_notes - gt_bag_of_notes) ** 2).sum()) ** 0.5
|
| 527 |
+
|
| 528 |
+
pred_bag_of_notes_with_inst = (onset_pred_np >= 0.5).sum(axis=0)
|
| 529 |
+
gt_bag_of_notes_with_inst = unaligned_onsets.astype(bool).sum(axis=0)
|
| 530 |
+
bon_dist_vec = np.abs(
|
| 531 |
+
pred_bag_of_notes_with_inst - gt_bag_of_notes_with_inst
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
bon_dist /= gt_bag_of_notes.sum()
|
| 535 |
+
self.logger.debug("bag of notes dist: %f", bon_dist)
|
| 536 |
+
####
|
| 537 |
+
|
| 538 |
+
aligned_onsets = np.zeros(onset_pred_np.shape, dtype=bool)
|
| 539 |
+
aligned_frames = np.zeros(onset_pred_np.shape, dtype=bool)
|
| 540 |
+
|
| 541 |
+
# This block is the main difference between the counting approach and the DTW approach.
|
| 542 |
+
# In the counting approach we label the audio by counting note onsets: For each onset pitch class,
|
| 543 |
+
# denote by K the number of times it occurs in the unaligned label. We simply take the K highest local
|
| 544 |
+
# peaks predicted by the current model.
|
| 545 |
+
# Split unaligned onsets into chunks of size counting_window_length
|
| 546 |
+
self.logger.debug(
|
| 547 |
+
"unaligned onsets shape: %s, counting window length: %d, counting window hop: %d",
|
| 548 |
+
unaligned_onsets.shape,
|
| 549 |
+
counting_window_length,
|
| 550 |
+
counting_window_hop,
|
| 551 |
+
)
|
| 552 |
+
assert counting_window_hop <= counting_window_length
|
| 553 |
+
if counting_window_hop == 0:
|
| 554 |
+
counting_window_hop = counting_window_length
|
| 555 |
+
|
| 556 |
+
num_chunks = (
|
| 557 |
+
1
|
| 558 |
+
if counting_window_length == 0
|
| 559 |
+
else int(np.ceil(len(unaligned_onsets) / counting_window_hop))
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
self.logger.debug("number of chunks: %d", num_chunks)
|
| 563 |
+
start_time = time.time()
|
| 564 |
+
for chunk_idx in range(num_chunks):
|
| 565 |
+
start_idx = chunk_idx * counting_window_hop
|
| 566 |
+
if counting_window_length == 0:
|
| 567 |
+
end_idx = max(len(unaligned_onsets), len(onset_pred_np))
|
| 568 |
+
else:
|
| 569 |
+
end_idx = min(
|
| 570 |
+
start_idx + counting_window_length, len(unaligned_onsets)
|
| 571 |
+
)
|
| 572 |
+
chunk_onsets = unaligned_onsets[start_idx:end_idx]
|
| 573 |
+
chunk_onsets_count = (
|
| 574 |
+
(data["unaligned_label"][start_idx:end_idx, :] == 3)
|
| 575 |
+
.sum(dim=0)
|
| 576 |
+
.numpy()
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
for f, f_count in enumerate(chunk_onsets_count):
|
| 580 |
+
if f_count == 0:
|
| 581 |
+
continue
|
| 582 |
+
f_most_likely = np.sort(
|
| 583 |
+
onset_pred_np[start_idx:end_idx, f].argsort()[::-1][:f_count]
|
| 584 |
+
)
|
| 585 |
+
f_most_likely += start_idx # Adjust indices to the original size
|
| 586 |
+
aligned_onsets[f_most_likely, f] = 1
|
| 587 |
+
|
| 588 |
+
f_unaligned = chunk_onsets[:, f].nonzero()
|
| 589 |
+
assert len(f_unaligned) == 1
|
| 590 |
+
f_unaligned = f_unaligned[0]
|
| 591 |
+
|
| 592 |
+
counting_duration = time.time() - start_time
|
| 593 |
+
total_counting_time += counting_duration
|
| 594 |
+
self.logger.debug(
|
| 595 |
+
"Counting alignment for file '%s' took %.2f seconds.",
|
| 596 |
+
flac,
|
| 597 |
+
counting_duration,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Pseudo labels, Pos bigger than 1 is equivalent to not using pseudo labels
|
| 601 |
+
pseudo_onsets = (onset_pred_np >= POS) & (~aligned_onsets)
|
| 602 |
+
|
| 603 |
+
onset_label = np.maximum(pseudo_onsets, aligned_onsets)
|
| 604 |
+
|
| 605 |
+
# in this project we do not train frame stack but we calculate the labeels anyways
|
| 606 |
+
pseudo_frames = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
|
| 607 |
+
pseudo_offsets = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
|
| 608 |
+
for t, f in zip(*onset_label.nonzero()):
|
| 609 |
+
t_off = t
|
| 610 |
+
while (
|
| 611 |
+
t_off < len(pseudo_frames)
|
| 612 |
+
and frame_pred[t_off, f % N_KEYS] >= FRAME_POS
|
| 613 |
+
):
|
| 614 |
+
t_off += 1
|
| 615 |
+
pseudo_frames[t:t_off, f] = 1
|
| 616 |
+
if t_off < len(pseudo_offsets):
|
| 617 |
+
pseudo_offsets[t_off, f] = 1
|
| 618 |
+
frame_label = np.maximum(pseudo_frames, aligned_frames)
|
| 619 |
+
offset_label = get_diff(frame_label, offset=True)
|
| 620 |
+
|
| 621 |
+
label = np.maximum(2 * frame_label, offset_label)
|
| 622 |
+
label = np.maximum(3 * onset_label, label).astype(np.uint8)
|
| 623 |
+
|
| 624 |
+
if to_save is not None:
|
| 625 |
+
save_midi_alignments_and_predictions(
|
| 626 |
+
to_save,
|
| 627 |
+
data["path"],
|
| 628 |
+
self.instruments,
|
| 629 |
+
aligned_onsets,
|
| 630 |
+
aligned_frames,
|
| 631 |
+
onset_pred_np,
|
| 632 |
+
frame_pred_np,
|
| 633 |
+
prefix="",
|
| 634 |
+
group=data["group"],
|
| 635 |
+
)
|
| 636 |
+
prev_bon_dist = data.get("BON", float("inf"))
|
| 637 |
+
prev_bon_dist_vec = data.get("BON_VEC", None)
|
| 638 |
+
if update:
|
| 639 |
+
if BEST_DIST_VEC:
|
| 640 |
+
self.logger.debug("Updated Labels")
|
| 641 |
+
if prev_bon_dist_vec is None:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
"BEST_DIST_VEC is True but no previous BON_VEC found"
|
| 644 |
+
)
|
| 645 |
+
prev_label = data["label"]
|
| 646 |
+
new_label = torch.from_numpy(label).byte()
|
| 647 |
+
if first:
|
| 648 |
+
prev_label = new_label
|
| 649 |
+
update_count += 1
|
| 650 |
+
else:
|
| 651 |
+
updated_flag = False
|
| 652 |
+
num_pitches_updated = 0
|
| 653 |
+
for k in range(prev_label.shape[1]):
|
| 654 |
+
if prev_bon_dist_vec[k] > bon_dist_vec[k]:
|
| 655 |
+
prev_label[:, k] = new_label[:, k]
|
| 656 |
+
prev_bon_dist_vec[k] = bon_dist_vec[k]
|
| 657 |
+
num_pitches_updated += 1
|
| 658 |
+
updated_flag = True
|
| 659 |
+
if updated_flag:
|
| 660 |
+
update_count += 1
|
| 661 |
+
self.logger.debug("Updated %d pitches", num_pitches_updated)
|
| 662 |
+
data["label"] = prev_label
|
| 663 |
+
data["BON_VEC"] = prev_bon_dist_vec
|
| 664 |
+
self.logger.debug("saved updated pt")
|
| 665 |
+
torch.save(
|
| 666 |
+
data,
|
| 667 |
+
self.labels_path
|
| 668 |
+
+ os.sep
|
| 669 |
+
+ flac.split(os.sep)[-1]
|
| 670 |
+
.replace(".flac", ".pt")
|
| 671 |
+
.replace(".mp3", ".pt"),
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
elif not BEST_DIST or bon_dist < prev_bon_dist:
|
| 675 |
+
update_count += 1
|
| 676 |
+
self.logger.debug("Updated Labels")
|
| 677 |
+
|
| 678 |
+
data["label"] = torch.from_numpy(label).byte()
|
| 679 |
+
|
| 680 |
+
data["BON"] = bon_dist
|
| 681 |
+
self.logger.debug("saved updated pt")
|
| 682 |
+
torch.save(
|
| 683 |
+
data,
|
| 684 |
+
self.labels_path
|
| 685 |
+
+ os.sep
|
| 686 |
+
+ flac.split(os.sep)[-1]
|
| 687 |
+
.replace(".flac", ".pt")
|
| 688 |
+
.replace(".mp3", ".pt"),
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
if bon_dist < prev_bon_dist:
|
| 692 |
+
self.logger.debug(
|
| 693 |
+
"Bag of notes distance improved from %f to %f",
|
| 694 |
+
prev_bon_dist,
|
| 695 |
+
bon_dist,
|
| 696 |
+
)
|
| 697 |
+
data["BON"] = bon_dist
|
| 698 |
+
|
| 699 |
+
if to_save is not None and BEST_DIST:
|
| 700 |
+
os.makedirs(to_save + "/BEST_BON", exist_ok=True)
|
| 701 |
+
save_midi_alignments_and_predictions(
|
| 702 |
+
to_save + "/BEST_BON",
|
| 703 |
+
data["path"],
|
| 704 |
+
self.instruments,
|
| 705 |
+
aligned_onsets,
|
| 706 |
+
aligned_frames,
|
| 707 |
+
onset_pred_np,
|
| 708 |
+
frame_pred_np,
|
| 709 |
+
prefix="BEST_BON",
|
| 710 |
+
group=data["group"],
|
| 711 |
+
use_time=False,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
self.logger.info(
|
| 715 |
+
"Updated %d pts out of %d", update_count, len(onlt_pitch_0_files)
|
| 716 |
+
)
|
| 717 |
+
self.logger.info(
|
| 718 |
+
"Total counting alignment time for all files: %.2f seconds.", total_counting_time
|
| 719 |
+
)
|
onsets_and_frames/decoding.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def extract_notes(onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5):
|
| 6 |
+
"""
|
| 7 |
+
Finds the note timings based on the onsets and frames information
|
| 8 |
+
|
| 9 |
+
Parameters
|
| 10 |
+
----------
|
| 11 |
+
onsets: torch.FloatTensor, shape = [frames, bins]
|
| 12 |
+
frames: torch.FloatTensor, shape = [frames, bins]
|
| 13 |
+
velocity: torch.FloatTensor, shape = [frames, bins]
|
| 14 |
+
onset_threshold: float
|
| 15 |
+
frame_threshold: float
|
| 16 |
+
|
| 17 |
+
Returns
|
| 18 |
+
-------
|
| 19 |
+
pitches: np.ndarray of bin_indices
|
| 20 |
+
intervals: np.ndarray of rows containing (onset_index, offset_index)
|
| 21 |
+
velocities: np.ndarray of velocity values
|
| 22 |
+
"""
|
| 23 |
+
# onsets_forward = torch.roll(onsets, shifts=(1, 0), dims=(0, 1))
|
| 24 |
+
# onsets_forward[0, :] = 0
|
| 25 |
+
# onsets_backward = torch.roll(onsets, shifts=(-1, 0), dims=(0, 1))
|
| 26 |
+
# onsets_backward[-1, :] = 0
|
| 27 |
+
# onsets_peak = torch.logical_and(onsets >= onsets_forward, onsets >= onsets_backward)
|
| 28 |
+
# onsets_peak = torch.logical_and(onsets >= 0.25, onsets_peak)
|
| 29 |
+
|
| 30 |
+
onsets = (onsets > onset_threshold).cpu().to(torch.uint8)
|
| 31 |
+
frames = (frames > frame_threshold).cpu().to(torch.uint8)
|
| 32 |
+
onset_diff = torch.cat([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], dim=0) == 1
|
| 33 |
+
# onset_diff = torch.cat([frames[:1, :], frames[1:, :] - frames[:-1, :]], dim=0) == 1
|
| 34 |
+
|
| 35 |
+
pitches = []
|
| 36 |
+
intervals = []
|
| 37 |
+
velocities = []
|
| 38 |
+
|
| 39 |
+
# for nonzero in onsets_peak.nonzero(as_tuple=False):
|
| 40 |
+
for nonzero in onset_diff.nonzero(as_tuple=False):
|
| 41 |
+
frame = nonzero[0].item()
|
| 42 |
+
pitch = nonzero[1].item()
|
| 43 |
+
|
| 44 |
+
onset = frame
|
| 45 |
+
offset = frame
|
| 46 |
+
velocity_samples = []
|
| 47 |
+
|
| 48 |
+
while onsets[offset, pitch].item() or frames[offset, pitch].item():
|
| 49 |
+
if onsets[offset, pitch].item():
|
| 50 |
+
# if frames[offset, pitch].item():
|
| 51 |
+
velocity_samples.append(velocity[offset, pitch].item())
|
| 52 |
+
offset += 1
|
| 53 |
+
if offset == onsets.shape[0]:
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
if offset > onset:
|
| 57 |
+
pitches.append(pitch)
|
| 58 |
+
intervals.append([onset, offset])
|
| 59 |
+
velocities.append(
|
| 60 |
+
np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return np.array(pitches), np.array(intervals), np.array(velocities)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def notes_to_frames(pitches, intervals, shape, mask=None):
|
| 67 |
+
"""
|
| 68 |
+
Takes lists specifying notes sequences and return
|
| 69 |
+
|
| 70 |
+
Parameters
|
| 71 |
+
----------
|
| 72 |
+
pitches: list of pitch bin indices
|
| 73 |
+
intervals: list of [onset, offset] ranges of bin indices
|
| 74 |
+
shape: the shape of the original piano roll, [n_frames, n_bins]
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
time: np.ndarray containing the frame indices
|
| 79 |
+
freqs: list of np.ndarray, each containing the frequency bin indices
|
| 80 |
+
"""
|
| 81 |
+
roll = np.zeros(tuple(shape))
|
| 82 |
+
for pitch, (onset, offset) in zip(pitches, intervals):
|
| 83 |
+
# print('pitch', pitch, onset, offset)
|
| 84 |
+
# print('onset offset', onset, offset, pitch)
|
| 85 |
+
roll[onset:offset, pitch] = 1
|
| 86 |
+
if mask is not None:
|
| 87 |
+
roll *= mask
|
| 88 |
+
time = np.arange(roll.shape[0])
|
| 89 |
+
freqs = [roll[t, :].nonzero()[0] for t in time]
|
| 90 |
+
# if mask_size is not None:
|
| 91 |
+
# mask = np.zeros(tuple(shape))
|
| 92 |
+
# notes = roll.shape[1]
|
| 93 |
+
# for n in range(notes):
|
| 94 |
+
# onset_d = roll[1:, n] - roll[: -1, n]
|
| 95 |
+
# print('unique', np.unique(onset_d))
|
| 96 |
+
# onset_d[onset_d < 0] = 0
|
| 97 |
+
# print('n', n, onset_d.sum())
|
| 98 |
+
# onset_d = np.concatenate((np.zeros((1, 1)), roll[1:, n] - roll[: -1, n]))
|
| 99 |
+
# onset_d[onset_d < 0] = 0
|
| 100 |
+
# for r in range(mask_size):
|
| 101 |
+
# mask[:, n] += np.roll(onset_d, r)
|
| 102 |
+
return time, freqs
|
onsets_and_frames/hf_model.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Hub-compatible wrapper for CountEM music transcription models.
|
| 3 |
+
"""
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Union, Tuple
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 10 |
+
|
| 11 |
+
from onsets_and_frames.transcriber import OnsetsAndFrames
|
| 12 |
+
from onsets_and_frames.mel import MelSpectrogram
|
| 13 |
+
from onsets_and_frames.midi_utils import frames2midi
|
| 14 |
+
from onsets_and_frames.constants import (
|
| 15 |
+
N_MELS,
|
| 16 |
+
MIN_MIDI,
|
| 17 |
+
MAX_MIDI,
|
| 18 |
+
HOP_LENGTH,
|
| 19 |
+
SAMPLE_RATE,
|
| 20 |
+
WINDOW_LENGTH,
|
| 21 |
+
MEL_FMIN,
|
| 22 |
+
MEL_FMAX,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CountEMModel(
|
| 27 |
+
OnsetsAndFrames,
|
| 28 |
+
PyTorchModelHubMixin,
|
| 29 |
+
# Optional metadata that gets pushed to model card
|
| 30 |
+
library_name="countem",
|
| 31 |
+
tags=["audio", "music-transcription", "automatic-music-transcription", "midi"],
|
| 32 |
+
license="cc-by-4.0",
|
| 33 |
+
repo_url="https://github.com/Yoni-Yaffe/count-the-notes",
|
| 34 |
+
paper_url="https://arxiv.org/abs/2511.14250",
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Hugging Face Hub-compatible wrapper for CountEM automatic music transcription models.
|
| 38 |
+
|
| 39 |
+
This model performs automatic music transcription (AMT) from audio to MIDI.
|
| 40 |
+
It uses the Onsets & Frames architecture trained with the CountEM framework,
|
| 41 |
+
which enables training with weak, unordered note count histograms.
|
| 42 |
+
|
| 43 |
+
Example usage:
|
| 44 |
+
```python
|
| 45 |
+
from onsets_and_frames.hf_model import CountEMModel
|
| 46 |
+
import soundfile as sf
|
| 47 |
+
|
| 48 |
+
# Load model from Hub
|
| 49 |
+
model = CountEMModel.from_pretrained("Yoni-Yaffe/countem-musicnet")
|
| 50 |
+
|
| 51 |
+
# Load audio (must be 16kHz)
|
| 52 |
+
audio, sr = sf.read("audio.flac")
|
| 53 |
+
assert sr == 16000, "Audio must be 16kHz"
|
| 54 |
+
|
| 55 |
+
# Transcribe to MIDI
|
| 56 |
+
model.transcribe_to_midi(audio, "output.mid")
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model_complexity: Complexity multiplier for the model (default: 64)
|
| 61 |
+
onset_complexity: Complexity multiplier for onset stack (default: 1.5)
|
| 62 |
+
n_instruments: Number of instruments to transcribe (default: 1)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
model_complexity: int = 64,
|
| 68 |
+
onset_complexity: float = 1.5,
|
| 69 |
+
n_instruments: int = 1,
|
| 70 |
+
**kwargs
|
| 71 |
+
):
|
| 72 |
+
# Initialize the base OnsetsAndFrames model
|
| 73 |
+
n_keys = MAX_MIDI - MIN_MIDI + 1
|
| 74 |
+
OnsetsAndFrames.__init__(
|
| 75 |
+
self,
|
| 76 |
+
input_features=N_MELS,
|
| 77 |
+
output_features=n_keys,
|
| 78 |
+
model_complexity=model_complexity,
|
| 79 |
+
onset_complexity=onset_complexity,
|
| 80 |
+
n_instruments=n_instruments,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Store config for HF Hub
|
| 84 |
+
self.config = {
|
| 85 |
+
"model_complexity": model_complexity,
|
| 86 |
+
"onset_complexity": onset_complexity,
|
| 87 |
+
"n_instruments": n_instruments,
|
| 88 |
+
"n_mels": N_MELS,
|
| 89 |
+
"n_keys": n_keys,
|
| 90 |
+
"sample_rate": SAMPLE_RATE,
|
| 91 |
+
"hop_length": HOP_LENGTH,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Add mel spectrogram as a submodule for proper device management
|
| 95 |
+
# This ensures the mel transform moves with the model when calling .to(device)
|
| 96 |
+
self.melspectrogram = MelSpectrogram(
|
| 97 |
+
n_mels=N_MELS,
|
| 98 |
+
sample_rate=SAMPLE_RATE,
|
| 99 |
+
filter_length=WINDOW_LENGTH,
|
| 100 |
+
hop_length=HOP_LENGTH,
|
| 101 |
+
mel_fmin=MEL_FMIN,
|
| 102 |
+
mel_fmax=MEL_FMAX,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, audio: Union[np.ndarray, torch.Tensor]):
|
| 106 |
+
"""
|
| 107 |
+
Forward pass that accepts raw audio waveforms.
|
| 108 |
+
|
| 109 |
+
Unlike the parent OnsetsAndFrames which expects mel spectrograms,
|
| 110 |
+
this forward method accepts raw audio and converts it internally.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
audio: Raw audio waveform, shape (batch, n_samples) or (n_samples,)
|
| 114 |
+
Should be normalized to [-1, 1] or will be normalized automatically
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
|
| 118 |
+
"""
|
| 119 |
+
# Convert to torch tensor if needed
|
| 120 |
+
if isinstance(audio, np.ndarray):
|
| 121 |
+
audio = torch.from_numpy(audio).float()
|
| 122 |
+
|
| 123 |
+
# Ensure audio is in range [-1, 1]
|
| 124 |
+
if audio.dtype == torch.int16:
|
| 125 |
+
audio = audio.float() / 32768.0
|
| 126 |
+
elif audio.max() > 1.0 or audio.min() < -1.0:
|
| 127 |
+
audio = audio / max(abs(audio.max()), abs(audio.min()))
|
| 128 |
+
|
| 129 |
+
# Add batch dimension if needed
|
| 130 |
+
if audio.dim() == 1:
|
| 131 |
+
audio = audio.unsqueeze(0)
|
| 132 |
+
|
| 133 |
+
device = next(self.parameters()).device
|
| 134 |
+
audio = audio.to(device)
|
| 135 |
+
|
| 136 |
+
# Remove last sample to fix frame count mismatch
|
| 137 |
+
audio = audio[:, :-1]
|
| 138 |
+
|
| 139 |
+
mel = self.melspectrogram(audio)
|
| 140 |
+
|
| 141 |
+
# Transpose to (batch, time, features) format expected by parent model
|
| 142 |
+
mel = mel.transpose(-1, -2)
|
| 143 |
+
|
| 144 |
+
return super().forward(mel)
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def transcribe(
|
| 148 |
+
self,
|
| 149 |
+
audio: Union[np.ndarray, torch.Tensor],
|
| 150 |
+
onset_threshold: float = 0.5,
|
| 151 |
+
frame_threshold: float = 0.5,
|
| 152 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 153 |
+
"""
|
| 154 |
+
Transcribe audio to note predictions.
|
| 155 |
+
|
| 156 |
+
Automatically handles long audio by splitting into segments (max 5 minutes each)
|
| 157 |
+
to avoid memory issues.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
audio: Audio waveform, shape (n_samples,), normalized to [-1, 1]
|
| 161 |
+
onset_threshold: Threshold for onset detection (default: 0.5)
|
| 162 |
+
frame_threshold: Threshold for frame detection (default: 0.5)
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Tuple of (onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred)
|
| 166 |
+
All are numpy arrays of shape (n_frames, 88) except velocity which may vary
|
| 167 |
+
"""
|
| 168 |
+
self.eval()
|
| 169 |
+
|
| 170 |
+
# Convert to torch tensor if needed
|
| 171 |
+
if isinstance(audio, np.ndarray):
|
| 172 |
+
audio = torch.from_numpy(audio).float()
|
| 173 |
+
|
| 174 |
+
# Ensure audio is 1D (convert stereo to mono if needed)
|
| 175 |
+
if audio.dim() > 1:
|
| 176 |
+
# If stereo or multi-channel, take mean across channels
|
| 177 |
+
audio = audio.mean(dim=-1 if audio.shape[-1] <=2 else 0)
|
| 178 |
+
|
| 179 |
+
# Normalize audio
|
| 180 |
+
if audio.dtype == torch.int16:
|
| 181 |
+
audio = audio.float() / 32768.0
|
| 182 |
+
elif audio.max() > 1.0 or audio.min() < -1.0:
|
| 183 |
+
audio = audio / max(abs(audio.max()), abs(audio.min()))
|
| 184 |
+
|
| 185 |
+
device = next(self.parameters()).device
|
| 186 |
+
audio = audio.to(device)
|
| 187 |
+
|
| 188 |
+
# Handle long audio by segmenting
|
| 189 |
+
MAX_TIME = 5 * 60 * SAMPLE_RATE # 5 minutes
|
| 190 |
+
audio_len = len(audio)
|
| 191 |
+
|
| 192 |
+
if audio_len > MAX_TIME:
|
| 193 |
+
# Split into segments
|
| 194 |
+
n_segments = int(np.ceil(audio_len / MAX_TIME))
|
| 195 |
+
seg_len = MAX_TIME
|
| 196 |
+
|
| 197 |
+
onset_preds = []
|
| 198 |
+
offset_preds = []
|
| 199 |
+
activation_preds = []
|
| 200 |
+
frame_preds = []
|
| 201 |
+
velocity_preds = []
|
| 202 |
+
|
| 203 |
+
for i_s in range(n_segments):
|
| 204 |
+
start = i_s * seg_len
|
| 205 |
+
end = min((i_s + 1) * seg_len, audio_len)
|
| 206 |
+
segment = audio[start:end]
|
| 207 |
+
|
| 208 |
+
# Forward pass on segment
|
| 209 |
+
onset_seg, offset_seg, activation_seg, frame_seg, velocity_seg = self(segment)
|
| 210 |
+
|
| 211 |
+
onset_preds.append(onset_seg)
|
| 212 |
+
offset_preds.append(offset_seg)
|
| 213 |
+
activation_preds.append(activation_seg)
|
| 214 |
+
frame_preds.append(frame_seg)
|
| 215 |
+
velocity_preds.append(velocity_seg)
|
| 216 |
+
|
| 217 |
+
# Concatenate along time dimension (dim=1)
|
| 218 |
+
onset_pred = torch.cat(onset_preds, dim=1)
|
| 219 |
+
offset_pred = torch.cat(offset_preds, dim=1)
|
| 220 |
+
activation_pred = torch.cat(activation_preds, dim=1)
|
| 221 |
+
frame_pred = torch.cat(frame_preds, dim=1)
|
| 222 |
+
velocity_pred = torch.cat(velocity_preds, dim=1)
|
| 223 |
+
else:
|
| 224 |
+
# Short audio, process directly
|
| 225 |
+
onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred = self(audio)
|
| 226 |
+
|
| 227 |
+
# Convert to numpy and remove batch dimension
|
| 228 |
+
onset_pred = onset_pred.squeeze(0).cpu().numpy()
|
| 229 |
+
offset_pred = offset_pred.squeeze(0).cpu().numpy()
|
| 230 |
+
activation_pred = activation_pred.squeeze(0).cpu().numpy()
|
| 231 |
+
frame_pred = frame_pred.squeeze(0).cpu().numpy()
|
| 232 |
+
velocity_pred = velocity_pred.squeeze(0).cpu().numpy()
|
| 233 |
+
|
| 234 |
+
return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred
|
| 235 |
+
|
| 236 |
+
def transcribe_to_midi(
|
| 237 |
+
self,
|
| 238 |
+
audio: Union[np.ndarray, torch.Tensor, str, Path],
|
| 239 |
+
output_path: Union[str, Path],
|
| 240 |
+
onset_threshold: float = 0.5,
|
| 241 |
+
frame_threshold: float = 0.5,
|
| 242 |
+
) -> None:
|
| 243 |
+
"""
|
| 244 |
+
Transcribe audio to MIDI file.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
audio: Audio waveform, numpy array, torch tensor, or path to audio file
|
| 248 |
+
output_path: Path to save MIDI file
|
| 249 |
+
onset_threshold: Threshold for onset detection (default: 0.5)
|
| 250 |
+
frame_threshold: Threshold for frame detection (default: 0.5)
|
| 251 |
+
"""
|
| 252 |
+
# Load audio from file if path is provided
|
| 253 |
+
if isinstance(audio, (str, Path)):
|
| 254 |
+
audio, sr = sf.read(audio, dtype="float32")
|
| 255 |
+
if sr != SAMPLE_RATE:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"Audio must be {SAMPLE_RATE}Hz, got {sr}Hz. "
|
| 258 |
+
f"Please resample to {SAMPLE_RATE}Hz first."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Get predictions
|
| 262 |
+
onset_pred, offset_pred, _, frame_pred, velocity_pred = self.transcribe(
|
| 263 |
+
audio, onset_threshold, frame_threshold
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Default instrument mapping (piano)
|
| 267 |
+
inst_mapping = {0: 0} # instrument 0 -> MIDI program 0 (Acoustic Grand Piano)
|
| 268 |
+
|
| 269 |
+
# Convert predictions to MIDI
|
| 270 |
+
frames2midi(
|
| 271 |
+
str(output_path),
|
| 272 |
+
onset_pred,
|
| 273 |
+
frame_pred,
|
| 274 |
+
velocity_pred,
|
| 275 |
+
onset_threshold=onset_threshold,
|
| 276 |
+
frame_threshold=frame_threshold,
|
| 277 |
+
scaling=HOP_LENGTH / SAMPLE_RATE,
|
| 278 |
+
inst_mapping=inst_mapping,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def to_legacy(self) -> OnsetsAndFrames:
|
| 282 |
+
"""
|
| 283 |
+
Convert this HuggingFace-compatible model to a legacy OnsetsAndFrames instance.
|
| 284 |
+
|
| 285 |
+
This is useful for:
|
| 286 |
+
- Fine-tuning models downloaded from HuggingFace Hub using existing training code
|
| 287 |
+
- Using HF models with existing inference scripts that expect OnsetsAndFrames
|
| 288 |
+
|
| 289 |
+
The legacy model will use the global melspectrogram from mel.py instead of
|
| 290 |
+
the instance-specific one in this model.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
OnsetsAndFrames instance with copied weights
|
| 294 |
+
"""
|
| 295 |
+
# Create legacy model with same architecture
|
| 296 |
+
legacy_model = OnsetsAndFrames(
|
| 297 |
+
input_features=self.config['n_mels'],
|
| 298 |
+
output_features=self.config['n_keys'],
|
| 299 |
+
model_complexity=self.config['model_complexity'],
|
| 300 |
+
onset_complexity=self.config['onset_complexity'],
|
| 301 |
+
n_instruments=self.config['n_instruments']
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Get the state dict and filter out melspectrogram keys
|
| 305 |
+
state_dict = self.state_dict()
|
| 306 |
+
legacy_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('melspectrogram.')}
|
| 307 |
+
|
| 308 |
+
# Copy state dict (only model weights, not mel spectrogram)
|
| 309 |
+
# The legacy model will use the global melspectrogram
|
| 310 |
+
legacy_model.load_state_dict(legacy_state_dict)
|
| 311 |
+
|
| 312 |
+
return legacy_model
|
| 313 |
+
|
| 314 |
+
@classmethod
|
| 315 |
+
def from_legacy_checkpoint(
|
| 316 |
+
cls,
|
| 317 |
+
checkpoint_path: Union[str, Path],
|
| 318 |
+
**kwargs
|
| 319 |
+
) -> "CountEMModel":
|
| 320 |
+
"""
|
| 321 |
+
Load a model from a legacy checkpoint (saved with torch.save(model)).
|
| 322 |
+
|
| 323 |
+
This is useful for converting old checkpoints to the new HF-compatible format.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
checkpoint_path: Path to the legacy .pt checkpoint file
|
| 327 |
+
**kwargs: Additional arguments for model initialization
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
CountEMModel instance with loaded weights
|
| 331 |
+
"""
|
| 332 |
+
# Load the legacy checkpoint
|
| 333 |
+
legacy_model = torch.load(checkpoint_path, map_location="cpu")
|
| 334 |
+
|
| 335 |
+
# Extract configuration from the loaded model
|
| 336 |
+
# Infer model_complexity from the model structure
|
| 337 |
+
# ConvStack.cnn[0] is the first Conv2d layer with out_channels = model_size // 16
|
| 338 |
+
first_conv_channels = legacy_model.offset_stack[0].cnn[0].out_channels
|
| 339 |
+
model_size = first_conv_channels * 16
|
| 340 |
+
model_complexity = model_size // 16
|
| 341 |
+
|
| 342 |
+
# Infer onset_complexity
|
| 343 |
+
onset_first_conv_channels = legacy_model.onset_stack[0].cnn[0].out_channels
|
| 344 |
+
onset_model_size = onset_first_conv_channels * 16
|
| 345 |
+
onset_complexity = onset_model_size / model_size
|
| 346 |
+
|
| 347 |
+
# Infer n_instruments from output layer
|
| 348 |
+
# onset_stack[2] is the Linear layer
|
| 349 |
+
onset_out_features = legacy_model.onset_stack[2].out_features
|
| 350 |
+
n_keys = MAX_MIDI - MIN_MIDI + 1
|
| 351 |
+
n_instruments = onset_out_features // n_keys
|
| 352 |
+
|
| 353 |
+
# Create new model with the same configuration
|
| 354 |
+
model = cls(
|
| 355 |
+
model_complexity=model_complexity,
|
| 356 |
+
onset_complexity=onset_complexity,
|
| 357 |
+
n_instruments=n_instruments,
|
| 358 |
+
**kwargs
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Copy the state dict (strict=False because new model has melspectrogram submodule)
|
| 362 |
+
model.load_state_dict(legacy_model.state_dict(), strict=False)
|
| 363 |
+
|
| 364 |
+
return model
|
onsets_and_frames/lstm.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BiLSTM(nn.Module):
|
| 6 |
+
inference_chunk_length = 512
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_features, recurrent_features, use_gru=False, dropout=0.0):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.rnn = (nn.LSTM if not use_gru else nn.GRU)(
|
| 11 |
+
input_features,
|
| 12 |
+
recurrent_features,
|
| 13 |
+
batch_first=True,
|
| 14 |
+
bidirectional=True,
|
| 15 |
+
dropout=dropout,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
if self.training:
|
| 20 |
+
return self.rnn(x)[0]
|
| 21 |
+
else:
|
| 22 |
+
# evaluation mode: support for longer sequences that do not fit in memory
|
| 23 |
+
batch_size, sequence_length, input_features = x.shape
|
| 24 |
+
hidden_size = self.rnn.hidden_size
|
| 25 |
+
num_directions = 2 if self.rnn.bidirectional else 1
|
| 26 |
+
|
| 27 |
+
h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
|
| 28 |
+
c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
|
| 29 |
+
output = torch.zeros(
|
| 30 |
+
batch_size,
|
| 31 |
+
sequence_length,
|
| 32 |
+
num_directions * hidden_size,
|
| 33 |
+
device=x.device,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# forward direction
|
| 37 |
+
slices = range(0, sequence_length, self.inference_chunk_length)
|
| 38 |
+
for start in slices:
|
| 39 |
+
end = start + self.inference_chunk_length
|
| 40 |
+
output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
|
| 41 |
+
|
| 42 |
+
# reverse direction
|
| 43 |
+
if self.rnn.bidirectional:
|
| 44 |
+
h.zero_()
|
| 45 |
+
c.zero_()
|
| 46 |
+
|
| 47 |
+
for start in reversed(slices):
|
| 48 |
+
end = start + self.inference_chunk_length
|
| 49 |
+
result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
|
| 50 |
+
output[:, start:end, hidden_size:] = result[:, :, hidden_size:]
|
| 51 |
+
|
| 52 |
+
return output
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class UniLSTM(nn.Module):
|
| 56 |
+
inference_chunk_length = 512
|
| 57 |
+
|
| 58 |
+
def __init__(self, input_features, recurrent_features):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.rnn = nn.LSTM(input_features, recurrent_features, batch_first=True)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
if self.training:
|
| 64 |
+
return self.rnn(x)[0]
|
| 65 |
+
else:
|
| 66 |
+
# evaluation mode: support for longer sequences that do not fit in memory
|
| 67 |
+
batch_size, sequence_length, input_features = x.shape
|
| 68 |
+
hidden_size = self.rnn.hidden_size
|
| 69 |
+
num_directions = 2 if self.rnn.bidirectional else 1
|
| 70 |
+
|
| 71 |
+
h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
|
| 72 |
+
c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
|
| 73 |
+
output = torch.zeros(
|
| 74 |
+
batch_size,
|
| 75 |
+
sequence_length,
|
| 76 |
+
num_directions * hidden_size,
|
| 77 |
+
device=x.device,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# forward direction
|
| 81 |
+
slices = range(0, sequence_length, self.inference_chunk_length)
|
| 82 |
+
for start in slices:
|
| 83 |
+
end = start + self.inference_chunk_length
|
| 84 |
+
output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
|
| 85 |
+
|
| 86 |
+
# reverse direction
|
| 87 |
+
if self.rnn.bidirectional:
|
| 88 |
+
h.zero_()
|
| 89 |
+
c.zero_()
|
| 90 |
+
|
| 91 |
+
for start in reversed(slices):
|
| 92 |
+
end = start + self.inference_chunk_length
|
| 93 |
+
result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
|
| 94 |
+
output[:, start:end, hidden_size:] = result[:, :, hidden_size:]
|
| 95 |
+
|
| 96 |
+
return output
|
onsets_and_frames/mel.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from librosa.filters import mel
|
| 5 |
+
from librosa.util import pad_center
|
| 6 |
+
from scipy.signal import get_window
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
|
| 9 |
+
from onsets_and_frames.constants import (
|
| 10 |
+
DEFAULT_DEVICE,
|
| 11 |
+
HOP_LENGTH,
|
| 12 |
+
MEL_FMAX,
|
| 13 |
+
MEL_FMIN,
|
| 14 |
+
N_MELS,
|
| 15 |
+
SAMPLE_RATE,
|
| 16 |
+
WINDOW_LENGTH,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class STFT(torch.nn.Module):
|
| 21 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, filter_length, hop_length, win_length=None, window="hann"):
|
| 24 |
+
super(STFT, self).__init__()
|
| 25 |
+
if win_length is None:
|
| 26 |
+
win_length = filter_length
|
| 27 |
+
|
| 28 |
+
self.filter_length = filter_length
|
| 29 |
+
self.hop_length = hop_length
|
| 30 |
+
self.win_length = win_length
|
| 31 |
+
self.window = window
|
| 32 |
+
self.forward_transform = None
|
| 33 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 34 |
+
|
| 35 |
+
cutoff = int((self.filter_length / 2 + 1))
|
| 36 |
+
fourier_basis = np.vstack(
|
| 37 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 41 |
+
|
| 42 |
+
if window is not None:
|
| 43 |
+
assert filter_length >= win_length
|
| 44 |
+
# get window and zero center pad it to filter_length
|
| 45 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
| 46 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
| 47 |
+
fft_window = torch.from_numpy(fft_window).float()
|
| 48 |
+
|
| 49 |
+
# window the bases
|
| 50 |
+
forward_basis *= fft_window
|
| 51 |
+
|
| 52 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
| 53 |
+
|
| 54 |
+
def forward(self, input_data):
|
| 55 |
+
num_batches = input_data.size(0)
|
| 56 |
+
num_samples = input_data.size(1)
|
| 57 |
+
|
| 58 |
+
# similar to librosa, reflect-pad the input
|
| 59 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
| 60 |
+
# print('inp before', input_data.shape)
|
| 61 |
+
input_data = F.pad(
|
| 62 |
+
input_data.unsqueeze(1),
|
| 63 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
| 64 |
+
mode="reflect",
|
| 65 |
+
)
|
| 66 |
+
input_data = input_data.squeeze(1)
|
| 67 |
+
# print('inp after', input_data.shape)
|
| 68 |
+
|
| 69 |
+
forward_transform = F.conv1d(
|
| 70 |
+
input_data,
|
| 71 |
+
Variable(self.forward_basis, requires_grad=False),
|
| 72 |
+
stride=self.hop_length,
|
| 73 |
+
padding=0,
|
| 74 |
+
)
|
| 75 |
+
# print('fwd', forward_transform.shape)
|
| 76 |
+
|
| 77 |
+
cutoff = int((self.filter_length / 2) + 1)
|
| 78 |
+
real_part = forward_transform[:, :cutoff, :]
|
| 79 |
+
imag_part = forward_transform[:, cutoff:, :]
|
| 80 |
+
|
| 81 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
| 82 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
| 83 |
+
|
| 84 |
+
return magnitude, phase
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MelSpectrogram(torch.nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
n_mels,
|
| 91 |
+
sample_rate,
|
| 92 |
+
filter_length,
|
| 93 |
+
hop_length,
|
| 94 |
+
win_length=None,
|
| 95 |
+
mel_fmin=0.0,
|
| 96 |
+
mel_fmax=None,
|
| 97 |
+
):
|
| 98 |
+
super(MelSpectrogram, self).__init__()
|
| 99 |
+
self.stft = STFT(filter_length, hop_length, win_length)
|
| 100 |
+
|
| 101 |
+
mel_basis = mel(
|
| 102 |
+
sr=sample_rate,
|
| 103 |
+
n_fft=filter_length,
|
| 104 |
+
n_mels=n_mels,
|
| 105 |
+
fmin=mel_fmin,
|
| 106 |
+
fmax=mel_fmax,
|
| 107 |
+
htk=True,
|
| 108 |
+
)
|
| 109 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 110 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 111 |
+
|
| 112 |
+
def forward(self, y):
|
| 113 |
+
"""Computes mel-spectrograms from a batch of waves
|
| 114 |
+
PARAMS
|
| 115 |
+
------
|
| 116 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
| 117 |
+
RETURNS
|
| 118 |
+
-------
|
| 119 |
+
mel_output: torch.FloatTensor of shape (B, T, n_mels)
|
| 120 |
+
"""
|
| 121 |
+
assert torch.min(y.data) >= -1
|
| 122 |
+
assert torch.max(y.data) <= 1
|
| 123 |
+
|
| 124 |
+
magnitudes, phases = self.stft(y)
|
| 125 |
+
magnitudes = magnitudes.data
|
| 126 |
+
|
| 127 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
| 128 |
+
mel_output = torch.log(torch.clamp(mel_output, min=1e-5))
|
| 129 |
+
return mel_output
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# the default melspectrogram converter across the project
|
| 133 |
+
melspectrogram = MelSpectrogram(
|
| 134 |
+
N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, mel_fmin=MEL_FMIN, mel_fmax=MEL_FMAX
|
| 135 |
+
)
|
| 136 |
+
melspectrogram.to(DEFAULT_DEVICE)
|
onsets_and_frames/midi_utils.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
+
import mido
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from mido import Message, MidiFile, MidiTrack
|
| 8 |
+
|
| 9 |
+
from onsets_and_frames.constants import (
|
| 10 |
+
DRUM_CHANNEL,
|
| 11 |
+
HOP_LENGTH,
|
| 12 |
+
HOPS_IN_OFFSET,
|
| 13 |
+
HOPS_IN_ONSET,
|
| 14 |
+
MAX_MIDI,
|
| 15 |
+
MIN_MIDI,
|
| 16 |
+
N_KEYS,
|
| 17 |
+
SAMPLE_RATE,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .utils import max_inst
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def midi_to_hz(m):
|
| 24 |
+
return 440.0 * (2.0 ** ((m - 69.0) / 12.0))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def hz_to_midi(h):
|
| 28 |
+
return 12.0 * np.log2(h / (440.0)) + 69.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def midi_to_frames(midi, instruments, conversion_map=None):
|
| 32 |
+
n_keys = MAX_MIDI - MIN_MIDI + 1
|
| 33 |
+
midi_length = int((max(midi[:, 1]) + 1) * SAMPLE_RATE)
|
| 34 |
+
n_steps = (midi_length - 1) // HOP_LENGTH + 1
|
| 35 |
+
n_channels = len(instruments) + 1
|
| 36 |
+
label = torch.zeros(n_steps, n_keys * n_channels, dtype=torch.uint8)
|
| 37 |
+
for onset, offset, note, vel, instrument in midi:
|
| 38 |
+
f = int(note) - MIN_MIDI
|
| 39 |
+
if 104 > instrument > 87 or instrument > 111:
|
| 40 |
+
continue
|
| 41 |
+
if f >= n_keys or f < 0:
|
| 42 |
+
continue
|
| 43 |
+
assert 0 < vel < 128
|
| 44 |
+
instrument = int(instrument)
|
| 45 |
+
if conversion_map is not None:
|
| 46 |
+
if instrument not in conversion_map:
|
| 47 |
+
continue
|
| 48 |
+
instrument = conversion_map[instrument]
|
| 49 |
+
left = int(round(onset * SAMPLE_RATE / HOP_LENGTH))
|
| 50 |
+
onset_right = min(n_steps, left + HOPS_IN_ONSET)
|
| 51 |
+
frame_right = int(round(offset * SAMPLE_RATE / HOP_LENGTH))
|
| 52 |
+
frame_right = min(n_steps, frame_right)
|
| 53 |
+
offset_right = min(n_steps, frame_right + HOPS_IN_OFFSET)
|
| 54 |
+
if int(instrument) not in instruments:
|
| 55 |
+
continue
|
| 56 |
+
chan = instruments.index(int(instrument))
|
| 57 |
+
label[left:onset_right, n_keys * chan + f] = 3
|
| 58 |
+
label[onset_right:frame_right, n_keys * chan + f] = 2
|
| 59 |
+
label[frame_right:offset_right, n_keys * chan + f] = 1
|
| 60 |
+
|
| 61 |
+
inv_chan = len(instruments)
|
| 62 |
+
label[left:onset_right, n_keys * inv_chan + f] = 3
|
| 63 |
+
label[onset_right:frame_right, n_keys * inv_chan + f] = 2
|
| 64 |
+
label[frame_right:offset_right, n_keys * inv_chan + f] = 1
|
| 65 |
+
|
| 66 |
+
return label
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
Convert piano roll to list of notes, pitch only.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def extract_notes_np_pitch(
|
| 75 |
+
onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5
|
| 76 |
+
):
|
| 77 |
+
onsets = (onsets > onset_threshold).astype(np.uint8)
|
| 78 |
+
frames = (frames > frame_threshold).astype(np.uint8)
|
| 79 |
+
onset_diff = (
|
| 80 |
+
np.concatenate([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], axis=0) == 1
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
pitches = []
|
| 84 |
+
intervals = []
|
| 85 |
+
velocities = []
|
| 86 |
+
|
| 87 |
+
for nonzero in np.transpose(np.nonzero(onset_diff)):
|
| 88 |
+
frame = nonzero[0].item()
|
| 89 |
+
pitch = nonzero[1].item()
|
| 90 |
+
|
| 91 |
+
onset = frame
|
| 92 |
+
offset = frame
|
| 93 |
+
velocity_samples = []
|
| 94 |
+
|
| 95 |
+
while onsets[offset, pitch] or frames[offset, pitch]:
|
| 96 |
+
if onsets[offset, pitch]:
|
| 97 |
+
velocity_samples.append(velocity[offset, pitch])
|
| 98 |
+
offset += 1
|
| 99 |
+
if offset == onsets.shape[0]:
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
if offset > onset:
|
| 103 |
+
pitches.append(pitch)
|
| 104 |
+
intervals.append([onset, offset])
|
| 105 |
+
velocities.append(
|
| 106 |
+
np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
|
| 107 |
+
)
|
| 108 |
+
return np.array(pitches), np.array(intervals), np.array(velocities)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def extract_notes_np_rescaled(
|
| 112 |
+
onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5
|
| 113 |
+
):
|
| 114 |
+
pitches, intervals, velocities, instruments = extract_notes_np(
|
| 115 |
+
onsets, frames, velocity, onset_threshold, frame_threshold
|
| 116 |
+
)
|
| 117 |
+
pitches += MIN_MIDI
|
| 118 |
+
scaling = HOP_LENGTH / SAMPLE_RATE
|
| 119 |
+
intervals = (intervals * scaling).reshape(-1, 2)
|
| 120 |
+
return pitches, intervals, velocities, instruments
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
"""
|
| 124 |
+
Convert piano roll to list of notes, pitch and instrument.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def extract_notes_np(
|
| 129 |
+
onsets,
|
| 130 |
+
frames,
|
| 131 |
+
velocity,
|
| 132 |
+
onset_threshold=0.5,
|
| 133 |
+
frame_threshold=0.5,
|
| 134 |
+
onset_threshold_vec=None,
|
| 135 |
+
):
|
| 136 |
+
if onset_threshold_vec is not None:
|
| 137 |
+
onsets = (onsets > np.array(onset_threshold_vec)).astype(np.uint8)
|
| 138 |
+
else:
|
| 139 |
+
onsets = (onsets > onset_threshold).astype(np.uint8)
|
| 140 |
+
|
| 141 |
+
frames = (frames > frame_threshold).astype(np.uint8)
|
| 142 |
+
onset_diff = (
|
| 143 |
+
np.concatenate([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], axis=0) == 1
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if onsets.shape[-1] != frames.shape[-1]:
|
| 147 |
+
num_instruments = onsets.shape[1] / frames.shape[1]
|
| 148 |
+
assert num_instruments.is_integer()
|
| 149 |
+
num_instruments = int(num_instruments)
|
| 150 |
+
frames = np.tile(frames, (1, num_instruments))
|
| 151 |
+
|
| 152 |
+
pitches = []
|
| 153 |
+
intervals = []
|
| 154 |
+
velocities = []
|
| 155 |
+
instruments = []
|
| 156 |
+
|
| 157 |
+
for nonzero in np.transpose(np.nonzero(onset_diff)):
|
| 158 |
+
frame = nonzero[0].item()
|
| 159 |
+
pitch = nonzero[1].item()
|
| 160 |
+
|
| 161 |
+
onset = frame
|
| 162 |
+
offset = frame
|
| 163 |
+
velocity_samples = []
|
| 164 |
+
|
| 165 |
+
while onsets[offset, pitch] or frames[offset, pitch]:
|
| 166 |
+
if onsets[offset, pitch]:
|
| 167 |
+
velocity_samples.append(velocity[offset, pitch])
|
| 168 |
+
offset += 1
|
| 169 |
+
if offset == onsets.shape[0]:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
if offset > onset:
|
| 173 |
+
pitch, instrument = pitch % N_KEYS, pitch // N_KEYS
|
| 174 |
+
|
| 175 |
+
pitches.append(pitch)
|
| 176 |
+
intervals.append([onset, offset])
|
| 177 |
+
velocities.append(
|
| 178 |
+
np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
|
| 179 |
+
)
|
| 180 |
+
instruments.append(instrument)
|
| 181 |
+
return (
|
| 182 |
+
np.array(pitches),
|
| 183 |
+
np.array(intervals),
|
| 184 |
+
np.array(velocities),
|
| 185 |
+
np.array(instruments),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def append_track_multi(file, pitches, intervals, velocities, ins, single_ins=False):
|
| 190 |
+
track = MidiTrack()
|
| 191 |
+
file.tracks.append(track)
|
| 192 |
+
chan = len(file.tracks) - 1
|
| 193 |
+
if chan >= DRUM_CHANNEL:
|
| 194 |
+
chan += 1
|
| 195 |
+
if chan > 15:
|
| 196 |
+
print(f"invalid chan {chan}")
|
| 197 |
+
chan = 15
|
| 198 |
+
track.append(
|
| 199 |
+
Message(
|
| 200 |
+
"program_change", channel=chan, program=ins if not single_ins else 0, time=0
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
ticks_per_second = file.ticks_per_beat * 2.0
|
| 205 |
+
|
| 206 |
+
events = []
|
| 207 |
+
for i in range(len(pitches)):
|
| 208 |
+
events.append(
|
| 209 |
+
dict(
|
| 210 |
+
type="on",
|
| 211 |
+
pitch=pitches[i],
|
| 212 |
+
time=intervals[i][0],
|
| 213 |
+
velocity=velocities[i],
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
events.append(
|
| 217 |
+
dict(
|
| 218 |
+
type="off",
|
| 219 |
+
pitch=pitches[i],
|
| 220 |
+
time=intervals[i][1],
|
| 221 |
+
velocity=velocities[i],
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
events.sort(key=lambda row: row["time"])
|
| 225 |
+
|
| 226 |
+
last_tick = 0
|
| 227 |
+
for event in events:
|
| 228 |
+
current_tick = int(event["time"] * ticks_per_second)
|
| 229 |
+
velocity = int(event["velocity"] * 127)
|
| 230 |
+
if velocity > 127:
|
| 231 |
+
velocity = 127
|
| 232 |
+
pitch = int(round(hz_to_midi(event["pitch"])))
|
| 233 |
+
track.append(
|
| 234 |
+
Message(
|
| 235 |
+
"note_" + event["type"],
|
| 236 |
+
channel=chan,
|
| 237 |
+
note=pitch,
|
| 238 |
+
velocity=velocity,
|
| 239 |
+
time=current_tick - last_tick,
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
# try:
|
| 243 |
+
# track.append(Message('note_' + event['type'], channel=chan, note=pitch, velocity=velocity, time=current_tick - last_tick))
|
| 244 |
+
# except Exception as e:
|
| 245 |
+
# print('Err Message', 'note_' + event['type'], pitch, velocity, current_tick - last_tick)
|
| 246 |
+
# track.append(Message('note_' + event['type'], channel=chan, note=pitch, velocity=max(0, velocity), time=current_tick - last_tick))
|
| 247 |
+
# if velocity >= 0:
|
| 248 |
+
# raise e
|
| 249 |
+
last_tick = current_tick
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def append_track(file, pitches, intervals, velocities):
|
| 253 |
+
track = MidiTrack()
|
| 254 |
+
file.tracks.append(track)
|
| 255 |
+
ticks_per_second = file.ticks_per_beat * 2.0
|
| 256 |
+
|
| 257 |
+
events = []
|
| 258 |
+
for i in range(len(pitches)):
|
| 259 |
+
events.append(
|
| 260 |
+
dict(
|
| 261 |
+
type="on",
|
| 262 |
+
pitch=pitches[i],
|
| 263 |
+
time=intervals[i][0],
|
| 264 |
+
velocity=velocities[i],
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
events.append(
|
| 268 |
+
dict(
|
| 269 |
+
type="off",
|
| 270 |
+
pitch=pitches[i],
|
| 271 |
+
time=intervals[i][1],
|
| 272 |
+
velocity=velocities[i],
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
events.sort(key=lambda row: row["time"])
|
| 276 |
+
|
| 277 |
+
last_tick = 0
|
| 278 |
+
for event in events:
|
| 279 |
+
current_tick = int(event["time"] * ticks_per_second)
|
| 280 |
+
velocity = int(event["velocity"] * 127)
|
| 281 |
+
if velocity > 127:
|
| 282 |
+
velocity = 127
|
| 283 |
+
pitch = int(round(hz_to_midi(event["pitch"])))
|
| 284 |
+
try:
|
| 285 |
+
track.append(
|
| 286 |
+
Message(
|
| 287 |
+
"note_" + event["type"],
|
| 288 |
+
note=pitch,
|
| 289 |
+
velocity=velocity,
|
| 290 |
+
time=current_tick - last_tick,
|
| 291 |
+
)
|
| 292 |
+
)
|
| 293 |
+
except Exception as e:
|
| 294 |
+
print(
|
| 295 |
+
"Err Message",
|
| 296 |
+
"note_" + event["type"],
|
| 297 |
+
pitch,
|
| 298 |
+
velocity,
|
| 299 |
+
current_tick - last_tick,
|
| 300 |
+
)
|
| 301 |
+
track.append(
|
| 302 |
+
Message(
|
| 303 |
+
"note_" + event["type"],
|
| 304 |
+
note=pitch,
|
| 305 |
+
velocity=max(0, velocity),
|
| 306 |
+
time=current_tick - last_tick,
|
| 307 |
+
)
|
| 308 |
+
)
|
| 309 |
+
if velocity >= 0:
|
| 310 |
+
raise e
|
| 311 |
+
last_tick = current_tick
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def save_midi(path, pitches, intervals, velocities, insts=None):
|
| 315 |
+
"""
|
| 316 |
+
Save extracted notes as a MIDI file
|
| 317 |
+
Parameters
|
| 318 |
+
----------
|
| 319 |
+
path: the path to save the MIDI file
|
| 320 |
+
pitches: np.ndarray of bin_indices
|
| 321 |
+
intervals: list of (onset_index, offset_index)
|
| 322 |
+
velocities: list of velocity values
|
| 323 |
+
"""
|
| 324 |
+
file = MidiFile()
|
| 325 |
+
if isinstance(pitches, list):
|
| 326 |
+
for p, i, v, ins in zip(pitches, intervals, velocities, insts):
|
| 327 |
+
append_track_multi(file, p, i, v, ins)
|
| 328 |
+
else:
|
| 329 |
+
append_track(file, pitches, intervals, velocities)
|
| 330 |
+
file.save(path)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def frames2midi(
|
| 334 |
+
save_path,
|
| 335 |
+
onsets,
|
| 336 |
+
frames,
|
| 337 |
+
vels,
|
| 338 |
+
onset_threshold=0.5,
|
| 339 |
+
frame_threshold=0.5,
|
| 340 |
+
scaling=HOP_LENGTH / SAMPLE_RATE,
|
| 341 |
+
inst_mapping=None,
|
| 342 |
+
onset_threshold_vec=None,
|
| 343 |
+
):
|
| 344 |
+
p_est, i_est, v_est, inst_est = extract_notes_np(
|
| 345 |
+
onsets,
|
| 346 |
+
frames,
|
| 347 |
+
vels,
|
| 348 |
+
onset_threshold,
|
| 349 |
+
frame_threshold,
|
| 350 |
+
onset_threshold_vec=onset_threshold_vec,
|
| 351 |
+
)
|
| 352 |
+
i_est = (i_est * scaling).reshape(-1, 2)
|
| 353 |
+
|
| 354 |
+
p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
|
| 355 |
+
|
| 356 |
+
inst_set = set(inst_est)
|
| 357 |
+
inst_set = sorted(list(inst_set))
|
| 358 |
+
|
| 359 |
+
p_est_lst = {}
|
| 360 |
+
i_est_lst = {}
|
| 361 |
+
v_est_lst = {}
|
| 362 |
+
assert len(p_est) == len(i_est) == len(v_est) == len(inst_est)
|
| 363 |
+
for p, i, v, ins in zip(p_est, i_est, v_est, inst_est):
|
| 364 |
+
if ins in p_est_lst:
|
| 365 |
+
p_est_lst[ins].append(p)
|
| 366 |
+
else:
|
| 367 |
+
p_est_lst[ins] = [p]
|
| 368 |
+
if ins in i_est_lst:
|
| 369 |
+
i_est_lst[ins].append(i)
|
| 370 |
+
else:
|
| 371 |
+
i_est_lst[ins] = [i]
|
| 372 |
+
if ins in v_est_lst:
|
| 373 |
+
v_est_lst[ins].append(v)
|
| 374 |
+
else:
|
| 375 |
+
v_est_lst[ins] = [v]
|
| 376 |
+
for elem in [p_est_lst, i_est_lst, v_est_lst]:
|
| 377 |
+
for k, v in elem.items():
|
| 378 |
+
elem[k] = np.array(v)
|
| 379 |
+
inst_set = [e for e in inst_set if e in p_est_lst]
|
| 380 |
+
# inst_set = [INSTRUMENT_MAPPING[e] for e in inst_set if e in p_est_lst]
|
| 381 |
+
p_est_lst = [p_est_lst[ins] for ins in inst_set if ins in p_est_lst]
|
| 382 |
+
i_est_lst = [i_est_lst[ins] for ins in inst_set if ins in i_est_lst]
|
| 383 |
+
v_est_lst = [v_est_lst[ins] for ins in inst_set if ins in v_est_lst]
|
| 384 |
+
assert len(p_est_lst) == len(i_est_lst) == len(v_est_lst) == len(inst_set)
|
| 385 |
+
inst_set = [inst_mapping[e] for e in inst_set]
|
| 386 |
+
save_midi(save_path, p_est_lst, i_est_lst, v_est_lst, inst_set)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def frames2midi_pitch(
|
| 390 |
+
save_path,
|
| 391 |
+
onsets,
|
| 392 |
+
frames,
|
| 393 |
+
vels,
|
| 394 |
+
onset_threshold=0.5,
|
| 395 |
+
frame_threshold=0.5,
|
| 396 |
+
scaling=HOP_LENGTH / SAMPLE_RATE,
|
| 397 |
+
):
|
| 398 |
+
p_est, i_est, v_est = extract_notes_np_pitch(
|
| 399 |
+
onsets, frames, vels, onset_threshold, frame_threshold
|
| 400 |
+
)
|
| 401 |
+
i_est = (i_est * scaling).reshape(-1, 2)
|
| 402 |
+
p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
|
| 403 |
+
print("Saving midi in", save_path)
|
| 404 |
+
save_midi(save_path, p_est, i_est, v_est)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def parse_midi_multi(path, force_instrument=None):
|
| 408 |
+
"""open midi file and return np.array of (onset, offset, note, velocity, instrument) rows"""
|
| 409 |
+
try:
|
| 410 |
+
midi = mido.MidiFile(path)
|
| 411 |
+
except:
|
| 412 |
+
print("could not open midi", path)
|
| 413 |
+
return
|
| 414 |
+
|
| 415 |
+
time = 0
|
| 416 |
+
|
| 417 |
+
events = []
|
| 418 |
+
|
| 419 |
+
control_changes = []
|
| 420 |
+
program_changes = []
|
| 421 |
+
|
| 422 |
+
sustain = {}
|
| 423 |
+
|
| 424 |
+
all_channels = set()
|
| 425 |
+
|
| 426 |
+
instruments = {} # mapping of channel: instrument
|
| 427 |
+
|
| 428 |
+
for message in midi:
|
| 429 |
+
time += message.time
|
| 430 |
+
if hasattr(message, "channel"):
|
| 431 |
+
if message.channel == DRUM_CHANNEL:
|
| 432 |
+
continue
|
| 433 |
+
|
| 434 |
+
if (
|
| 435 |
+
message.type == "control_change"
|
| 436 |
+
and message.control == 64
|
| 437 |
+
and (message.value >= 64) != sustain.get(message.channel, False)
|
| 438 |
+
):
|
| 439 |
+
sustain[message.channel] = message.value >= 64
|
| 440 |
+
event_type = "sustain_on" if sustain[message.channel] else "sustain_off"
|
| 441 |
+
event = dict(
|
| 442 |
+
index=len(events), time=time, type=event_type, note=None, velocity=0
|
| 443 |
+
)
|
| 444 |
+
event["channel"] = message.channel
|
| 445 |
+
event["sustain"] = sustain[message.channel]
|
| 446 |
+
events.append(event)
|
| 447 |
+
|
| 448 |
+
if message.type == "control_change" and message.control != 64:
|
| 449 |
+
control_changes.append(
|
| 450 |
+
(time, message.control, message.value, message.channel)
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
if message.type == "program_change":
|
| 454 |
+
program_changes.append((time, message.program, message.channel))
|
| 455 |
+
instruments[message.channel] = instruments.get(message.channel, []) + [
|
| 456 |
+
(message.program, time)
|
| 457 |
+
]
|
| 458 |
+
|
| 459 |
+
if "note" in message.type:
|
| 460 |
+
# MIDI offsets can be either 'note_off' events or 'note_on' with zero velocity
|
| 461 |
+
velocity = message.velocity if message.type == "note_on" else 0
|
| 462 |
+
event = dict(
|
| 463 |
+
index=len(events),
|
| 464 |
+
time=time,
|
| 465 |
+
type="note",
|
| 466 |
+
note=message.note,
|
| 467 |
+
velocity=velocity,
|
| 468 |
+
sustain=sustain.get(message.channel, False),
|
| 469 |
+
)
|
| 470 |
+
event["channel"] = message.channel
|
| 471 |
+
events.append(event)
|
| 472 |
+
|
| 473 |
+
if hasattr(message, "channel"):
|
| 474 |
+
all_channels.add(message.channel)
|
| 475 |
+
|
| 476 |
+
if len(instruments) == 0:
|
| 477 |
+
instruments = {c: [(0, 0)] for c in all_channels}
|
| 478 |
+
if len(all_channels) > len(instruments):
|
| 479 |
+
for e in all_channels - set(instruments.keys()):
|
| 480 |
+
instruments[e] = [(0, 0)]
|
| 481 |
+
|
| 482 |
+
if force_instrument is not None:
|
| 483 |
+
instruments = {c: [(force_instrument, 0)] for c in all_channels}
|
| 484 |
+
|
| 485 |
+
this_instruments = set()
|
| 486 |
+
for v in instruments.values():
|
| 487 |
+
this_instruments = this_instruments.union(set(x[0] for x in v))
|
| 488 |
+
|
| 489 |
+
notes = []
|
| 490 |
+
for i, onset in enumerate(events):
|
| 491 |
+
if onset["velocity"] == 0:
|
| 492 |
+
continue
|
| 493 |
+
offset = next(
|
| 494 |
+
n
|
| 495 |
+
for n in events[i + 1 :]
|
| 496 |
+
if (n["note"] == onset["note"] and n["channel"] == onset["channel"])
|
| 497 |
+
or n is events[-1]
|
| 498 |
+
)
|
| 499 |
+
if "sustain" not in offset:
|
| 500 |
+
print("offset without sustain", offset)
|
| 501 |
+
if offset["sustain"] and offset is not events[-1]:
|
| 502 |
+
# if the sustain pedal is active at offset, find when the sustain ends
|
| 503 |
+
offset = next(
|
| 504 |
+
n
|
| 505 |
+
for n in events[offset["index"] + 1 :]
|
| 506 |
+
if (n["type"] == "sustain_off" and n["channel"] == onset["channel"])
|
| 507 |
+
or n is events[-1]
|
| 508 |
+
)
|
| 509 |
+
for k, v in instruments.items():
|
| 510 |
+
if len(set(v)) == 1 and len(v) > 1:
|
| 511 |
+
instruments[k] = list(set(v))
|
| 512 |
+
for k, v in instruments.items():
|
| 513 |
+
instruments[k] = sorted(v, key=lambda x: x[1])
|
| 514 |
+
if len(instruments[onset["channel"]]) == 1:
|
| 515 |
+
instrument = instruments[onset["channel"]][0][0]
|
| 516 |
+
else:
|
| 517 |
+
ind = 0
|
| 518 |
+
while (
|
| 519 |
+
ind < len(instruments[onset["channel"]])
|
| 520 |
+
and onset["time"] >= instruments[onset["channel"]][ind][1]
|
| 521 |
+
):
|
| 522 |
+
ind += 1
|
| 523 |
+
if ind > 0:
|
| 524 |
+
ind -= 1
|
| 525 |
+
instrument = instruments[onset["channel"]][ind][0]
|
| 526 |
+
if onset["channel"] == DRUM_CHANNEL:
|
| 527 |
+
print("skipping drum note")
|
| 528 |
+
continue
|
| 529 |
+
note = (
|
| 530 |
+
onset["time"],
|
| 531 |
+
offset["time"],
|
| 532 |
+
onset["note"],
|
| 533 |
+
onset["velocity"],
|
| 534 |
+
instrument,
|
| 535 |
+
)
|
| 536 |
+
notes.append(note)
|
| 537 |
+
|
| 538 |
+
res = np.array(notes)
|
| 539 |
+
return res
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def save_midi_alignments_and_predictions(
|
| 543 |
+
save_path,
|
| 544 |
+
data_path,
|
| 545 |
+
inst_mapping,
|
| 546 |
+
aligned_onsets,
|
| 547 |
+
aligned_frames,
|
| 548 |
+
onset_pred_np,
|
| 549 |
+
frame_pred_np,
|
| 550 |
+
prefix="",
|
| 551 |
+
use_time=True,
|
| 552 |
+
group=None,
|
| 553 |
+
):
|
| 554 |
+
inst_only = len(inst_mapping) * N_KEYS
|
| 555 |
+
time_now = datetime.now().strftime("%y%m%d-%H%M%S") if use_time else ""
|
| 556 |
+
if len(prefix) > 0:
|
| 557 |
+
prefix = "_{}".format(prefix)
|
| 558 |
+
|
| 559 |
+
# Save the aligned label. If training on a small dataset or a single performance in order to label it for later adding it
|
| 560 |
+
# to a large dataset, it is recommended to use this MIDI as a label.
|
| 561 |
+
frames2midi(
|
| 562 |
+
save_path
|
| 563 |
+
+ os.sep
|
| 564 |
+
+ data_path.replace(".flac", "").split(os.sep)[-1]
|
| 565 |
+
+ prefix
|
| 566 |
+
+ "_alignment_"
|
| 567 |
+
+ time_now
|
| 568 |
+
+ ".mid",
|
| 569 |
+
aligned_onsets[:, :inst_only],
|
| 570 |
+
aligned_frames[:, :inst_only],
|
| 571 |
+
64.0 * aligned_onsets[:, :inst_only],
|
| 572 |
+
inst_mapping=inst_mapping,
|
| 573 |
+
)
|
| 574 |
+
return
|
| 575 |
+
|
| 576 |
+
# # Aligned label, pitch-only, on the piano.
|
| 577 |
+
# frames2midi_pitch(save_path + os.sep + data_path.replace('.flac', '').split(os.sep)[-1] + prefix + '_alignment_pitch_' + time_now + '.mid',
|
| 578 |
+
# aligned_onsets[:, -N_KEYS:], aligned_frames[:, -N_KEYS:],
|
| 579 |
+
# 64. * aligned_onsets[:, -N_KEYS:])
|
| 580 |
+
|
| 581 |
+
predicted_onsets = onset_pred_np >= 0.5
|
| 582 |
+
predicted_frames = frame_pred_np >= 0.5
|
| 583 |
+
|
| 584 |
+
# # Raw pitch with instrument prediction - will probably have lower recall, depending on the model's strength.
|
| 585 |
+
# frames2midi(save_path + os.sep + data_path.replace('.flac', '').split(os.sep)[-1] + prefix + '_pred_' + time_now + '.mid',
|
| 586 |
+
# predicted_onsets[:, : inst_only], predicted_frames[:, : inst_only],
|
| 587 |
+
# 64. * predicted_onsets[:, : inst_only],
|
| 588 |
+
# inst_mapping=inst_mapping)
|
| 589 |
+
|
| 590 |
+
# Pitch prediction played on the piano - will have high recall, since it does not differentiate between instruments.
|
| 591 |
+
frames2midi_pitch(
|
| 592 |
+
save_path
|
| 593 |
+
+ os.sep
|
| 594 |
+
+ data_path.replace(".flac", "").split(os.sep)[-1]
|
| 595 |
+
+ prefix
|
| 596 |
+
+ "_pred_pitch_"
|
| 597 |
+
+ time_now
|
| 598 |
+
+ ".mid",
|
| 599 |
+
predicted_onsets[:, -N_KEYS:],
|
| 600 |
+
predicted_frames[:, -N_KEYS:],
|
| 601 |
+
64.0 * predicted_onsets[:, -N_KEYS:],
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
# Pitch prediction, with choice of most likely instrument for each detected note.
|
| 605 |
+
if len(inst_mapping) > 1:
|
| 606 |
+
max_pred_onsets = max_inst(onset_pred_np)
|
| 607 |
+
frames2midi(
|
| 608 |
+
save_path
|
| 609 |
+
+ os.sep
|
| 610 |
+
+ data_path.replace(".flac", "").split(os.sep)[-1]
|
| 611 |
+
+ prefix
|
| 612 |
+
+ "_pred_inst_"
|
| 613 |
+
+ time_now
|
| 614 |
+
+ ".mid",
|
| 615 |
+
max_pred_onsets[:, :inst_only],
|
| 616 |
+
predicted_frames[:, :inst_only],
|
| 617 |
+
64.0 * max_pred_onsets[:, :inst_only],
|
| 618 |
+
inst_mapping=inst_mapping,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
pseudo_onsets = (onset_pred_np >= 0.5) & (~aligned_onsets)
|
| 622 |
+
onset_label = np.maximum(pseudo_onsets, aligned_onsets)
|
| 623 |
+
|
| 624 |
+
pseudo_frames = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
|
| 625 |
+
for t, f in zip(*onset_label.nonzero()):
|
| 626 |
+
t_off = t
|
| 627 |
+
while t_off < len(pseudo_frames) and frame_pred_np[t_off, f % N_KEYS] >= 0.5:
|
| 628 |
+
t_off += 1
|
| 629 |
+
pseudo_frames[t:t_off, f] = 1
|
| 630 |
+
frame_label = np.maximum(pseudo_frames, aligned_frames)
|
| 631 |
+
|
| 632 |
+
# pseudo_frames = (frame_pred_np >= 0.5) & (~aligned_frames)
|
| 633 |
+
# frame_label = np.maximum(pseudo_frames, aligned_frames)
|
| 634 |
+
|
| 635 |
+
frames2midi(
|
| 636 |
+
save_path
|
| 637 |
+
+ os.sep
|
| 638 |
+
+ data_path.replace(".flac", "").split(os.sep)[-1]
|
| 639 |
+
+ prefix
|
| 640 |
+
+ "_pred_align_max_"
|
| 641 |
+
+ time_now
|
| 642 |
+
+ ".mid",
|
| 643 |
+
onset_label[:, :inst_only],
|
| 644 |
+
frame_label[:, :inst_only],
|
| 645 |
+
64.0 * onset_label[:, :inst_only],
|
| 646 |
+
inst_mapping=inst_mapping,
|
| 647 |
+
)
|
| 648 |
+
# if group is not None:
|
| 649 |
+
# gorup_path = os.path.join(save_path, 'pred_alignment_max', group)
|
| 650 |
+
# file_name = os.path.basename(data_path).replace('.flac', '_pred_align_max.mid')
|
| 651 |
+
# os.makedirs(gorup_path, exist_ok=True)
|
| 652 |
+
# frames2midi(os.path.join(gorup_path, file_name),
|
| 653 |
+
# onset_label[:, : inst_only], frame_label[:, : inst_only],
|
| 654 |
+
# 64. * onset_label[:, : inst_only],
|
| 655 |
+
# inst_mapping=inst_mapping)
|
onsets_and_frames/transcriber.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from onsets_and_frames.constants import MAX_MIDI, MIN_MIDI, N_KEYS
|
| 6 |
+
|
| 7 |
+
from .lstm import BiLSTM
|
| 8 |
+
from .mel import melspectrogram
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConvStack(nn.Module):
|
| 12 |
+
def __init__(self, input_features, output_features):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
# input is batch_size * 1 channel * frames * input_features
|
| 16 |
+
self.cnn = nn.Sequential(
|
| 17 |
+
# layer 0
|
| 18 |
+
nn.Conv2d(1, output_features // 16, (3, 3), padding=1),
|
| 19 |
+
nn.BatchNorm2d(output_features // 16),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
# layer 1
|
| 22 |
+
nn.Conv2d(output_features // 16, output_features // 16, (3, 3), padding=1),
|
| 23 |
+
nn.BatchNorm2d(output_features // 16),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
# layer 2
|
| 26 |
+
nn.MaxPool2d((1, 2)),
|
| 27 |
+
nn.Dropout(0.25),
|
| 28 |
+
nn.Conv2d(output_features // 16, output_features // 8, (3, 3), padding=1),
|
| 29 |
+
nn.BatchNorm2d(output_features // 8),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
# layer 3
|
| 32 |
+
nn.MaxPool2d((1, 2)),
|
| 33 |
+
nn.Dropout(0.25),
|
| 34 |
+
)
|
| 35 |
+
self.fc = nn.Sequential(
|
| 36 |
+
nn.Linear((output_features // 8) * (input_features // 4), output_features),
|
| 37 |
+
nn.Dropout(0.5),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, mel):
|
| 41 |
+
x = mel.view(mel.size(0), 1, mel.size(1), mel.size(2))
|
| 42 |
+
x = self.cnn(x)
|
| 43 |
+
x = x.transpose(1, 2).flatten(-2)
|
| 44 |
+
x = self.fc(x)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class OnsetsAndFrames(nn.Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
input_features,
|
| 52 |
+
output_features,
|
| 53 |
+
model_complexity=48,
|
| 54 |
+
onset_complexity=1,
|
| 55 |
+
n_instruments=13,
|
| 56 |
+
):
|
| 57 |
+
nn.Module.__init__(self)
|
| 58 |
+
model_size = model_complexity * 16
|
| 59 |
+
sequence_model = lambda input_size, output_size: BiLSTM(
|
| 60 |
+
input_size, output_size // 2
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
onset_model_size = int(onset_complexity * model_size)
|
| 64 |
+
self.onset_stack = nn.Sequential(
|
| 65 |
+
ConvStack(input_features, onset_model_size),
|
| 66 |
+
sequence_model(onset_model_size, onset_model_size),
|
| 67 |
+
nn.Linear(onset_model_size, output_features * n_instruments),
|
| 68 |
+
nn.Sigmoid(),
|
| 69 |
+
)
|
| 70 |
+
self.offset_stack = nn.Sequential(
|
| 71 |
+
ConvStack(input_features, model_size),
|
| 72 |
+
sequence_model(model_size, model_size),
|
| 73 |
+
nn.Linear(model_size, output_features),
|
| 74 |
+
nn.Sigmoid(),
|
| 75 |
+
)
|
| 76 |
+
self.frame_stack = nn.Sequential(
|
| 77 |
+
ConvStack(input_features, model_size),
|
| 78 |
+
nn.Linear(model_size, output_features),
|
| 79 |
+
nn.Sigmoid(),
|
| 80 |
+
)
|
| 81 |
+
self.combined_stack = nn.Sequential(
|
| 82 |
+
sequence_model(output_features * 3, model_size),
|
| 83 |
+
nn.Linear(model_size, output_features),
|
| 84 |
+
nn.Sigmoid(),
|
| 85 |
+
)
|
| 86 |
+
self.velocity_stack = nn.Sequential(
|
| 87 |
+
ConvStack(input_features, model_size),
|
| 88 |
+
nn.Linear(model_size, output_features * n_instruments),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, mel):
|
| 92 |
+
onset_pred = self.onset_stack(mel)
|
| 93 |
+
offset_pred = self.offset_stack(mel)
|
| 94 |
+
activation_pred = self.frame_stack(mel)
|
| 95 |
+
|
| 96 |
+
onset_detached = onset_pred.detach()
|
| 97 |
+
shape = onset_detached.shape
|
| 98 |
+
keys = MAX_MIDI - MIN_MIDI + 1
|
| 99 |
+
new_shape = shape[:-1] + (shape[-1] // keys, keys)
|
| 100 |
+
onset_detached = onset_detached.reshape(new_shape)
|
| 101 |
+
onset_detached, _ = onset_detached.max(axis=-2)
|
| 102 |
+
|
| 103 |
+
offset_detached = offset_pred.detach()
|
| 104 |
+
|
| 105 |
+
combined_pred = torch.cat(
|
| 106 |
+
[onset_detached, offset_detached, activation_pred], dim=-1
|
| 107 |
+
)
|
| 108 |
+
frame_pred = self.combined_stack(combined_pred)
|
| 109 |
+
velocity_pred = self.velocity_stack(mel)
|
| 110 |
+
return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred
|
| 111 |
+
|
| 112 |
+
def run_on_batch(
|
| 113 |
+
self,
|
| 114 |
+
batch,
|
| 115 |
+
parallel_model=None,
|
| 116 |
+
positive_weight=2.0,
|
| 117 |
+
inv_positive_weight=2.0,
|
| 118 |
+
with_onset_mask=False,
|
| 119 |
+
):
|
| 120 |
+
audio_label = batch["audio"]
|
| 121 |
+
|
| 122 |
+
onset_label = batch["onset"]
|
| 123 |
+
offset_label = batch["offset"]
|
| 124 |
+
frame_label = batch["frame"]
|
| 125 |
+
if "velocity" in batch:
|
| 126 |
+
velocity_label = batch["velocity"]
|
| 127 |
+
mel = melspectrogram(
|
| 128 |
+
audio_label.reshape(-1, audio_label.shape[-1])[:, :-1]
|
| 129 |
+
).transpose(-1, -2)
|
| 130 |
+
|
| 131 |
+
if not parallel_model:
|
| 132 |
+
onset_pred, offset_pred, _, frame_pred, velocity_pred = self(mel)
|
| 133 |
+
else:
|
| 134 |
+
onset_pred, offset_pred, _, frame_pred, velocity_pred = parallel_model(mel)
|
| 135 |
+
|
| 136 |
+
predictions = {
|
| 137 |
+
"onset": onset_pred.reshape(*onset_label.shape),
|
| 138 |
+
"offset": offset_pred.reshape(*offset_label.shape),
|
| 139 |
+
"frame": frame_pred.reshape(*frame_label.shape),
|
| 140 |
+
# 'velocity': velocity_pred.reshape(*velocity_label.shape)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if "velocity" in batch:
|
| 144 |
+
predictions["velocity"] = velocity_pred.reshape(*velocity_label.shape)
|
| 145 |
+
|
| 146 |
+
losses = {
|
| 147 |
+
"loss/onset": F.binary_cross_entropy(
|
| 148 |
+
predictions["onset"], onset_label, reduction="none"
|
| 149 |
+
),
|
| 150 |
+
"loss/offset": F.binary_cross_entropy(
|
| 151 |
+
predictions["offset"], offset_label, reduction="none"
|
| 152 |
+
),
|
| 153 |
+
"loss/frame": F.binary_cross_entropy(
|
| 154 |
+
predictions["frame"], frame_label, reduction="none"
|
| 155 |
+
),
|
| 156 |
+
# 'loss/velocity': self.velocity_loss(predictions['velocity'], velocity_label, onset_label)
|
| 157 |
+
}
|
| 158 |
+
if "velocity" in batch:
|
| 159 |
+
losses["loss/velocity"] = self.velocity_loss(
|
| 160 |
+
predictions["velocity"], velocity_label, onset_label
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
onset_mask = 1.0 * onset_label
|
| 164 |
+
onset_mask[..., :-N_KEYS] *= positive_weight - 1
|
| 165 |
+
onset_mask[..., -N_KEYS:] *= inv_positive_weight - 1
|
| 166 |
+
onset_mask += 1
|
| 167 |
+
if with_onset_mask:
|
| 168 |
+
if "onset_mask" in batch:
|
| 169 |
+
onset_mask = onset_mask * batch["onset_mask"]
|
| 170 |
+
# if 'onset_mask' in batch:
|
| 171 |
+
# onset_mask += batch['onset_mask']
|
| 172 |
+
|
| 173 |
+
offset_mask = 1.0 * offset_label
|
| 174 |
+
offset_positive_weight = 2.0
|
| 175 |
+
offset_mask *= offset_positive_weight - 1
|
| 176 |
+
offset_mask += 1.0
|
| 177 |
+
|
| 178 |
+
frame_mask = 1.0 * frame_label
|
| 179 |
+
frame_positive_weight = 2.0
|
| 180 |
+
frame_mask *= frame_positive_weight - 1
|
| 181 |
+
frame_mask += 1.0
|
| 182 |
+
|
| 183 |
+
for loss_key, mask in zip(
|
| 184 |
+
["onset", "offset", "frame"], [onset_mask, offset_mask, frame_mask]
|
| 185 |
+
):
|
| 186 |
+
losses["loss/" + loss_key] = (mask * losses["loss/" + loss_key]).mean()
|
| 187 |
+
|
| 188 |
+
return predictions, losses
|
| 189 |
+
|
| 190 |
+
def velocity_loss(self, velocity_pred, velocity_label, onset_label):
|
| 191 |
+
denominator = onset_label.sum()
|
| 192 |
+
if denominator.item() == 0:
|
| 193 |
+
return denominator
|
| 194 |
+
else:
|
| 195 |
+
return (
|
| 196 |
+
onset_label * (velocity_label - velocity_pred) ** 2
|
| 197 |
+
).sum() / denominator
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# same implementation as OnsetsAndFrames, but with only onset stack
|
| 201 |
+
class OnsetsNoFrames(nn.Module):
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
input_features,
|
| 205 |
+
output_features,
|
| 206 |
+
model_complexity=48,
|
| 207 |
+
onset_complexity=1,
|
| 208 |
+
n_instruments=13,
|
| 209 |
+
):
|
| 210 |
+
nn.Module.__init__(self)
|
| 211 |
+
model_size = model_complexity * 16
|
| 212 |
+
sequence_model = lambda input_size, output_size: BiLSTM(
|
| 213 |
+
input_size, output_size // 2
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
onset_model_size = int(onset_complexity * model_size)
|
| 217 |
+
self.onset_stack = nn.Sequential(
|
| 218 |
+
ConvStack(input_features, onset_model_size),
|
| 219 |
+
sequence_model(onset_model_size, onset_model_size),
|
| 220 |
+
nn.Linear(onset_model_size, output_features * n_instruments),
|
| 221 |
+
nn.Sigmoid(),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(self, mel):
|
| 225 |
+
onset_pred = self.onset_stack(mel)
|
| 226 |
+
|
| 227 |
+
onset_detached = onset_pred.detach()
|
| 228 |
+
shape = onset_detached.shape
|
| 229 |
+
keys = MAX_MIDI - MIN_MIDI + 1
|
| 230 |
+
new_shape = shape[:-1] + (shape[-1] // keys, keys)
|
| 231 |
+
onset_detached = onset_detached.reshape(new_shape)
|
| 232 |
+
onset_detached, _ = onset_detached.max(axis=-2)
|
| 233 |
+
|
| 234 |
+
return onset_pred
|
| 235 |
+
|
| 236 |
+
def run_on_batch(
|
| 237 |
+
self,
|
| 238 |
+
batch,
|
| 239 |
+
parallel_model=None,
|
| 240 |
+
positive_weight=2.0,
|
| 241 |
+
inv_positive_weight=2.0,
|
| 242 |
+
with_onset_mask=False,
|
| 243 |
+
):
|
| 244 |
+
audio_label = batch["audio"]
|
| 245 |
+
|
| 246 |
+
onset_label = batch["onset"]
|
| 247 |
+
mel = melspectrogram(
|
| 248 |
+
audio_label.reshape(-1, audio_label.shape[-1])[:, :-1]
|
| 249 |
+
).transpose(-1, -2)
|
| 250 |
+
|
| 251 |
+
if not parallel_model:
|
| 252 |
+
onset_pred = self(mel)
|
| 253 |
+
else:
|
| 254 |
+
onset_pred = parallel_model(mel)
|
| 255 |
+
|
| 256 |
+
predictions = {
|
| 257 |
+
"onset": onset_pred,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
losses = {
|
| 261 |
+
"loss/onset": F.binary_cross_entropy(
|
| 262 |
+
predictions["onset"], onset_label, reduction="none"
|
| 263 |
+
),
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
onset_mask = 1.0 * onset_label
|
| 267 |
+
onset_mask[..., :-N_KEYS] *= positive_weight - 1
|
| 268 |
+
onset_mask[..., -N_KEYS:] *= inv_positive_weight - 1
|
| 269 |
+
onset_mask += 1
|
| 270 |
+
if with_onset_mask:
|
| 271 |
+
if "onset_mask" in batch:
|
| 272 |
+
onset_mask = onset_mask * batch["onset_mask"]
|
| 273 |
+
|
| 274 |
+
losses["loss/onset"] = (onset_mask * losses["loss/onset"]).mean()
|
| 275 |
+
|
| 276 |
+
return predictions, losses
|
onsets_and_frames/utils.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from onsets_and_frames.constants import (
|
| 9 |
+
DTW_FACTOR,
|
| 10 |
+
HOP_LENGTH,
|
| 11 |
+
MAX_MIDI,
|
| 12 |
+
MIN_MIDI,
|
| 13 |
+
N_KEYS,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cycle(iterable):
|
| 18 |
+
while True:
|
| 19 |
+
for item in iterable:
|
| 20 |
+
yield item
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def shift_label(label, shift):
|
| 24 |
+
if shift == 0:
|
| 25 |
+
return label
|
| 26 |
+
assert len(label.shape) == 2
|
| 27 |
+
t, p = label.shape
|
| 28 |
+
keys, instruments = N_KEYS, p // N_KEYS
|
| 29 |
+
label_zero_pad = torch.zeros(t, instruments, abs(shift), dtype=label.dtype)
|
| 30 |
+
label = label.reshape(t, instruments, keys)
|
| 31 |
+
to_cat = (
|
| 32 |
+
(label_zero_pad, label[:, :, :-shift])
|
| 33 |
+
if shift > 0
|
| 34 |
+
else (label[:, :, -shift:], label_zero_pad)
|
| 35 |
+
)
|
| 36 |
+
label = torch.cat(to_cat, dim=-1)
|
| 37 |
+
return label.reshape(t, p)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_peaks(notes, win_size, gpu=False):
|
| 41 |
+
constraints = []
|
| 42 |
+
notes = notes.cpu()
|
| 43 |
+
for i in range(1, win_size + 1):
|
| 44 |
+
forward = torch.roll(notes, i, 0)
|
| 45 |
+
forward[:i, ...] = 0 # assume time axis is 0
|
| 46 |
+
backward = torch.roll(notes, -i, 0)
|
| 47 |
+
backward[-i:, ...] = 0
|
| 48 |
+
constraints.extend([forward, backward])
|
| 49 |
+
res = torch.ones(notes.shape, dtype=bool)
|
| 50 |
+
for elem in constraints:
|
| 51 |
+
res = res & (notes >= elem)
|
| 52 |
+
return res if not gpu else res.cuda()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_peaks_numpy(notes, win_size):
|
| 56 |
+
"""
|
| 57 |
+
Detect peaks in a NumPy array based on a window size.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
notes (np.ndarray): Input array, shape (frames, ...).
|
| 61 |
+
win_size (int): Window size for detecting peaks.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
np.ndarray: Boolean array indicating peaks, same shape as `notes`.
|
| 65 |
+
"""
|
| 66 |
+
# Initialize constraints
|
| 67 |
+
constraints = []
|
| 68 |
+
notes = np.array(notes) # Ensure input is a NumPy array
|
| 69 |
+
|
| 70 |
+
for i in range(1, win_size + 1):
|
| 71 |
+
# Roll array forward and backward
|
| 72 |
+
forward = np.roll(notes, i, axis=0)
|
| 73 |
+
backward = np.roll(notes, -i, axis=0)
|
| 74 |
+
|
| 75 |
+
# Zero out invalid regions
|
| 76 |
+
forward[:i, ...] = 0
|
| 77 |
+
backward[-i:, ...] = 0
|
| 78 |
+
|
| 79 |
+
constraints.extend([forward, backward])
|
| 80 |
+
|
| 81 |
+
# Initialize result with all True
|
| 82 |
+
res = np.ones_like(notes, dtype=bool)
|
| 83 |
+
|
| 84 |
+
# Apply constraints
|
| 85 |
+
for elem in constraints:
|
| 86 |
+
res &= notes >= elem
|
| 87 |
+
|
| 88 |
+
return res
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_diff(notes, offset=True):
|
| 92 |
+
rolled = np.roll(notes, 1, axis=0)
|
| 93 |
+
rolled[0, ...] = 0
|
| 94 |
+
return (rolled & (~notes)) if offset else (notes & (~rolled))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def compress_across_octave(notes):
|
| 98 |
+
keys = MAX_MIDI - MIN_MIDI + 1
|
| 99 |
+
time, instruments = notes.shape[0], notes.shape[1] // keys
|
| 100 |
+
notes_reshaped = notes.reshape((time, instruments, keys))
|
| 101 |
+
notes_reshaped = notes_reshaped.max(axis=1)
|
| 102 |
+
octaves = keys // 12
|
| 103 |
+
res = np.zeros((time, 12), dtype=np.uint8)
|
| 104 |
+
for i in range(octaves):
|
| 105 |
+
curr_octave = notes_reshaped[:, i * 12 : (i + 1) * 12]
|
| 106 |
+
res = np.maximum(res, curr_octave)
|
| 107 |
+
return res
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def compress_time(notes, factor):
|
| 111 |
+
t, p = notes.shape
|
| 112 |
+
res = np.zeros((t // factor, p), dtype=notes.dtype)
|
| 113 |
+
for i in range(t // factor):
|
| 114 |
+
res[i, :] = notes[i * factor : (i + 1) * factor, :].max(axis=0)
|
| 115 |
+
return res
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_matches(index1, index2):
|
| 119 |
+
matches = {}
|
| 120 |
+
for i1, i2 in zip(index1, index2):
|
| 121 |
+
# matches[i1] = matches.get(i1, []) + [i2]
|
| 122 |
+
if i1 not in matches:
|
| 123 |
+
matches[i1] = []
|
| 124 |
+
matches[i1].append(i2)
|
| 125 |
+
return matches
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
Extend a temporal range to WINDOW_SIZE_SRC if it is shorter than that.
|
| 130 |
+
WINDOW_SIZE_SRC defaults to 28 frames for 256 hop length (assuming DTW_FACTOR=3), which is ~0.5 second.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_margin(
|
| 135 |
+
t_sources, max_len, WINDOW_SIZE_SRC=11 * (512 // HOP_LENGTH) + 2 * DTW_FACTOR
|
| 136 |
+
):
|
| 137 |
+
margin = max(0, (WINDOW_SIZE_SRC - len(t_sources)) // 2)
|
| 138 |
+
t_sources_left = list(range(max(t_sources[0] - margin, 0), t_sources[0]))
|
| 139 |
+
t_sources_right = list(
|
| 140 |
+
range(t_sources[-1], min(t_sources[-1] + margin, max_len - 1))
|
| 141 |
+
)
|
| 142 |
+
t_sources_extended = t_sources_left + t_sources + t_sources_right
|
| 143 |
+
return t_sources_extended
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_inactive_instruments(target_onsets, T):
|
| 147 |
+
keys = MAX_MIDI - MIN_MIDI + 1
|
| 148 |
+
time, instruments = target_onsets.shape[0], target_onsets.shape[1] // keys
|
| 149 |
+
notes_reshaped = target_onsets.reshape((time, instruments, keys))
|
| 150 |
+
active_instruments = notes_reshaped.max(axis=(0, 2))
|
| 151 |
+
res = np.zeros((T, instruments, keys), dtype=bool)
|
| 152 |
+
for ins in range(instruments):
|
| 153 |
+
if active_instruments[ins] == 0:
|
| 154 |
+
res[:, ins, :] = 1
|
| 155 |
+
return res.reshape((T, instruments * keys)), active_instruments
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def max_inst(probs, threshold_vec=None):
|
| 159 |
+
if threshold_vec is None:
|
| 160 |
+
threshold_vec = 0.5
|
| 161 |
+
if probs.shape[-1] == N_KEYS or probs.shape[-1] == N_KEYS * 2:
|
| 162 |
+
# there is only pitch
|
| 163 |
+
return probs
|
| 164 |
+
keys = MAX_MIDI - MIN_MIDI + 1
|
| 165 |
+
instruments = probs.shape[1] // keys
|
| 166 |
+
time = len(probs)
|
| 167 |
+
probs = probs.reshape((time, instruments, keys))
|
| 168 |
+
notes = probs.max(axis=1) >= threshold_vec
|
| 169 |
+
max_instruments = np.argmax(probs[:, :-1, :], axis=1)
|
| 170 |
+
res = np.zeros(probs.shape, dtype=np.uint8)
|
| 171 |
+
for t, p in zip(*(notes.nonzero())):
|
| 172 |
+
res[t, max_instruments[t, p], p] = 1
|
| 173 |
+
res[t, -1, p] = 1
|
| 174 |
+
return res.reshape((time, instruments * keys))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# Define the smoothing function (operates on CPU)
|
| 178 |
+
def smooth_labels(onset_tensor):
|
| 179 |
+
"""
|
| 180 |
+
Smooths onset labels using a triangular kernel with 1D convolution along the time axis.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
onset_tensor (torch.Tensor): A (T, F) tensor where T = time steps and F = pitches.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
torch.Tensor: Smoothed onset tensor with the same shape (T, F).
|
| 187 |
+
"""
|
| 188 |
+
# Define the triangular smoothing kernel
|
| 189 |
+
# kernel = torch.tensor([0.2, 0.4, 0.6, 0.8, 1, 0.8, 0.6, 0.4, 0.2],
|
| 190 |
+
# dtype=onset_tensor.dtype).view(1, 1, -1)
|
| 191 |
+
# kernel = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
|
| 192 |
+
# dtype=onset_tensor.dtype).view(1, 1, -1)
|
| 193 |
+
kernel = torch.tensor([0.33, 0.67, 1, 0.67, 0.33], dtype=onset_tensor.dtype).view(
|
| 194 |
+
1, 1, -1
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
onset_tensor = onset_tensor.T.unsqueeze(1) # Now shape is (F, 1, T)
|
| 198 |
+
|
| 199 |
+
# Use 'same' padding so that the output has the same time dimension as the input.
|
| 200 |
+
padding = kernel.shape[-1] // 2
|
| 201 |
+
smoothed = F.conv1d(onset_tensor, kernel, padding=padding)
|
| 202 |
+
|
| 203 |
+
# Reshape back to original shape (T, F)
|
| 204 |
+
return smoothed.squeeze(1).T
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def initialize_logging_system(logdir):
|
| 208 |
+
"""Initialize the logging system once with named loggers for train and dataset."""
|
| 209 |
+
log_file = os.path.join(logdir, "training.log")
|
| 210 |
+
|
| 211 |
+
# Create formatter
|
| 212 |
+
formatter = logging.Formatter(
|
| 213 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# File handler (shared by all loggers)
|
| 217 |
+
file_handler = logging.FileHandler(log_file)
|
| 218 |
+
file_handler.setLevel(logging.INFO)
|
| 219 |
+
file_handler.setFormatter(formatter)
|
| 220 |
+
|
| 221 |
+
# Console handler (shared by all loggers)
|
| 222 |
+
console_handler = logging.StreamHandler()
|
| 223 |
+
console_handler.setLevel(logging.INFO)
|
| 224 |
+
console_handler.setFormatter(formatter)
|
| 225 |
+
|
| 226 |
+
# Create train logger
|
| 227 |
+
train_logger = logging.getLogger("train")
|
| 228 |
+
train_logger.setLevel(logging.INFO)
|
| 229 |
+
train_logger.handlers.clear()
|
| 230 |
+
train_logger.addHandler(file_handler)
|
| 231 |
+
train_logger.addHandler(console_handler)
|
| 232 |
+
|
| 233 |
+
# Create dataset logger
|
| 234 |
+
dataset_logger = logging.getLogger("dataset")
|
| 235 |
+
dataset_logger.setLevel(logging.INFO)
|
| 236 |
+
dataset_logger.handlers.clear()
|
| 237 |
+
dataset_logger.addHandler(file_handler)
|
| 238 |
+
dataset_logger.addHandler(console_handler)
|
| 239 |
+
|
| 240 |
+
return train_logger, dataset_logger
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def get_logger(name):
|
| 244 |
+
"""Get a named logger. Call initialize_logging_system first."""
|
| 245 |
+
return logging.getLogger(name)
|