Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,262 Bytes
05d6e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import numpy as np
import torch
import torch.nn.functional as F
from librosa.filters import mel
from librosa.util import pad_center
from scipy.signal import get_window
from torch.autograd import Variable
from onsets_and_frames.constants import (
DEFAULT_DEVICE,
HOP_LENGTH,
MEL_FMAX,
MEL_FMIN,
N_MELS,
SAMPLE_RATE,
WINDOW_LENGTH,
)
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length, hop_length, win_length=None, window="hann"):
super(STFT, self).__init__()
if win_length is None:
win_length = filter_length
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
)
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
if window is not None:
assert filter_length >= win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
self.register_buffer("forward_basis", forward_basis.float())
def forward(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
# print('inp before', input_data.shape)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode="reflect",
)
input_data = input_data.squeeze(1)
# print('inp after', input_data.shape)
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0,
)
# print('fwd', forward_transform.shape)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
class MelSpectrogram(torch.nn.Module):
def __init__(
self,
n_mels,
sample_rate,
filter_length,
hop_length,
win_length=None,
mel_fmin=0.0,
mel_fmax=None,
):
super(MelSpectrogram, self).__init__()
self.stft = STFT(filter_length, hop_length, win_length)
mel_basis = mel(
sr=sample_rate,
n_fft=filter_length,
n_mels=n_mels,
fmin=mel_fmin,
fmax=mel_fmax,
htk=True,
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
def forward(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, T, n_mels)
"""
assert torch.min(y.data) >= -1
assert torch.max(y.data) <= 1
magnitudes, phases = self.stft(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = torch.log(torch.clamp(mel_output, min=1e-5))
return mel_output
# the default melspectrogram converter across the project
melspectrogram = MelSpectrogram(
N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, mel_fmin=MEL_FMIN, mel_fmax=MEL_FMAX
)
melspectrogram.to(DEFAULT_DEVICE)
|