Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from functools import partial | |
| import torch.nn.functional as F | |
| from typing import Callable, Dict | |
| from funasr_detach.models.emotion2vec.fairseq_modules import ( | |
| LayerNorm, | |
| SamePad, | |
| TransposeLast, | |
| ConvFeatureExtractionModel, | |
| ) | |
| from funasr_detach.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d | |
| from funasr_detach.models.emotion2vec.base import ( | |
| ModalitySpecificEncoder, | |
| get_alibi_bias, | |
| ) | |
| class AudioEncoder(ModalitySpecificEncoder): | |
| def __init__( | |
| self, | |
| modality_cfg, | |
| embed_dim: int, | |
| make_block: Callable[[float], nn.ModuleList], | |
| norm_layer: Callable[[int], nn.LayerNorm], | |
| layer_norm_first: bool, | |
| alibi_biases: Dict, | |
| ): | |
| self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec) | |
| feature_embed_dim = self.feature_enc_layers[-1][0] | |
| local_encoder = ConvFeatureExtractionModel( | |
| conv_layers=self.feature_enc_layers, | |
| dropout=0.0, | |
| mode=modality_cfg.extractor_mode, | |
| conv_bias=False, | |
| ) | |
| project_features = nn.Sequential( | |
| TransposeLast(), | |
| nn.LayerNorm(feature_embed_dim), | |
| nn.Linear(feature_embed_dim, embed_dim), | |
| ) | |
| num_pos_layers = modality_cfg.conv_pos_depth | |
| k = max(3, modality_cfg.conv_pos_width // num_pos_layers) | |
| positional_encoder = nn.Sequential( | |
| TransposeLast(), | |
| *[ | |
| nn.Sequential( | |
| nn.Conv1d( | |
| embed_dim, | |
| embed_dim, | |
| kernel_size=k, | |
| padding=k // 2, | |
| groups=modality_cfg.conv_pos_groups, | |
| ), | |
| SamePad(k), | |
| TransposeLast(), | |
| LayerNorm(embed_dim, elementwise_affine=False), | |
| TransposeLast(), | |
| nn.GELU(), | |
| ) | |
| for _ in range(num_pos_layers) | |
| ], | |
| TransposeLast(), | |
| ) | |
| if modality_cfg.conv_pos_pre_ln: | |
| positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder) | |
| dpr = np.linspace( | |
| modality_cfg.start_drop_path_rate, | |
| modality_cfg.end_drop_path_rate, | |
| modality_cfg.prenet_depth, | |
| ) | |
| context_encoder = BlockEncoder( | |
| nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), | |
| norm_layer(embed_dim) if not layer_norm_first else None, | |
| layer_norm_first, | |
| modality_cfg.prenet_layerdrop, | |
| modality_cfg.prenet_dropout, | |
| ) | |
| decoder = ( | |
| Decoder1d(modality_cfg.decoder, embed_dim) | |
| if modality_cfg.decoder is not None | |
| else None | |
| ) | |
| alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) | |
| super().__init__( | |
| modality_cfg=modality_cfg, | |
| embed_dim=embed_dim, | |
| local_encoder=local_encoder, | |
| project_features=project_features, | |
| fixed_positional_encoder=None, | |
| relative_positional_encoder=positional_encoder, | |
| context_encoder=context_encoder, | |
| decoder=decoder, | |
| get_alibi_bias=alibi_bias_fn, | |
| ) | |
| def convert_padding_mask(self, x, padding_mask): | |
| def get_feat_extract_output_lengths(input_lengths: torch.LongTensor): | |
| """ | |
| Computes the output length of the convolutional layers | |
| """ | |
| def _conv_out_length(input_length, kernel_size, stride): | |
| return torch.floor((input_length - kernel_size) / stride + 1) | |
| for i in range(len(self.feature_enc_layers)): | |
| input_lengths = _conv_out_length( | |
| input_lengths, | |
| self.feature_enc_layers[i][1], | |
| self.feature_enc_layers[i][2], | |
| ) | |
| return input_lengths.to(torch.long) | |
| if padding_mask is not None: | |
| input_lengths = (1 - padding_mask.long()).sum(-1) | |
| # apply conv formula to get real output_lengths | |
| output_lengths = get_feat_extract_output_lengths(input_lengths) | |
| if padding_mask.any(): | |
| padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device) | |
| # these two operations makes sure that all values | |
| # before the output lengths indices are attended to | |
| padding_mask[ | |
| ( | |
| torch.arange(padding_mask.shape[0], device=padding_mask.device), | |
| output_lengths - 1, | |
| ) | |
| ] = 1 | |
| padding_mask = ( | |
| 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1]) | |
| ).bool() | |
| else: | |
| padding_mask = torch.zeros( | |
| x.shape[:2], dtype=torch.bool, device=x.device | |
| ) | |
| return padding_mask | |
| def reset_parameters(self): | |
| super().reset_parameters() | |
| for mod in self.project_features.children(): | |
| if isinstance(mod, nn.Linear): | |
| mod.reset_parameters() | |
| if self.decoder is not None: | |
| self.decoder.reset_parameters() | |