Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List | |
| from typing import Tuple | |
| from typing import Union | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from torch_complex.tensor import ComplexTensor | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| class FeatureTransform(torch.nn.Module): | |
| def __init__( | |
| self, | |
| # Mel options, | |
| fs: int = 16000, | |
| n_fft: int = 512, | |
| n_mels: int = 80, | |
| fmin: float = 0.0, | |
| fmax: float = None, | |
| # Normalization | |
| stats_file: str = None, | |
| apply_uttmvn: bool = True, | |
| uttmvn_norm_means: bool = True, | |
| uttmvn_norm_vars: bool = False, | |
| ): | |
| super().__init__() | |
| self.apply_uttmvn = apply_uttmvn | |
| self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) | |
| self.stats_file = stats_file | |
| if stats_file is not None: | |
| self.global_mvn = GlobalMVN(stats_file) | |
| else: | |
| self.global_mvn = None | |
| if self.apply_uttmvn is not None: | |
| self.uttmvn = UtteranceMVN( | |
| norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars | |
| ) | |
| else: | |
| self.uttmvn = None | |
| def forward( | |
| self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| # (B, T, F) or (B, T, C, F) | |
| if x.dim() not in (3, 4): | |
| raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") | |
| if not torch.is_tensor(ilens): | |
| ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) | |
| if x.dim() == 4: | |
| # h: (B, T, C, F) -> h: (B, T, F) | |
| if self.training: | |
| # Select 1ch randomly | |
| ch = np.random.randint(x.size(2)) | |
| h = x[:, :, ch, :] | |
| else: | |
| # Use the first channel | |
| h = x[:, :, 0, :] | |
| else: | |
| h = x | |
| # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) | |
| h = h.real**2 + h.imag**2 | |
| h, _ = self.logmel(h, ilens) | |
| if self.stats_file is not None: | |
| h, _ = self.global_mvn(h, ilens) | |
| if self.apply_uttmvn: | |
| h, _ = self.uttmvn(h, ilens) | |
| return h, ilens | |
| class LogMel(torch.nn.Module): | |
| """Convert STFT to fbank feats | |
| The arguments is same as librosa.filters.mel | |
| Args: | |
| fs: number > 0 [scalar] sampling rate of the incoming signal | |
| n_fft: int > 0 [scalar] number of FFT components | |
| n_mels: int > 0 [scalar] number of Mel bands to generate | |
| fmin: float >= 0 [scalar] lowest frequency (in Hz) | |
| fmax: float >= 0 [scalar] highest frequency (in Hz). | |
| If `None`, use `fmax = fs / 2.0` | |
| htk: use HTK formula instead of Slaney | |
| norm: {None, 1, np.inf} [scalar] | |
| if 1, divide the triangular mel weights by the width of the mel band | |
| (area normalization). Otherwise, leave all the triangles aiming for | |
| a peak value of 1.0 | |
| """ | |
| def __init__( | |
| self, | |
| fs: int = 16000, | |
| n_fft: int = 512, | |
| n_mels: int = 80, | |
| fmin: float = 0.0, | |
| fmax: float = None, | |
| htk: bool = False, | |
| norm=1, | |
| ): | |
| super().__init__() | |
| _mel_options = dict( | |
| sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm | |
| ) | |
| self.mel_options = _mel_options | |
| # Note(kamo): The mel matrix of librosa is different from kaldi. | |
| melmat = librosa.filters.mel(**_mel_options) | |
| # melmat: (D2, D1) -> (D1, D2) | |
| self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) | |
| def extra_repr(self): | |
| return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) | |
| def forward( | |
| self, feat: torch.Tensor, ilens: torch.LongTensor | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) | |
| mel_feat = torch.matmul(feat, self.melmat) | |
| logmel_feat = (mel_feat + 1e-20).log() | |
| # Zero padding | |
| logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) | |
| return logmel_feat, ilens | |
| class GlobalMVN(torch.nn.Module): | |
| """Apply global mean and variance normalization | |
| Args: | |
| stats_file(str): npy file of 1-dim array or text file. | |
| From the _first element to | |
| the {(len(array) - 1) / 2}th element are treated as | |
| the sum of features, | |
| and the rest excluding the last elements are | |
| treated as the sum of the square value of features, | |
| and the last elements eqauls to the number of samples. | |
| std_floor(float): | |
| """ | |
| def __init__( | |
| self, | |
| stats_file: str, | |
| norm_means: bool = True, | |
| norm_vars: bool = True, | |
| eps: float = 1.0e-20, | |
| ): | |
| super().__init__() | |
| self.norm_means = norm_means | |
| self.norm_vars = norm_vars | |
| self.stats_file = stats_file | |
| stats = np.load(stats_file) | |
| stats = stats.astype(float) | |
| assert (len(stats) - 1) % 2 == 0, stats.shape | |
| count = stats.flatten()[-1] | |
| mean = stats[: (len(stats) - 1) // 2] / count | |
| var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean | |
| std = np.maximum(np.sqrt(var), eps) | |
| self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) | |
| self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32))) | |
| def extra_repr(self): | |
| return ( | |
| f"stats_file={self.stats_file}, " | |
| f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | |
| ) | |
| def forward( | |
| self, x: torch.Tensor, ilens: torch.LongTensor | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| # feat: (B, T, D) | |
| if self.norm_means: | |
| x += self.bias.type_as(x) | |
| x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) | |
| if self.norm_vars: | |
| x *= self.scale.type_as(x) | |
| return x, ilens | |
| class UtteranceMVN(torch.nn.Module): | |
| def __init__( | |
| self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 | |
| ): | |
| super().__init__() | |
| self.norm_means = norm_means | |
| self.norm_vars = norm_vars | |
| self.eps = eps | |
| def extra_repr(self): | |
| return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | |
| def forward( | |
| self, x: torch.Tensor, ilens: torch.LongTensor | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| return utterance_mvn( | |
| x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps | |
| ) | |
| def utterance_mvn( | |
| x: torch.Tensor, | |
| ilens: torch.LongTensor, | |
| norm_means: bool = True, | |
| norm_vars: bool = False, | |
| eps: float = 1.0e-20, | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| """Apply utterance mean and variance normalization | |
| Args: | |
| x: (B, T, D), assumed zero padded | |
| ilens: (B, T, D) | |
| norm_means: | |
| norm_vars: | |
| eps: | |
| """ | |
| ilens_ = ilens.type_as(x) | |
| # mean: (B, D) | |
| mean = x.sum(dim=1) / ilens_[:, None] | |
| if norm_means: | |
| x -= mean[:, None, :] | |
| x_ = x | |
| else: | |
| x_ = x - mean[:, None, :] | |
| # Zero padding | |
| x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) | |
| if norm_vars: | |
| var = x_.pow(2).sum(dim=1) / ilens_[:, None] | |
| var = torch.clamp(var, min=eps) | |
| x /= var.sqrt()[:, None, :] | |
| x_ = x | |
| return x_, ilens | |
| def feature_transform_for(args, n_fft): | |
| return FeatureTransform( | |
| # Mel options, | |
| fs=args.fbank_fs, | |
| n_fft=n_fft, | |
| n_mels=args.n_mels, | |
| fmin=args.fbank_fmin, | |
| fmax=args.fbank_fmax, | |
| # Normalization | |
| stats_file=args.stats_file, | |
| apply_uttmvn=args.apply_uttmvn, | |
| uttmvn_norm_means=args.uttmvn_norm_means, | |
| uttmvn_norm_vars=args.uttmvn_norm_vars, | |
| ) | |