Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| from typing import Tuple | |
| from typing import Union | |
| import numpy as np | |
| import torch | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| from funasr_detach.register import tables | |
| class GlobalMVN(torch.nn.Module): | |
| """Apply global mean and variance normalization | |
| TODO(kamo): Make this class portable somehow | |
| Args: | |
| stats_file: npy file | |
| norm_means: Apply mean normalization | |
| norm_vars: Apply var normalization | |
| eps: | |
| """ | |
| def __init__( | |
| self, | |
| stats_file: Union[Path, str], | |
| norm_means: bool = True, | |
| norm_vars: bool = True, | |
| eps: float = 1.0e-20, | |
| ): | |
| super().__init__() | |
| self.norm_means = norm_means | |
| self.norm_vars = norm_vars | |
| self.eps = eps | |
| stats_file = Path(stats_file) | |
| self.stats_file = stats_file | |
| stats = np.load(stats_file) | |
| if isinstance(stats, np.ndarray): | |
| # Kaldi like stats | |
| count = stats[0].flatten()[-1] | |
| mean = stats[0, :-1] / count | |
| var = stats[1, :-1] / count - mean * mean | |
| else: | |
| # New style: Npz file | |
| count = stats["count"] | |
| sum_v = stats["sum"] | |
| sum_square_v = stats["sum_square"] | |
| mean = sum_v / count | |
| var = sum_square_v / count - mean * mean | |
| std = np.sqrt(np.maximum(var, eps)) | |
| self.register_buffer("mean", torch.from_numpy(mean)) | |
| self.register_buffer("std", torch.from_numpy(std)) | |
| def extra_repr(self): | |
| return ( | |
| f"stats_file={self.stats_file}, " | |
| f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | |
| ) | |
| def forward( | |
| self, x: torch.Tensor, ilens: torch.Tensor = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward function | |
| Args: | |
| x: (B, L, ...) | |
| ilens: (B,) | |
| """ | |
| if ilens is None: | |
| ilens = x.new_full([x.size(0)], x.size(1)) | |
| norm_means = self.norm_means | |
| norm_vars = self.norm_vars | |
| self.mean = self.mean.to(x.device, x.dtype) | |
| self.std = self.std.to(x.device, x.dtype) | |
| mask = make_pad_mask(ilens, x, 1) | |
| # feat: (B, T, D) | |
| if norm_means: | |
| if x.requires_grad: | |
| x = x - self.mean | |
| else: | |
| x -= self.mean | |
| if x.requires_grad: | |
| x = x.masked_fill(mask, 0.0) | |
| else: | |
| x.masked_fill_(mask, 0.0) | |
| if norm_vars: | |
| x /= self.std | |
| return x, ilens | |
| def inverse( | |
| self, x: torch.Tensor, ilens: torch.Tensor = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if ilens is None: | |
| ilens = x.new_full([x.size(0)], x.size(1)) | |
| norm_means = self.norm_means | |
| norm_vars = self.norm_vars | |
| self.mean = self.mean.to(x.device, x.dtype) | |
| self.std = self.std.to(x.device, x.dtype) | |
| mask = make_pad_mask(ilens, x, 1) | |
| if x.requires_grad: | |
| x = x.masked_fill(mask, 0.0) | |
| else: | |
| x.masked_fill_(mask, 0.0) | |
| if norm_vars: | |
| x *= self.std | |
| # feat: (B, T, D) | |
| if norm_means: | |
| x += self.mean | |
| x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) | |
| return x, ilens | |