Yoni232's picture
added source code of model and transcription scripts
05d6e12
import os
from datetime import datetime
import mido
import numpy as np
import torch
from mido import Message, MidiFile, MidiTrack
from onsets_and_frames.constants import (
DRUM_CHANNEL,
HOP_LENGTH,
HOPS_IN_OFFSET,
HOPS_IN_ONSET,
MAX_MIDI,
MIN_MIDI,
N_KEYS,
SAMPLE_RATE,
)
from .utils import max_inst
def midi_to_hz(m):
return 440.0 * (2.0 ** ((m - 69.0) / 12.0))
def hz_to_midi(h):
return 12.0 * np.log2(h / (440.0)) + 69.0
def midi_to_frames(midi, instruments, conversion_map=None):
n_keys = MAX_MIDI - MIN_MIDI + 1
midi_length = int((max(midi[:, 1]) + 1) * SAMPLE_RATE)
n_steps = (midi_length - 1) // HOP_LENGTH + 1
n_channels = len(instruments) + 1
label = torch.zeros(n_steps, n_keys * n_channels, dtype=torch.uint8)
for onset, offset, note, vel, instrument in midi:
f = int(note) - MIN_MIDI
if 104 > instrument > 87 or instrument > 111:
continue
if f >= n_keys or f < 0:
continue
assert 0 < vel < 128
instrument = int(instrument)
if conversion_map is not None:
if instrument not in conversion_map:
continue
instrument = conversion_map[instrument]
left = int(round(onset * SAMPLE_RATE / HOP_LENGTH))
onset_right = min(n_steps, left + HOPS_IN_ONSET)
frame_right = int(round(offset * SAMPLE_RATE / HOP_LENGTH))
frame_right = min(n_steps, frame_right)
offset_right = min(n_steps, frame_right + HOPS_IN_OFFSET)
if int(instrument) not in instruments:
continue
chan = instruments.index(int(instrument))
label[left:onset_right, n_keys * chan + f] = 3
label[onset_right:frame_right, n_keys * chan + f] = 2
label[frame_right:offset_right, n_keys * chan + f] = 1
inv_chan = len(instruments)
label[left:onset_right, n_keys * inv_chan + f] = 3
label[onset_right:frame_right, n_keys * inv_chan + f] = 2
label[frame_right:offset_right, n_keys * inv_chan + f] = 1
return label
"""
Convert piano roll to list of notes, pitch only.
"""
def extract_notes_np_pitch(
onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5
):
onsets = (onsets > onset_threshold).astype(np.uint8)
frames = (frames > frame_threshold).astype(np.uint8)
onset_diff = (
np.concatenate([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], axis=0) == 1
)
pitches = []
intervals = []
velocities = []
for nonzero in np.transpose(np.nonzero(onset_diff)):
frame = nonzero[0].item()
pitch = nonzero[1].item()
onset = frame
offset = frame
velocity_samples = []
while onsets[offset, pitch] or frames[offset, pitch]:
if onsets[offset, pitch]:
velocity_samples.append(velocity[offset, pitch])
offset += 1
if offset == onsets.shape[0]:
break
if offset > onset:
pitches.append(pitch)
intervals.append([onset, offset])
velocities.append(
np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
)
return np.array(pitches), np.array(intervals), np.array(velocities)
def extract_notes_np_rescaled(
onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5
):
pitches, intervals, velocities, instruments = extract_notes_np(
onsets, frames, velocity, onset_threshold, frame_threshold
)
pitches += MIN_MIDI
scaling = HOP_LENGTH / SAMPLE_RATE
intervals = (intervals * scaling).reshape(-1, 2)
return pitches, intervals, velocities, instruments
"""
Convert piano roll to list of notes, pitch and instrument.
"""
def extract_notes_np(
onsets,
frames,
velocity,
onset_threshold=0.5,
frame_threshold=0.5,
onset_threshold_vec=None,
):
if onset_threshold_vec is not None:
onsets = (onsets > np.array(onset_threshold_vec)).astype(np.uint8)
else:
onsets = (onsets > onset_threshold).astype(np.uint8)
frames = (frames > frame_threshold).astype(np.uint8)
onset_diff = (
np.concatenate([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], axis=0) == 1
)
if onsets.shape[-1] != frames.shape[-1]:
num_instruments = onsets.shape[1] / frames.shape[1]
assert num_instruments.is_integer()
num_instruments = int(num_instruments)
frames = np.tile(frames, (1, num_instruments))
pitches = []
intervals = []
velocities = []
instruments = []
for nonzero in np.transpose(np.nonzero(onset_diff)):
frame = nonzero[0].item()
pitch = nonzero[1].item()
onset = frame
offset = frame
velocity_samples = []
while onsets[offset, pitch] or frames[offset, pitch]:
if onsets[offset, pitch]:
velocity_samples.append(velocity[offset, pitch])
offset += 1
if offset == onsets.shape[0]:
break
if offset > onset:
pitch, instrument = pitch % N_KEYS, pitch // N_KEYS
pitches.append(pitch)
intervals.append([onset, offset])
velocities.append(
np.mean(velocity_samples) if len(velocity_samples) > 0 else 0
)
instruments.append(instrument)
return (
np.array(pitches),
np.array(intervals),
np.array(velocities),
np.array(instruments),
)
def append_track_multi(file, pitches, intervals, velocities, ins, single_ins=False):
track = MidiTrack()
file.tracks.append(track)
chan = len(file.tracks) - 1
if chan >= DRUM_CHANNEL:
chan += 1
if chan > 15:
print(f"invalid chan {chan}")
chan = 15
track.append(
Message(
"program_change", channel=chan, program=ins if not single_ins else 0, time=0
)
)
ticks_per_second = file.ticks_per_beat * 2.0
events = []
for i in range(len(pitches)):
events.append(
dict(
type="on",
pitch=pitches[i],
time=intervals[i][0],
velocity=velocities[i],
)
)
events.append(
dict(
type="off",
pitch=pitches[i],
time=intervals[i][1],
velocity=velocities[i],
)
)
events.sort(key=lambda row: row["time"])
last_tick = 0
for event in events:
current_tick = int(event["time"] * ticks_per_second)
velocity = int(event["velocity"] * 127)
if velocity > 127:
velocity = 127
pitch = int(round(hz_to_midi(event["pitch"])))
track.append(
Message(
"note_" + event["type"],
channel=chan,
note=pitch,
velocity=velocity,
time=current_tick - last_tick,
)
)
# try:
# track.append(Message('note_' + event['type'], channel=chan, note=pitch, velocity=velocity, time=current_tick - last_tick))
# except Exception as e:
# print('Err Message', 'note_' + event['type'], pitch, velocity, current_tick - last_tick)
# track.append(Message('note_' + event['type'], channel=chan, note=pitch, velocity=max(0, velocity), time=current_tick - last_tick))
# if velocity >= 0:
# raise e
last_tick = current_tick
def append_track(file, pitches, intervals, velocities):
track = MidiTrack()
file.tracks.append(track)
ticks_per_second = file.ticks_per_beat * 2.0
events = []
for i in range(len(pitches)):
events.append(
dict(
type="on",
pitch=pitches[i],
time=intervals[i][0],
velocity=velocities[i],
)
)
events.append(
dict(
type="off",
pitch=pitches[i],
time=intervals[i][1],
velocity=velocities[i],
)
)
events.sort(key=lambda row: row["time"])
last_tick = 0
for event in events:
current_tick = int(event["time"] * ticks_per_second)
velocity = int(event["velocity"] * 127)
if velocity > 127:
velocity = 127
pitch = int(round(hz_to_midi(event["pitch"])))
try:
track.append(
Message(
"note_" + event["type"],
note=pitch,
velocity=velocity,
time=current_tick - last_tick,
)
)
except Exception as e:
print(
"Err Message",
"note_" + event["type"],
pitch,
velocity,
current_tick - last_tick,
)
track.append(
Message(
"note_" + event["type"],
note=pitch,
velocity=max(0, velocity),
time=current_tick - last_tick,
)
)
if velocity >= 0:
raise e
last_tick = current_tick
def save_midi(path, pitches, intervals, velocities, insts=None):
"""
Save extracted notes as a MIDI file
Parameters
----------
path: the path to save the MIDI file
pitches: np.ndarray of bin_indices
intervals: list of (onset_index, offset_index)
velocities: list of velocity values
"""
file = MidiFile()
if isinstance(pitches, list):
for p, i, v, ins in zip(pitches, intervals, velocities, insts):
append_track_multi(file, p, i, v, ins)
else:
append_track(file, pitches, intervals, velocities)
file.save(path)
def frames2midi(
save_path,
onsets,
frames,
vels,
onset_threshold=0.5,
frame_threshold=0.5,
scaling=HOP_LENGTH / SAMPLE_RATE,
inst_mapping=None,
onset_threshold_vec=None,
):
p_est, i_est, v_est, inst_est = extract_notes_np(
onsets,
frames,
vels,
onset_threshold,
frame_threshold,
onset_threshold_vec=onset_threshold_vec,
)
i_est = (i_est * scaling).reshape(-1, 2)
p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
inst_set = set(inst_est)
inst_set = sorted(list(inst_set))
p_est_lst = {}
i_est_lst = {}
v_est_lst = {}
assert len(p_est) == len(i_est) == len(v_est) == len(inst_est)
for p, i, v, ins in zip(p_est, i_est, v_est, inst_est):
if ins in p_est_lst:
p_est_lst[ins].append(p)
else:
p_est_lst[ins] = [p]
if ins in i_est_lst:
i_est_lst[ins].append(i)
else:
i_est_lst[ins] = [i]
if ins in v_est_lst:
v_est_lst[ins].append(v)
else:
v_est_lst[ins] = [v]
for elem in [p_est_lst, i_est_lst, v_est_lst]:
for k, v in elem.items():
elem[k] = np.array(v)
inst_set = [e for e in inst_set if e in p_est_lst]
# inst_set = [INSTRUMENT_MAPPING[e] for e in inst_set if e in p_est_lst]
p_est_lst = [p_est_lst[ins] for ins in inst_set if ins in p_est_lst]
i_est_lst = [i_est_lst[ins] for ins in inst_set if ins in i_est_lst]
v_est_lst = [v_est_lst[ins] for ins in inst_set if ins in v_est_lst]
assert len(p_est_lst) == len(i_est_lst) == len(v_est_lst) == len(inst_set)
inst_set = [inst_mapping[e] for e in inst_set]
save_midi(save_path, p_est_lst, i_est_lst, v_est_lst, inst_set)
def frames2midi_pitch(
save_path,
onsets,
frames,
vels,
onset_threshold=0.5,
frame_threshold=0.5,
scaling=HOP_LENGTH / SAMPLE_RATE,
):
p_est, i_est, v_est = extract_notes_np_pitch(
onsets, frames, vels, onset_threshold, frame_threshold
)
i_est = (i_est * scaling).reshape(-1, 2)
p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
print("Saving midi in", save_path)
save_midi(save_path, p_est, i_est, v_est)
def parse_midi_multi(path, force_instrument=None):
"""open midi file and return np.array of (onset, offset, note, velocity, instrument) rows"""
try:
midi = mido.MidiFile(path)
except:
print("could not open midi", path)
return
time = 0
events = []
control_changes = []
program_changes = []
sustain = {}
all_channels = set()
instruments = {} # mapping of channel: instrument
for message in midi:
time += message.time
if hasattr(message, "channel"):
if message.channel == DRUM_CHANNEL:
continue
if (
message.type == "control_change"
and message.control == 64
and (message.value >= 64) != sustain.get(message.channel, False)
):
sustain[message.channel] = message.value >= 64
event_type = "sustain_on" if sustain[message.channel] else "sustain_off"
event = dict(
index=len(events), time=time, type=event_type, note=None, velocity=0
)
event["channel"] = message.channel
event["sustain"] = sustain[message.channel]
events.append(event)
if message.type == "control_change" and message.control != 64:
control_changes.append(
(time, message.control, message.value, message.channel)
)
if message.type == "program_change":
program_changes.append((time, message.program, message.channel))
instruments[message.channel] = instruments.get(message.channel, []) + [
(message.program, time)
]
if "note" in message.type:
# MIDI offsets can be either 'note_off' events or 'note_on' with zero velocity
velocity = message.velocity if message.type == "note_on" else 0
event = dict(
index=len(events),
time=time,
type="note",
note=message.note,
velocity=velocity,
sustain=sustain.get(message.channel, False),
)
event["channel"] = message.channel
events.append(event)
if hasattr(message, "channel"):
all_channels.add(message.channel)
if len(instruments) == 0:
instruments = {c: [(0, 0)] for c in all_channels}
if len(all_channels) > len(instruments):
for e in all_channels - set(instruments.keys()):
instruments[e] = [(0, 0)]
if force_instrument is not None:
instruments = {c: [(force_instrument, 0)] for c in all_channels}
this_instruments = set()
for v in instruments.values():
this_instruments = this_instruments.union(set(x[0] for x in v))
notes = []
for i, onset in enumerate(events):
if onset["velocity"] == 0:
continue
offset = next(
n
for n in events[i + 1 :]
if (n["note"] == onset["note"] and n["channel"] == onset["channel"])
or n is events[-1]
)
if "sustain" not in offset:
print("offset without sustain", offset)
if offset["sustain"] and offset is not events[-1]:
# if the sustain pedal is active at offset, find when the sustain ends
offset = next(
n
for n in events[offset["index"] + 1 :]
if (n["type"] == "sustain_off" and n["channel"] == onset["channel"])
or n is events[-1]
)
for k, v in instruments.items():
if len(set(v)) == 1 and len(v) > 1:
instruments[k] = list(set(v))
for k, v in instruments.items():
instruments[k] = sorted(v, key=lambda x: x[1])
if len(instruments[onset["channel"]]) == 1:
instrument = instruments[onset["channel"]][0][0]
else:
ind = 0
while (
ind < len(instruments[onset["channel"]])
and onset["time"] >= instruments[onset["channel"]][ind][1]
):
ind += 1
if ind > 0:
ind -= 1
instrument = instruments[onset["channel"]][ind][0]
if onset["channel"] == DRUM_CHANNEL:
print("skipping drum note")
continue
note = (
onset["time"],
offset["time"],
onset["note"],
onset["velocity"],
instrument,
)
notes.append(note)
res = np.array(notes)
return res
def save_midi_alignments_and_predictions(
save_path,
data_path,
inst_mapping,
aligned_onsets,
aligned_frames,
onset_pred_np,
frame_pred_np,
prefix="",
use_time=True,
group=None,
):
inst_only = len(inst_mapping) * N_KEYS
time_now = datetime.now().strftime("%y%m%d-%H%M%S") if use_time else ""
if len(prefix) > 0:
prefix = "_{}".format(prefix)
# Save the aligned label. If training on a small dataset or a single performance in order to label it for later adding it
# to a large dataset, it is recommended to use this MIDI as a label.
frames2midi(
save_path
+ os.sep
+ data_path.replace(".flac", "").split(os.sep)[-1]
+ prefix
+ "_alignment_"
+ time_now
+ ".mid",
aligned_onsets[:, :inst_only],
aligned_frames[:, :inst_only],
64.0 * aligned_onsets[:, :inst_only],
inst_mapping=inst_mapping,
)
return
# # Aligned label, pitch-only, on the piano.
# frames2midi_pitch(save_path + os.sep + data_path.replace('.flac', '').split(os.sep)[-1] + prefix + '_alignment_pitch_' + time_now + '.mid',
# aligned_onsets[:, -N_KEYS:], aligned_frames[:, -N_KEYS:],
# 64. * aligned_onsets[:, -N_KEYS:])
predicted_onsets = onset_pred_np >= 0.5
predicted_frames = frame_pred_np >= 0.5
# # Raw pitch with instrument prediction - will probably have lower recall, depending on the model's strength.
# frames2midi(save_path + os.sep + data_path.replace('.flac', '').split(os.sep)[-1] + prefix + '_pred_' + time_now + '.mid',
# predicted_onsets[:, : inst_only], predicted_frames[:, : inst_only],
# 64. * predicted_onsets[:, : inst_only],
# inst_mapping=inst_mapping)
# Pitch prediction played on the piano - will have high recall, since it does not differentiate between instruments.
frames2midi_pitch(
save_path
+ os.sep
+ data_path.replace(".flac", "").split(os.sep)[-1]
+ prefix
+ "_pred_pitch_"
+ time_now
+ ".mid",
predicted_onsets[:, -N_KEYS:],
predicted_frames[:, -N_KEYS:],
64.0 * predicted_onsets[:, -N_KEYS:],
)
# Pitch prediction, with choice of most likely instrument for each detected note.
if len(inst_mapping) > 1:
max_pred_onsets = max_inst(onset_pred_np)
frames2midi(
save_path
+ os.sep
+ data_path.replace(".flac", "").split(os.sep)[-1]
+ prefix
+ "_pred_inst_"
+ time_now
+ ".mid",
max_pred_onsets[:, :inst_only],
predicted_frames[:, :inst_only],
64.0 * max_pred_onsets[:, :inst_only],
inst_mapping=inst_mapping,
)
pseudo_onsets = (onset_pred_np >= 0.5) & (~aligned_onsets)
onset_label = np.maximum(pseudo_onsets, aligned_onsets)
pseudo_frames = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype)
for t, f in zip(*onset_label.nonzero()):
t_off = t
while t_off < len(pseudo_frames) and frame_pred_np[t_off, f % N_KEYS] >= 0.5:
t_off += 1
pseudo_frames[t:t_off, f] = 1
frame_label = np.maximum(pseudo_frames, aligned_frames)
# pseudo_frames = (frame_pred_np >= 0.5) & (~aligned_frames)
# frame_label = np.maximum(pseudo_frames, aligned_frames)
frames2midi(
save_path
+ os.sep
+ data_path.replace(".flac", "").split(os.sep)[-1]
+ prefix
+ "_pred_align_max_"
+ time_now
+ ".mid",
onset_label[:, :inst_only],
frame_label[:, :inst_only],
64.0 * onset_label[:, :inst_only],
inst_mapping=inst_mapping,
)
# if group is not None:
# gorup_path = os.path.join(save_path, 'pred_alignment_max', group)
# file_name = os.path.basename(data_path).replace('.flac', '_pred_align_max.mid')
# os.makedirs(gorup_path, exist_ok=True)
# frames2midi(os.path.join(gorup_path, file_name),
# onset_label[:, : inst_only], frame_label[:, : inst_only],
# 64. * onset_label[:, : inst_only],
# inst_mapping=inst_mapping)