Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class Singleton(type): | |
| _instances = {} | |
| def __call__(cls, *args, **kwargs): | |
| if cls not in cls._instances: | |
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |
| return cls._instances[cls] | |
| class STFT(torch.nn.Module, metaclass=Singleton): | |
| def __init__(self, filter_length=1024, hop_length=512): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_len = hop_length | |
| self.win_len = filter_length | |
| self.window = torch.hann_window(self.win_len) | |
| self.num_samples = -1 | |
| def transform(self, x): | |
| x = torch.nn.functional.pad(x, (0, self.win_len - x.shape[1]%self.win_len)) | |
| fft = torch.stft(x, self.filter_length, self.hop_len, self.win_len, window=self.window.to(x.device), return_complex=True) | |
| real_part, imag_part = fft.real, fft.imag | |
| squared = real_part**2 + imag_part**2 | |
| additive_epsilon = torch.ones_like(squared) * (squared == 0).float() * 1e-24 | |
| magnitude = torch.sqrt(squared + additive_epsilon) - torch.sqrt(additive_epsilon) | |
| phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)).float() | |
| return magnitude, phase | |
| def inverse(self, magnitude, phase): | |
| recombine_magnitude_phase = magnitude*torch.cos(phase) + 1j*magnitude*torch.sin(phase) | |
| inverse_transform = torch.istft(recombine_magnitude_phase, self.filter_length, hop_length=self.hop_len, win_length=self.win_len, window=self.window.to(magnitude.device)).unsqueeze(1) # , length=self.num_samples | |
| padding = self.win_len - (self.num_samples % self.win_len) | |
| inverse_transform = inverse_transform[:, :, :-padding] | |
| return inverse_transform | |