|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from patch_utils import MindSpeedPatchesManager as aspm |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import logging |
|
|
import torchaudio.transforms as trans |
|
|
from s3prl.upstream.wavlm.expert import UpstreamExpert as s3prl_UpstreamExpert |
|
|
from models.ecapa_tdnn import Conv1dReluBn, SE_Res2Block, AttentiveStatsPool |
|
|
from models.ecapa_tdnn import ECAPA_TDNN_SMALL, ECAPA_TDNN |
|
|
|
|
|
def init_model_patched(model_name, checkpoint=None): |
|
|
S3PRL_PATH = os.environ.get("S3PRL_PATH") |
|
|
if model_name == 'unispeech_sat': |
|
|
config_path = 'config/unispeech_sat.th' |
|
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path) |
|
|
elif model_name == 'wavlm_base_plus': |
|
|
config_path = None |
|
|
model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path) |
|
|
elif model_name == 'wavlm_large': |
|
|
config_path = S3PRL_PATH |
|
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path) |
|
|
elif model_name == 'hubert_large': |
|
|
config_path = None |
|
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path) |
|
|
elif model_name == 'wav2vec2_xlsr': |
|
|
config_path = None |
|
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path) |
|
|
else: |
|
|
model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank') |
|
|
|
|
|
if checkpoint is not None: |
|
|
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) |
|
|
model.load_state_dict(state_dict['model'], strict=False) |
|
|
return model |
|
|
|
|
|
|
|
|
class patched_ECAPA_TDNN(ECAPA_TDNN): |
|
|
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, |
|
|
feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): |
|
|
super(ECAPA_TDNN, self).__init__() |
|
|
|
|
|
self.feat_type = feat_type |
|
|
self.feature_selection = feature_selection |
|
|
self.update_extract = update_extract |
|
|
self.sr = sr |
|
|
|
|
|
if feat_type == "fbank" or feat_type == "mfcc": |
|
|
self.update_extract = False |
|
|
|
|
|
win_len = int(sr * 0.025) |
|
|
hop_len = int(sr * 0.01) |
|
|
|
|
|
if feat_type == 'fbank': |
|
|
self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len, |
|
|
hop_length=hop_len, f_min=0.0, f_max=sr // 2, |
|
|
pad=0, n_mels=feat_dim) |
|
|
elif feat_type == 'mfcc': |
|
|
melkwargs = { |
|
|
'n_fft': 512, |
|
|
'win_length': win_len, |
|
|
'hop_length': hop_len, |
|
|
'f_min': 0.0, |
|
|
'f_max': sr // 2, |
|
|
'pad': 0 |
|
|
} |
|
|
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False, |
|
|
melkwargs=melkwargs) |
|
|
else: |
|
|
if config_path is None: |
|
|
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type) |
|
|
else: |
|
|
self.feature_extract = s3prl_UpstreamExpert(config_path) |
|
|
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): |
|
|
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False |
|
|
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): |
|
|
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False |
|
|
|
|
|
self.feat_num = self.get_feat_num() |
|
|
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) |
|
|
|
|
|
if feat_type != 'fbank' and feat_type != 'mfcc': |
|
|
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer'] |
|
|
for name, param in self.feature_extract.named_parameters(): |
|
|
for freeze_val in freeze_list: |
|
|
if freeze_val in name: |
|
|
param.requires_grad = False |
|
|
break |
|
|
|
|
|
if not self.update_extract: |
|
|
for param in self.feature_extract.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.instance_norm = nn.InstanceNorm1d(feat_dim) |
|
|
|
|
|
self.channels = [channels] * 4 + [1536] |
|
|
|
|
|
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) |
|
|
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) |
|
|
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) |
|
|
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) |
|
|
|
|
|
|
|
|
cat_channels = channels * 3 |
|
|
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) |
|
|
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) |
|
|
self.bn = nn.BatchNorm1d(self.channels[-1] * 2) |
|
|
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) |
|
|
|
|
|
|
|
|
def patched_ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): |
|
|
return patched_ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim, |
|
|
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path) |
|
|
|
|
|
def patch_for_npu(): |
|
|
aspm.register_patch('models.ecapa_tdnn.ECAPA_TDNN_SMALL', patched_ECAPA_TDNN_SMALL) |
|
|
aspm.register_patch('verification.init_model', init_model_patched) |
|
|
aspm.apply_patches() |