Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, List, Optional, Tuple, Type | |
| import torch | |
| from torch import nn | |
| from torch.nn.modules import activation | |
| from torch.utils.checkpoint import checkpoint_sequential | |
| from .utils import ( | |
| band_widths_from_specs, | |
| check_no_gap, | |
| check_no_overlap, | |
| check_nonzero_bandwidth, | |
| ) | |
| class BaseNormMLP(nn.Module): | |
| def __init__( | |
| self, | |
| emb_dim: int, | |
| mlp_dim: int, | |
| bandwidth: int, | |
| in_channels: Optional[int], | |
| hidden_activation: str = "Tanh", | |
| hidden_activation_kwargs=None, | |
| complex_mask: bool = True, | |
| ): | |
| super().__init__() | |
| if hidden_activation_kwargs is None: | |
| hidden_activation_kwargs = {} | |
| self.hidden_activation_kwargs = hidden_activation_kwargs | |
| self.norm = nn.LayerNorm(emb_dim) | |
| self.hidden = nn.Sequential( | |
| nn.Linear(in_features=emb_dim, out_features=mlp_dim), | |
| activation.__dict__[hidden_activation](**self.hidden_activation_kwargs), | |
| ) | |
| self.bandwidth = bandwidth | |
| self.in_channels = in_channels | |
| self.complex_mask = complex_mask | |
| self.reim = 2 if complex_mask else 1 | |
| self.glu_mult = 2 | |
| class NormMLP(BaseNormMLP): | |
| def __init__( | |
| self, | |
| emb_dim: int, | |
| mlp_dim: int, | |
| bandwidth: int, | |
| in_channels: Optional[int], | |
| hidden_activation: str = "Tanh", | |
| hidden_activation_kwargs=None, | |
| complex_mask: bool = True, | |
| ) -> None: | |
| super().__init__( | |
| emb_dim=emb_dim, | |
| mlp_dim=mlp_dim, | |
| bandwidth=bandwidth, | |
| in_channels=in_channels, | |
| hidden_activation=hidden_activation, | |
| hidden_activation_kwargs=hidden_activation_kwargs, | |
| complex_mask=complex_mask, | |
| ) | |
| self.output = nn.Sequential( | |
| nn.Linear( | |
| in_features=mlp_dim, | |
| out_features=bandwidth * in_channels * self.reim * 2, | |
| ), | |
| nn.GLU(dim=-1), | |
| ) | |
| try: | |
| self.combined = torch.compile( | |
| nn.Sequential(self.norm, self.hidden, self.output), disable=True | |
| ) | |
| except Exception as e: | |
| self.combined = nn.Sequential(self.norm, self.hidden, self.output) | |
| def reshape_output(self, mb): | |
| # print(mb.shape) | |
| batch, n_time, _ = mb.shape | |
| if self.complex_mask: | |
| mb = mb.reshape( | |
| batch, n_time, self.in_channels, self.bandwidth, self.reim | |
| ).contiguous() | |
| # print(mb.shape) | |
| mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth) | |
| else: | |
| mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth) | |
| mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time) | |
| return mb | |
| def forward(self, qb): | |
| # qb = (batch, n_time, emb_dim) | |
| # qb = self.norm(qb) # (batch, n_time, emb_dim) | |
| # qb = self.hidden(qb) # (batch, n_time, mlp_dim) | |
| # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim) | |
| mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False) | |
| mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time) | |
| return mb | |
| class MaskEstimationModuleSuperBase(nn.Module): | |
| pass | |
| class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): | |
| def __init__( | |
| self, | |
| band_specs: List[Tuple[float, float]], | |
| emb_dim: int, | |
| mlp_dim: int, | |
| in_channels: Optional[int], | |
| hidden_activation: str = "Tanh", | |
| hidden_activation_kwargs: Dict = None, | |
| complex_mask: bool = True, | |
| norm_mlp_cls: Type[nn.Module] = NormMLP, | |
| norm_mlp_kwargs: Dict = None, | |
| ) -> None: | |
| super().__init__() | |
| self.band_widths = band_widths_from_specs(band_specs) | |
| self.n_bands = len(band_specs) | |
| if hidden_activation_kwargs is None: | |
| hidden_activation_kwargs = {} | |
| if norm_mlp_kwargs is None: | |
| norm_mlp_kwargs = {} | |
| self.norm_mlp = nn.ModuleList( | |
| [ | |
| norm_mlp_cls( | |
| bandwidth=self.band_widths[b], | |
| emb_dim=emb_dim, | |
| mlp_dim=mlp_dim, | |
| in_channels=in_channels, | |
| hidden_activation=hidden_activation, | |
| hidden_activation_kwargs=hidden_activation_kwargs, | |
| complex_mask=complex_mask, | |
| **norm_mlp_kwargs, | |
| ) | |
| for b in range(self.n_bands) | |
| ] | |
| ) | |
| def compute_masks(self, q): | |
| batch, n_bands, n_time, emb_dim = q.shape | |
| masks = [] | |
| for b, nmlp in enumerate(self.norm_mlp): | |
| # print(f"maskestim/{b:02d}") | |
| qb = q[:, b, :, :] | |
| mb = nmlp(qb) | |
| masks.append(mb) | |
| return masks | |
| def compute_mask(self, q, b): | |
| batch, n_bands, n_time, emb_dim = q.shape | |
| qb = q[:, b, :, :] | |
| mb = self.norm_mlp[b](qb) | |
| return mb | |
| class OverlappingMaskEstimationModule(MaskEstimationModuleBase): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| band_specs: List[Tuple[float, float]], | |
| freq_weights: List[torch.Tensor], | |
| n_freq: int, | |
| emb_dim: int, | |
| mlp_dim: int, | |
| cond_dim: int = 0, | |
| hidden_activation: str = "Tanh", | |
| hidden_activation_kwargs: Dict = None, | |
| complex_mask: bool = True, | |
| norm_mlp_cls: Type[nn.Module] = NormMLP, | |
| norm_mlp_kwargs: Dict = None, | |
| use_freq_weights: bool = False, | |
| ) -> None: | |
| check_nonzero_bandwidth(band_specs) | |
| check_no_gap(band_specs) | |
| if cond_dim > 0: | |
| raise NotImplementedError | |
| super().__init__( | |
| band_specs=band_specs, | |
| emb_dim=emb_dim + cond_dim, | |
| mlp_dim=mlp_dim, | |
| in_channels=in_channels, | |
| hidden_activation=hidden_activation, | |
| hidden_activation_kwargs=hidden_activation_kwargs, | |
| complex_mask=complex_mask, | |
| norm_mlp_cls=norm_mlp_cls, | |
| norm_mlp_kwargs=norm_mlp_kwargs, | |
| ) | |
| self.n_freq = n_freq | |
| self.band_specs = band_specs | |
| self.in_channels = in_channels | |
| if freq_weights is not None and use_freq_weights: | |
| for i, fw in enumerate(freq_weights): | |
| self.register_buffer(f"freq_weights/{i}", fw) | |
| self.use_freq_weights = use_freq_weights | |
| else: | |
| self.use_freq_weights = False | |
| def forward(self, q): | |
| # q = (batch, n_bands, n_time, emb_dim) | |
| batch, n_bands, n_time, emb_dim = q.shape | |
| masks = torch.zeros( | |
| (batch, self.in_channels, self.n_freq, n_time), | |
| device=q.device, | |
| dtype=torch.complex64, | |
| ) | |
| for im in range(n_bands): | |
| fstart, fend = self.band_specs[im] | |
| mask = self.compute_mask(q, im) | |
| if self.use_freq_weights: | |
| fw = self.get_buffer(f"freq_weights/{im}")[:, None] | |
| mask = mask * fw | |
| masks[:, :, fstart:fend, :] += mask | |
| return masks | |
| class MaskEstimationModule(OverlappingMaskEstimationModule): | |
| def __init__( | |
| self, | |
| band_specs: List[Tuple[float, float]], | |
| emb_dim: int, | |
| mlp_dim: int, | |
| in_channels: Optional[int], | |
| hidden_activation: str = "Tanh", | |
| hidden_activation_kwargs: Dict = None, | |
| complex_mask: bool = True, | |
| **kwargs, | |
| ) -> None: | |
| check_nonzero_bandwidth(band_specs) | |
| check_no_gap(band_specs) | |
| check_no_overlap(band_specs) | |
| super().__init__( | |
| in_channels=in_channels, | |
| band_specs=band_specs, | |
| freq_weights=None, | |
| n_freq=None, | |
| emb_dim=emb_dim, | |
| mlp_dim=mlp_dim, | |
| hidden_activation=hidden_activation, | |
| hidden_activation_kwargs=hidden_activation_kwargs, | |
| complex_mask=complex_mask, | |
| ) | |
| def forward(self, q, cond=None): | |
| # q = (batch, n_bands, n_time, emb_dim) | |
| masks = self.compute_masks( | |
| q | |
| ) # [n_bands * (batch, in_channels, bandwidth, n_time)] | |
| # TODO: currently this requires band specs to have no gap and no overlap | |
| masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time) | |
| return masks | |