Spaces:
Running
on
Zero
Running
on
Zero
| import librosa | |
| import torch | |
| from typing import Tuple | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| class LogMel(torch.nn.Module): | |
| """Convert STFT to fbank feats | |
| The arguments is same as librosa.filters.mel | |
| Args: | |
| fs: number > 0 [scalar] sampling rate of the incoming signal | |
| n_fft: int > 0 [scalar] number of FFT components | |
| n_mels: int > 0 [scalar] number of Mel bands to generate | |
| fmin: float >= 0 [scalar] lowest frequency (in Hz) | |
| fmax: float >= 0 [scalar] highest frequency (in Hz). | |
| If `None`, use `fmax = fs / 2.0` | |
| htk: use HTK formula instead of Slaney | |
| """ | |
| def __init__( | |
| self, | |
| fs: int = 16000, | |
| n_fft: int = 512, | |
| n_mels: int = 80, | |
| fmin: float = None, | |
| fmax: float = None, | |
| htk: bool = False, | |
| log_base: float = None, | |
| ): | |
| super().__init__() | |
| fmin = 0 if fmin is None else fmin | |
| fmax = fs / 2 if fmax is None else fmax | |
| _mel_options = dict( | |
| sr=fs, | |
| n_fft=n_fft, | |
| n_mels=n_mels, | |
| fmin=fmin, | |
| fmax=fmax, | |
| htk=htk, | |
| ) | |
| self.mel_options = _mel_options | |
| self.log_base = log_base | |
| # Note(kamo): The mel matrix of librosa is different from kaldi. | |
| melmat = librosa.filters.mel(**_mel_options) | |
| # melmat: (D2, D1) -> (D1, D2) | |
| self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) | |
| def extra_repr(self): | |
| return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) | |
| def forward( | |
| self, | |
| feat: torch.Tensor, | |
| ilens: torch.Tensor = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) | |
| mel_feat = torch.matmul(feat, self.melmat) | |
| mel_feat = torch.clamp(mel_feat, min=1e-10) | |
| if self.log_base is None: | |
| logmel_feat = mel_feat.log() | |
| elif self.log_base == 2.0: | |
| logmel_feat = mel_feat.log2() | |
| elif self.log_base == 10.0: | |
| logmel_feat = mel_feat.log10() | |
| else: | |
| logmel_feat = mel_feat.log() / torch.log(self.log_base) | |
| # Zero padding | |
| if ilens is not None: | |
| logmel_feat = logmel_feat.masked_fill( | |
| make_pad_mask(ilens, logmel_feat, 1), 0.0 | |
| ) | |
| else: | |
| ilens = feat.new_full( | |
| [feat.size(0)], fill_value=feat.size(1), dtype=torch.long | |
| ) | |
| return logmel_feat, ilens | |