import os import sys import torch from torch.nn.functional import conv1d, conv2d sys.path.append(os.getcwd()) @torch.no_grad() def temperature_sigmoid(x, x0, temp_coeff): return ((x - x0) / temp_coeff).sigmoid() @torch.no_grad() def linspace(start, stop, num = 50, endpoint = True, **kwargs): return ( torch.linspace( start, stop, num, **kwargs ) ) if endpoint else ( torch.linspace( start, stop, num + 1, **kwargs )[:-1] ) @torch.no_grad() def amp_to_db(x, eps=torch.finfo(torch.float32).eps, top_db=40): x_db = 20 * (x + eps).log10() return x_db.max( (x_db.max(-1).values - top_db).unsqueeze(-1) ) class TorchGate(torch.nn.Module): @torch.no_grad() def __init__( self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50 ): super().__init__() self.sr = sr self.nonstationary = nonstationary assert 0.0 <= prop_decrease <= 1.0 self.prop_decrease = prop_decrease self.n_fft = n_fft self.win_length = self.n_fft if win_length is None else win_length self.hop_length = self.win_length // 4 if hop_length is None else hop_length self.n_std_thresh_stationary = n_std_thresh_stationary self.temp_coeff_nonstationary = temp_coeff_nonstationary self.n_movemean_nonstationary = n_movemean_nonstationary self.n_thresh_nonstationary = n_thresh_nonstationary self.freq_mask_smooth_hz = freq_mask_smooth_hz self.time_mask_smooth_ms = time_mask_smooth_ms self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter()) @torch.no_grad() def _generate_mask_smoothing_filter(self): if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))) if n_grad_freq < 1: raise ValueError n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))) if n_grad_time < 1: raise ValueError if n_grad_time == 1 and n_grad_freq == 1: return None smoothing_filter = torch.outer( torch.cat([ linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2) ])[1:-1], torch.cat([ linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2) ])[1:-1] ).unsqueeze(0).unsqueeze(0) return smoothing_filter / smoothing_filter.sum() @torch.no_grad() def _stationary_mask(self, X_db): std_freq_noise, mean_freq_noise = torch.std_mean(X_db, dim=-1) return X_db > (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2) @torch.no_grad() def _nonstationary_mask(self, X_abs): X_smoothed = ( conv1d( X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones( self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device ).view(1, 1, -1), padding="same" ).view(X_abs.shape) / self.n_movemean_nonstationary ) return temperature_sigmoid( ((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary ) def forward(self, x): assert x.ndim == 2 if x.shape[-1] < self.win_length * 2: raise Exception if str(x.device).startswith(("ocl", "privateuseone")): if not hasattr(self, "stft"): from main.library.backends.utils import STFT self.stft = STFT( filter_length=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, pad_mode="constant" ).to(x.device) X, phase = self.stft.transform( x, eps=1e-9, return_phase=True ) else: X = torch.stft( x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device) ) sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X.abs())) sig_mask = self.prop_decrease * (sig_mask.float() * 1.0 - 1.0) + 1.0 if self.smoothing_filter is not None: sig_mask = conv2d( sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same" ) Y = X * sig_mask.squeeze(1) return ( self.stft.inverse( Y, phase ) ) if hasattr(self, "stft") else ( torch.istft( Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device) ).to(dtype=x.dtype) )