File size: 1,096 Bytes
c60dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

class STFT:
    def __init__(self, n_fft, hop_length, win_length):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = torch.hann_window(win_length)

    def __call__(self, y):
        self.window = self.window.to(y.device)
        stft_matrix = torch.stft(
            y,
            n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
            window=self.window, return_complex=False, center=True, pad_mode='reflect'
        )
        return stft_matrix

class iSTFT:
    def __init__(self, n_fft, hop_length, win_length):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = torch.hann_window(win_length)

    def __call__(self, X):
        self.window = self.window.to(X.device)
        X = torch.view_as_complex(X.contiguous())
        return torch.istft(
            X,
            n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
            window=self.window, center=True
        )