Create utils/noisereduce.py
Browse files- tools/utils/noisereduce.py +196 -0
tools/utils/noisereduce.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tempfile
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from joblib import Parallel, delayed
|
| 6 |
+
from torch.nn.functional import conv1d, conv2d
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def amp_to_db(x, eps = torch.finfo(torch.float32).eps, top_db = 40):
|
| 10 |
+
x_db = 20 * torch.log10(x.abs() + eps)
|
| 11 |
+
return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
|
| 12 |
+
|
| 13 |
+
@torch.no_grad()
|
| 14 |
+
def temperature_sigmoid(x, x0, temp_coeff):
|
| 15 |
+
return torch.sigmoid((x - x0) / temp_coeff)
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def linspace(start, stop, num = 50, endpoint = True, **kwargs):
|
| 19 |
+
return torch.linspace(start, stop, num, **kwargs) if endpoint else torch.linspace(start, stop, num + 1, **kwargs)[:-1]
|
| 20 |
+
|
| 21 |
+
def _smoothing_filter(n_grad_freq, n_grad_time):
|
| 22 |
+
smoothing_filter = np.outer(np.concatenate([np.linspace(0, 1, n_grad_freq + 1, endpoint=False), np.linspace(1, 0, n_grad_freq + 2)])[1:-1], np.concatenate([np.linspace(0, 1, n_grad_time + 1, endpoint=False), np.linspace(1, 0, n_grad_time + 2)])[1:-1])
|
| 23 |
+
return smoothing_filter / np.sum(smoothing_filter)
|
| 24 |
+
|
| 25 |
+
class SpectralGate:
|
| 26 |
+
def __init__(self, y, sr, prop_decrease, chunk_size, padding, n_fft, win_length, hop_length, time_constant_s, freq_mask_smooth_hz, time_mask_smooth_ms, tmp_folder, use_tqdm, n_jobs):
|
| 27 |
+
self.sr = sr
|
| 28 |
+
self.flat = False
|
| 29 |
+
y = np.array(y)
|
| 30 |
+
|
| 31 |
+
if len(y.shape) == 1:
|
| 32 |
+
self.y = np.expand_dims(y, 0)
|
| 33 |
+
self.flat = True
|
| 34 |
+
elif len(y.shape) > 2: raise ValueError
|
| 35 |
+
else: self.y = y
|
| 36 |
+
|
| 37 |
+
self._dtype = y.dtype
|
| 38 |
+
self.n_channels, self.n_frames = self.y.shape
|
| 39 |
+
self._chunk_size = chunk_size
|
| 40 |
+
self.padding = padding
|
| 41 |
+
self.n_jobs = n_jobs
|
| 42 |
+
self.use_tqdm = use_tqdm
|
| 43 |
+
self._tmp_folder = tmp_folder
|
| 44 |
+
self._n_fft = n_fft
|
| 45 |
+
self._win_length = self._n_fft if win_length is None else win_length
|
| 46 |
+
self._hop_length = (self._win_length // 4) if hop_length is None else hop_length
|
| 47 |
+
self._time_constant_s = time_constant_s
|
| 48 |
+
self._prop_decrease = prop_decrease
|
| 49 |
+
|
| 50 |
+
if (freq_mask_smooth_hz is None) & (time_mask_smooth_ms is None): self.smooth_mask = False
|
| 51 |
+
else: self._generate_mask_smoothing_filter(freq_mask_smooth_hz, time_mask_smooth_ms)
|
| 52 |
+
|
| 53 |
+
def _generate_mask_smoothing_filter(self, freq_mask_smooth_hz, time_mask_smooth_ms):
|
| 54 |
+
if freq_mask_smooth_hz is None: n_grad_freq = 1
|
| 55 |
+
else:
|
| 56 |
+
n_grad_freq = int(freq_mask_smooth_hz / (self.sr / (self._n_fft / 2)))
|
| 57 |
+
if n_grad_freq < 1: raise ValueError
|
| 58 |
+
|
| 59 |
+
if time_mask_smooth_ms is None: n_grad_time = 1
|
| 60 |
+
else:
|
| 61 |
+
n_grad_time = int(time_mask_smooth_ms / ((self._hop_length / self.sr) * 1000))
|
| 62 |
+
if n_grad_time < 1: raise ValueError
|
| 63 |
+
|
| 64 |
+
if (n_grad_time == 1) & (n_grad_freq == 1): self.smooth_mask = False
|
| 65 |
+
else:
|
| 66 |
+
self.smooth_mask = True
|
| 67 |
+
self._smoothing_filter = _smoothing_filter(n_grad_freq, n_grad_time)
|
| 68 |
+
|
| 69 |
+
def _read_chunk(self, i1, i2):
|
| 70 |
+
i1b = 0 if i1 < 0 else i1
|
| 71 |
+
i2b = self.n_frames if i2 > self.n_frames else i2
|
| 72 |
+
chunk = np.zeros((self.n_channels, i2 - i1))
|
| 73 |
+
chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
|
| 74 |
+
return chunk
|
| 75 |
+
|
| 76 |
+
def filter_chunk(self, start_frame, end_frame):
|
| 77 |
+
i1 = start_frame - self.padding
|
| 78 |
+
return self._do_filter(self._read_chunk(i1, (end_frame + self.padding)))[:, start_frame - i1: end_frame - i1]
|
| 79 |
+
|
| 80 |
+
def _get_filtered_chunk(self, ind):
|
| 81 |
+
start0 = ind * self._chunk_size
|
| 82 |
+
end0 = (ind + 1) * self._chunk_size
|
| 83 |
+
return self.filter_chunk(start_frame=start0, end_frame=end0)
|
| 84 |
+
|
| 85 |
+
def _do_filter(self, chunk):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
|
| 89 |
+
filtered_chunk[:, pos: pos + end0 - start0] = self._get_filtered_chunk(ich)[:, start0:end0]
|
| 90 |
+
pos += end0 - start0
|
| 91 |
+
|
| 92 |
+
def get_traces(self, start_frame=None, end_frame=None):
|
| 93 |
+
if start_frame is None: start_frame = 0
|
| 94 |
+
if end_frame is None: end_frame = self.n_frames
|
| 95 |
+
|
| 96 |
+
if self._chunk_size is not None:
|
| 97 |
+
if end_frame - start_frame > self._chunk_size:
|
| 98 |
+
ich1 = int(start_frame / self._chunk_size)
|
| 99 |
+
ich2 = int((end_frame - 1) / self._chunk_size)
|
| 100 |
+
|
| 101 |
+
with tempfile.NamedTemporaryFile(prefix=self._tmp_folder) as fp:
|
| 102 |
+
filtered_chunk = np.memmap(fp, dtype=self._dtype, shape=(self.n_channels, int(end_frame - start_frame)), mode="w+")
|
| 103 |
+
pos_list, start_list, end_list = [], [], []
|
| 104 |
+
pos = 0
|
| 105 |
+
|
| 106 |
+
for ich in range(ich1, ich2 + 1):
|
| 107 |
+
start0 = (start_frame - ich * self._chunk_size) if ich == ich1 else 0
|
| 108 |
+
end0 = end_frame - ich * self._chunk_size if ich == ich2 else self._chunk_size
|
| 109 |
+
pos_list.append(pos)
|
| 110 |
+
start_list.append(start0)
|
| 111 |
+
end_list.append(end0)
|
| 112 |
+
pos += end0 - start0
|
| 113 |
+
|
| 114 |
+
Parallel(n_jobs=self.n_jobs)(delayed(self._iterate_chunk)(filtered_chunk, pos, end0, start0, ich) for pos, start0, end0, ich in zip(pos_list, start_list, end_list, range(ich1, ich2 + 1)))
|
| 115 |
+
return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
|
| 116 |
+
|
| 117 |
+
filtered_chunk = self.filter_chunk(start_frame=0, end_frame=end_frame)
|
| 118 |
+
return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
|
| 119 |
+
|
| 120 |
+
class TG(torch.nn.Module):
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def __init__(self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.sr = sr
|
| 125 |
+
self.nonstationary = nonstationary
|
| 126 |
+
assert 0.0 <= prop_decrease <= 1.0
|
| 127 |
+
self.prop_decrease = prop_decrease
|
| 128 |
+
self.n_fft = n_fft
|
| 129 |
+
self.win_length = self.n_fft if win_length is None else win_length
|
| 130 |
+
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
|
| 131 |
+
self.n_std_thresh_stationary = n_std_thresh_stationary
|
| 132 |
+
self.temp_coeff_nonstationary = temp_coeff_nonstationary
|
| 133 |
+
self.n_movemean_nonstationary = n_movemean_nonstationary
|
| 134 |
+
self.n_thresh_nonstationary = n_thresh_nonstationary
|
| 135 |
+
self.freq_mask_smooth_hz = freq_mask_smooth_hz
|
| 136 |
+
self.time_mask_smooth_ms = time_mask_smooth_ms
|
| 137 |
+
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def _generate_mask_smoothing_filter(self):
|
| 141 |
+
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None
|
| 142 |
+
n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))))
|
| 143 |
+
if n_grad_freq < 1: raise ValueError
|
| 144 |
+
|
| 145 |
+
n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)))
|
| 146 |
+
if n_grad_time < 1: raise ValueError
|
| 147 |
+
if n_grad_time == 1 and n_grad_freq == 1: return None
|
| 148 |
+
|
| 149 |
+
smoothing_filter = torch.outer(torch.cat([linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2)])[1:-1], torch.cat([linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2)])[1:-1]).unsqueeze(0).unsqueeze(0)
|
| 150 |
+
return smoothing_filter / smoothing_filter.sum()
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def _stationary_mask(self, X_db, xn = None):
|
| 154 |
+
XN_db = amp_to_db(torch.stft(xn, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(xn.device))).to(dtype=X_db.dtype) if xn is not None else X_db
|
| 155 |
+
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
|
| 156 |
+
return torch.gt(X_db, (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2))
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def _nonstationary_mask(self, X_abs):
|
| 160 |
+
X_smoothed = (conv1d(X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones(self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device).view(1, 1, -1), padding="same").view(X_abs.shape) / self.n_movemean_nonstationary)
|
| 161 |
+
return temperature_sigmoid(((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary)
|
| 162 |
+
|
| 163 |
+
def forward(self, x, xn = None):
|
| 164 |
+
assert x.ndim == 2
|
| 165 |
+
if x.shape[-1] < self.win_length * 2: raise Exception
|
| 166 |
+
assert xn is None or xn.ndim == 1 or xn.ndim == 2
|
| 167 |
+
if xn is not None and xn.shape[-1] < self.win_length * 2: raise Exception
|
| 168 |
+
|
| 169 |
+
X = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device))
|
| 170 |
+
sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X), xn)
|
| 171 |
+
|
| 172 |
+
sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
|
| 173 |
+
if self.smoothing_filter is not None: sig_mask = conv2d(sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same")
|
| 174 |
+
|
| 175 |
+
Y = X * sig_mask.squeeze(1)
|
| 176 |
+
return torch.istft(Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device)).to(dtype=x.dtype)
|
| 177 |
+
|
| 178 |
+
class StreamedTorchGate(SpectralGate):
|
| 179 |
+
def __init__(self, y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, n_std_thresh_stationary=1.5, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, n_jobs=1, device="cpu"):
|
| 180 |
+
super().__init__(y=y, sr=sr, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, tmp_folder=tmp_folder, prop_decrease=prop_decrease, use_tqdm=use_tqdm, n_jobs=n_jobs)
|
| 181 |
+
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 182 |
+
|
| 183 |
+
if y_noise is not None:
|
| 184 |
+
if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary: y_noise = y_noise[: y.shape[-1]]
|
| 185 |
+
y_noise = torch.from_numpy(y_noise).to(device)
|
| 186 |
+
if len(y_noise.shape) == 1: y_noise = y_noise.unsqueeze(0)
|
| 187 |
+
|
| 188 |
+
self.y_noise = y_noise
|
| 189 |
+
self.tg = TG(sr=sr, nonstationary=not stationary, n_std_thresh_stationary=n_std_thresh_stationary, n_thresh_nonstationary=thresh_n_mult_nonstationary, temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary, n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr), prop_decrease=prop_decrease, n_fft=self._n_fft, win_length=self._win_length, hop_length=self._hop_length, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms).to(device)
|
| 190 |
+
|
| 191 |
+
def _do_filter(self, chunk):
|
| 192 |
+
if type(chunk) is np.ndarray: chunk = torch.from_numpy(chunk).to(self.device)
|
| 193 |
+
return self.tg(x=chunk, xn=self.y_noise).cpu().detach().numpy()
|
| 194 |
+
|
| 195 |
+
def reduce_noise(y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, device="cpu"):
|
| 196 |
+
return StreamedTorchGate(y=y, sr=sr, stationary=stationary, y_noise=y_noise, prop_decrease=prop_decrease, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, thresh_n_mult_nonstationary=thresh_n_mult_nonstationary, sigmoid_slope_nonstationary=sigmoid_slope_nonstationary, tmp_folder=tmp_folder, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, clip_noise_stationary=clip_noise_stationary, use_tqdm=use_tqdm, n_jobs=1, device=device).get_traces()
|