|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
|
|
|
from torch.nn.functional import conv1d, conv2d |
|
|
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
|
|
@torch.no_grad() |
|
|
def temperature_sigmoid(x, x0, temp_coeff): |
|
|
return ((x - x0) / temp_coeff).sigmoid() |
|
|
|
|
|
@torch.no_grad() |
|
|
def linspace(start, stop, num = 50, endpoint = True, **kwargs): |
|
|
return ( |
|
|
torch.linspace( |
|
|
start, |
|
|
stop, |
|
|
num, |
|
|
**kwargs |
|
|
) |
|
|
) if endpoint else ( |
|
|
torch.linspace( |
|
|
start, |
|
|
stop, |
|
|
num + 1, |
|
|
**kwargs |
|
|
)[:-1] |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def amp_to_db(x, eps=torch.finfo(torch.float32).eps, top_db=40): |
|
|
x_db = 20 * (x + eps).log10() |
|
|
|
|
|
return x_db.max( |
|
|
(x_db.max(-1).values - top_db).unsqueeze(-1) |
|
|
) |
|
|
|
|
|
class TorchGate(torch.nn.Module): |
|
|
@torch.no_grad() |
|
|
def __init__( |
|
|
self, |
|
|
sr, |
|
|
nonstationary = False, |
|
|
n_std_thresh_stationary = 1.5, |
|
|
n_thresh_nonstationary = 1.3, |
|
|
temp_coeff_nonstationary = 0.1, |
|
|
n_movemean_nonstationary = 20, |
|
|
prop_decrease = 1.0, |
|
|
n_fft = 1024, |
|
|
win_length = None, |
|
|
hop_length = None, |
|
|
freq_mask_smooth_hz = 500, |
|
|
time_mask_smooth_ms = 50 |
|
|
): |
|
|
super().__init__() |
|
|
self.sr = sr |
|
|
self.nonstationary = nonstationary |
|
|
assert 0.0 <= prop_decrease <= 1.0 |
|
|
self.prop_decrease = prop_decrease |
|
|
self.n_fft = n_fft |
|
|
self.win_length = self.n_fft if win_length is None else win_length |
|
|
self.hop_length = self.win_length // 4 if hop_length is None else hop_length |
|
|
self.n_std_thresh_stationary = n_std_thresh_stationary |
|
|
self.temp_coeff_nonstationary = temp_coeff_nonstationary |
|
|
self.n_movemean_nonstationary = n_movemean_nonstationary |
|
|
self.n_thresh_nonstationary = n_thresh_nonstationary |
|
|
self.freq_mask_smooth_hz = freq_mask_smooth_hz |
|
|
self.time_mask_smooth_ms = time_mask_smooth_ms |
|
|
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter()) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _generate_mask_smoothing_filter(self): |
|
|
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None |
|
|
n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))) |
|
|
if n_grad_freq < 1: raise ValueError |
|
|
|
|
|
n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))) |
|
|
if n_grad_time < 1: raise ValueError |
|
|
if n_grad_time == 1 and n_grad_freq == 1: return None |
|
|
|
|
|
smoothing_filter = torch.outer( |
|
|
torch.cat([ |
|
|
linspace(0, 1, n_grad_freq + 1, endpoint=False), |
|
|
linspace(1, 0, n_grad_freq + 2) |
|
|
])[1:-1], |
|
|
torch.cat([ |
|
|
linspace(0, 1, n_grad_time + 1, endpoint=False), |
|
|
linspace(1, 0, n_grad_time + 2) |
|
|
])[1:-1] |
|
|
).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
return smoothing_filter / smoothing_filter.sum() |
|
|
|
|
|
@torch.no_grad() |
|
|
def _stationary_mask(self, X_db): |
|
|
std_freq_noise, mean_freq_noise = torch.std_mean(X_db, dim=-1) |
|
|
return X_db > (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _nonstationary_mask(self, X_abs): |
|
|
X_smoothed = ( |
|
|
conv1d( |
|
|
X_abs.reshape(-1, 1, X_abs.shape[-1]), |
|
|
torch.ones( |
|
|
self.n_movemean_nonstationary, |
|
|
dtype=X_abs.dtype, |
|
|
device=X_abs.device |
|
|
).view(1, 1, -1), |
|
|
padding="same" |
|
|
).view(X_abs.shape) / self.n_movemean_nonstationary |
|
|
) |
|
|
|
|
|
return temperature_sigmoid( |
|
|
((X_abs - X_smoothed) / X_smoothed), |
|
|
self.n_thresh_nonstationary, |
|
|
self.temp_coeff_nonstationary |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
assert x.ndim == 2 |
|
|
if x.shape[-1] < self.win_length * 2: raise Exception |
|
|
|
|
|
if str(x.device).startswith(("ocl", "privateuseone")): |
|
|
if not hasattr(self, "stft"): |
|
|
from main.library.backends.utils import STFT |
|
|
|
|
|
self.stft = STFT( |
|
|
filter_length=self.n_fft, |
|
|
hop_length=self.hop_length, |
|
|
win_length=self.win_length, |
|
|
pad_mode="constant" |
|
|
).to(x.device) |
|
|
|
|
|
X, phase = self.stft.transform( |
|
|
x, |
|
|
eps=1e-9, |
|
|
return_phase=True |
|
|
) |
|
|
else: |
|
|
X = torch.stft( |
|
|
x, |
|
|
n_fft=self.n_fft, |
|
|
hop_length=self.hop_length, |
|
|
win_length=self.win_length, |
|
|
return_complex=True, |
|
|
pad_mode="constant", |
|
|
center=True, |
|
|
window=torch.hann_window(self.win_length).to(x.device) |
|
|
) |
|
|
|
|
|
sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X.abs())) |
|
|
sig_mask = self.prop_decrease * (sig_mask.float() * 1.0 - 1.0) + 1.0 |
|
|
|
|
|
if self.smoothing_filter is not None: |
|
|
sig_mask = conv2d( |
|
|
sig_mask.unsqueeze(1), |
|
|
self.smoothing_filter.to(sig_mask.dtype), |
|
|
padding="same" |
|
|
) |
|
|
|
|
|
Y = X * sig_mask.squeeze(1) |
|
|
|
|
|
return ( |
|
|
self.stft.inverse( |
|
|
Y, |
|
|
phase |
|
|
) |
|
|
) if hasattr(self, "stft") else ( |
|
|
torch.istft( |
|
|
Y, |
|
|
n_fft=self.n_fft, |
|
|
hop_length=self.hop_length, |
|
|
win_length=self.win_length, |
|
|
center=True, |
|
|
window=torch.hann_window(self.win_length).to(Y.device) |
|
|
).to(dtype=x.dtype) |
|
|
) |