Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from librosa.filters import mel | |
| from librosa.util import pad_center | |
| from scipy.signal import get_window | |
| from torch.autograd import Variable | |
| from onsets_and_frames.constants import ( | |
| DEFAULT_DEVICE, | |
| HOP_LENGTH, | |
| MEL_FMAX, | |
| MEL_FMIN, | |
| N_MELS, | |
| SAMPLE_RATE, | |
| WINDOW_LENGTH, | |
| ) | |
| class STFT(torch.nn.Module): | |
| """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" | |
| def __init__(self, filter_length, hop_length, win_length=None, window="hann"): | |
| super(STFT, self).__init__() | |
| if win_length is None: | |
| win_length = filter_length | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.window = window | |
| self.forward_transform = None | |
| fourier_basis = np.fft.fft(np.eye(self.filter_length)) | |
| cutoff = int((self.filter_length / 2 + 1)) | |
| fourier_basis = np.vstack( | |
| [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] | |
| ) | |
| forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) | |
| if window is not None: | |
| assert filter_length >= win_length | |
| # get window and zero center pad it to filter_length | |
| fft_window = get_window(window, win_length, fftbins=True) | |
| fft_window = pad_center(fft_window, size=filter_length) | |
| fft_window = torch.from_numpy(fft_window).float() | |
| # window the bases | |
| forward_basis *= fft_window | |
| self.register_buffer("forward_basis", forward_basis.float()) | |
| def forward(self, input_data): | |
| num_batches = input_data.size(0) | |
| num_samples = input_data.size(1) | |
| # similar to librosa, reflect-pad the input | |
| input_data = input_data.view(num_batches, 1, num_samples) | |
| # print('inp before', input_data.shape) | |
| input_data = F.pad( | |
| input_data.unsqueeze(1), | |
| (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), | |
| mode="reflect", | |
| ) | |
| input_data = input_data.squeeze(1) | |
| # print('inp after', input_data.shape) | |
| forward_transform = F.conv1d( | |
| input_data, | |
| Variable(self.forward_basis, requires_grad=False), | |
| stride=self.hop_length, | |
| padding=0, | |
| ) | |
| # print('fwd', forward_transform.shape) | |
| cutoff = int((self.filter_length / 2) + 1) | |
| real_part = forward_transform[:, :cutoff, :] | |
| imag_part = forward_transform[:, cutoff:, :] | |
| magnitude = torch.sqrt(real_part**2 + imag_part**2) | |
| phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) | |
| return magnitude, phase | |
| class MelSpectrogram(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_mels, | |
| sample_rate, | |
| filter_length, | |
| hop_length, | |
| win_length=None, | |
| mel_fmin=0.0, | |
| mel_fmax=None, | |
| ): | |
| super(MelSpectrogram, self).__init__() | |
| self.stft = STFT(filter_length, hop_length, win_length) | |
| mel_basis = mel( | |
| sr=sample_rate, | |
| n_fft=filter_length, | |
| n_mels=n_mels, | |
| fmin=mel_fmin, | |
| fmax=mel_fmax, | |
| htk=True, | |
| ) | |
| mel_basis = torch.from_numpy(mel_basis).float() | |
| self.register_buffer("mel_basis", mel_basis) | |
| def forward(self, y): | |
| """Computes mel-spectrograms from a batch of waves | |
| PARAMS | |
| ------ | |
| y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] | |
| RETURNS | |
| ------- | |
| mel_output: torch.FloatTensor of shape (B, T, n_mels) | |
| """ | |
| assert torch.min(y.data) >= -1 | |
| assert torch.max(y.data) <= 1 | |
| magnitudes, phases = self.stft(y) | |
| magnitudes = magnitudes.data | |
| mel_output = torch.matmul(self.mel_basis, magnitudes) | |
| mel_output = torch.log(torch.clamp(mel_output, min=1e-5)) | |
| return mel_output | |
| # the default melspectrogram converter across the project | |
| melspectrogram = MelSpectrogram( | |
| N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, mel_fmin=MEL_FMIN, mel_fmax=MEL_FMAX | |
| ) | |
| melspectrogram.to(DEFAULT_DEVICE) | |