import torch import torch.nn.functional as F from torch import nn from onsets_and_frames.constants import MAX_MIDI, MIN_MIDI, N_KEYS from .lstm import BiLSTM from .mel import melspectrogram class ConvStack(nn.Module): def __init__(self, input_features, output_features): super().__init__() # input is batch_size * 1 channel * frames * input_features self.cnn = nn.Sequential( # layer 0 nn.Conv2d(1, output_features // 16, (3, 3), padding=1), nn.BatchNorm2d(output_features // 16), nn.ReLU(), # layer 1 nn.Conv2d(output_features // 16, output_features // 16, (3, 3), padding=1), nn.BatchNorm2d(output_features // 16), nn.ReLU(), # layer 2 nn.MaxPool2d((1, 2)), nn.Dropout(0.25), nn.Conv2d(output_features // 16, output_features // 8, (3, 3), padding=1), nn.BatchNorm2d(output_features // 8), nn.ReLU(), # layer 3 nn.MaxPool2d((1, 2)), nn.Dropout(0.25), ) self.fc = nn.Sequential( nn.Linear((output_features // 8) * (input_features // 4), output_features), nn.Dropout(0.5), ) def forward(self, mel): x = mel.view(mel.size(0), 1, mel.size(1), mel.size(2)) x = self.cnn(x) x = x.transpose(1, 2).flatten(-2) x = self.fc(x) return x class OnsetsAndFrames(nn.Module): def __init__( self, input_features, output_features, model_complexity=48, onset_complexity=1, n_instruments=13, ): nn.Module.__init__(self) model_size = model_complexity * 16 sequence_model = lambda input_size, output_size: BiLSTM( input_size, output_size // 2 ) onset_model_size = int(onset_complexity * model_size) self.onset_stack = nn.Sequential( ConvStack(input_features, onset_model_size), sequence_model(onset_model_size, onset_model_size), nn.Linear(onset_model_size, output_features * n_instruments), nn.Sigmoid(), ) self.offset_stack = nn.Sequential( ConvStack(input_features, model_size), sequence_model(model_size, model_size), nn.Linear(model_size, output_features), nn.Sigmoid(), ) self.frame_stack = nn.Sequential( ConvStack(input_features, model_size), nn.Linear(model_size, output_features), nn.Sigmoid(), ) self.combined_stack = nn.Sequential( sequence_model(output_features * 3, model_size), nn.Linear(model_size, output_features), nn.Sigmoid(), ) self.velocity_stack = nn.Sequential( ConvStack(input_features, model_size), nn.Linear(model_size, output_features * n_instruments), ) def forward(self, mel): onset_pred = self.onset_stack(mel) offset_pred = self.offset_stack(mel) activation_pred = self.frame_stack(mel) onset_detached = onset_pred.detach() shape = onset_detached.shape keys = MAX_MIDI - MIN_MIDI + 1 new_shape = shape[:-1] + (shape[-1] // keys, keys) onset_detached = onset_detached.reshape(new_shape) onset_detached, _ = onset_detached.max(axis=-2) offset_detached = offset_pred.detach() combined_pred = torch.cat( [onset_detached, offset_detached, activation_pred], dim=-1 ) frame_pred = self.combined_stack(combined_pred) velocity_pred = self.velocity_stack(mel) return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred def run_on_batch( self, batch, parallel_model=None, positive_weight=2.0, inv_positive_weight=2.0, with_onset_mask=False, ): audio_label = batch["audio"] onset_label = batch["onset"] offset_label = batch["offset"] frame_label = batch["frame"] if "velocity" in batch: velocity_label = batch["velocity"] mel = melspectrogram( audio_label.reshape(-1, audio_label.shape[-1])[:, :-1] ).transpose(-1, -2) if not parallel_model: onset_pred, offset_pred, _, frame_pred, velocity_pred = self(mel) else: onset_pred, offset_pred, _, frame_pred, velocity_pred = parallel_model(mel) predictions = { "onset": onset_pred.reshape(*onset_label.shape), "offset": offset_pred.reshape(*offset_label.shape), "frame": frame_pred.reshape(*frame_label.shape), # 'velocity': velocity_pred.reshape(*velocity_label.shape) } if "velocity" in batch: predictions["velocity"] = velocity_pred.reshape(*velocity_label.shape) losses = { "loss/onset": F.binary_cross_entropy( predictions["onset"], onset_label, reduction="none" ), "loss/offset": F.binary_cross_entropy( predictions["offset"], offset_label, reduction="none" ), "loss/frame": F.binary_cross_entropy( predictions["frame"], frame_label, reduction="none" ), # 'loss/velocity': self.velocity_loss(predictions['velocity'], velocity_label, onset_label) } if "velocity" in batch: losses["loss/velocity"] = self.velocity_loss( predictions["velocity"], velocity_label, onset_label ) onset_mask = 1.0 * onset_label onset_mask[..., :-N_KEYS] *= positive_weight - 1 onset_mask[..., -N_KEYS:] *= inv_positive_weight - 1 onset_mask += 1 if with_onset_mask: if "onset_mask" in batch: onset_mask = onset_mask * batch["onset_mask"] # if 'onset_mask' in batch: # onset_mask += batch['onset_mask'] offset_mask = 1.0 * offset_label offset_positive_weight = 2.0 offset_mask *= offset_positive_weight - 1 offset_mask += 1.0 frame_mask = 1.0 * frame_label frame_positive_weight = 2.0 frame_mask *= frame_positive_weight - 1 frame_mask += 1.0 for loss_key, mask in zip( ["onset", "offset", "frame"], [onset_mask, offset_mask, frame_mask] ): losses["loss/" + loss_key] = (mask * losses["loss/" + loss_key]).mean() return predictions, losses def velocity_loss(self, velocity_pred, velocity_label, onset_label): denominator = onset_label.sum() if denominator.item() == 0: return denominator else: return ( onset_label * (velocity_label - velocity_pred) ** 2 ).sum() / denominator # same implementation as OnsetsAndFrames, but with only onset stack class OnsetsNoFrames(nn.Module): def __init__( self, input_features, output_features, model_complexity=48, onset_complexity=1, n_instruments=13, ): nn.Module.__init__(self) model_size = model_complexity * 16 sequence_model = lambda input_size, output_size: BiLSTM( input_size, output_size // 2 ) onset_model_size = int(onset_complexity * model_size) self.onset_stack = nn.Sequential( ConvStack(input_features, onset_model_size), sequence_model(onset_model_size, onset_model_size), nn.Linear(onset_model_size, output_features * n_instruments), nn.Sigmoid(), ) def forward(self, mel): onset_pred = self.onset_stack(mel) onset_detached = onset_pred.detach() shape = onset_detached.shape keys = MAX_MIDI - MIN_MIDI + 1 new_shape = shape[:-1] + (shape[-1] // keys, keys) onset_detached = onset_detached.reshape(new_shape) onset_detached, _ = onset_detached.max(axis=-2) return onset_pred def run_on_batch( self, batch, parallel_model=None, positive_weight=2.0, inv_positive_weight=2.0, with_onset_mask=False, ): audio_label = batch["audio"] onset_label = batch["onset"] mel = melspectrogram( audio_label.reshape(-1, audio_label.shape[-1])[:, :-1] ).transpose(-1, -2) if not parallel_model: onset_pred = self(mel) else: onset_pred = parallel_model(mel) predictions = { "onset": onset_pred, } losses = { "loss/onset": F.binary_cross_entropy( predictions["onset"], onset_label, reduction="none" ), } onset_mask = 1.0 * onset_label onset_mask[..., :-N_KEYS] *= positive_weight - 1 onset_mask[..., -N_KEYS:] *= inv_positive_weight - 1 onset_mask += 1 if with_onset_mask: if "onset_mask" in batch: onset_mask = onset_mask * batch["onset_mask"] losses["loss/onset"] = (onset_mask * losses["loss/onset"]).mean() return predictions, losses