# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from patch_utils import MindSpeedPatchesManager as aspm import os import torch import torch.nn as nn import logging import torchaudio.transforms as trans from s3prl.upstream.wavlm.expert import UpstreamExpert as s3prl_UpstreamExpert from models.ecapa_tdnn import Conv1dReluBn, SE_Res2Block, AttentiveStatsPool from models.ecapa_tdnn import ECAPA_TDNN_SMALL, ECAPA_TDNN def init_model_patched(model_name, checkpoint=None): S3PRL_PATH = os.environ.get("S3PRL_PATH") if model_name == 'unispeech_sat': config_path = 'config/unispeech_sat.th' model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path) elif model_name == 'wavlm_base_plus': config_path = None model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path) elif model_name == 'wavlm_large': config_path = S3PRL_PATH model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path) elif model_name == 'hubert_large': config_path = None model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path) elif model_name == 'wav2vec2_xlsr': config_path = None model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path) else: model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank') if checkpoint is not None: state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict['model'], strict=False) return model class patched_ECAPA_TDNN(ECAPA_TDNN): def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): super(ECAPA_TDNN, self).__init__() self.feat_type = feat_type self.feature_selection = feature_selection self.update_extract = update_extract self.sr = sr if feat_type == "fbank" or feat_type == "mfcc": self.update_extract = False win_len = int(sr * 0.025) hop_len = int(sr * 0.01) if feat_type == 'fbank': self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len, hop_length=hop_len, f_min=0.0, f_max=sr // 2, pad=0, n_mels=feat_dim) elif feat_type == 'mfcc': melkwargs = { 'n_fft': 512, 'win_length': win_len, 'hop_length': hop_len, 'f_min': 0.0, 'f_max': sr // 2, 'pad': 0 } self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False, melkwargs=melkwargs) else: if config_path is None: self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type) else: self.feature_extract = s3prl_UpstreamExpert(config_path) if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False self.feat_num = self.get_feat_num() self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) if feat_type != 'fbank' and feat_type != 'mfcc': freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer'] for name, param in self.feature_extract.named_parameters(): for freeze_val in freeze_list: if freeze_val in name: param.requires_grad = False break if not self.update_extract: for param in self.feature_extract.parameters(): param.requires_grad = False self.instance_norm = nn.InstanceNorm1d(feat_dim) # self.channels = [channels] * 4 + [channels * 3] self.channels = [channels] * 4 + [1536] self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) cat_channels = channels * 3 self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) self.bn = nn.BatchNorm1d(self.channels[-1] * 2) self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) def patched_ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): return patched_ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim, feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path) def patch_for_npu(): aspm.register_patch('models.ecapa_tdnn.ECAPA_TDNN_SMALL', patched_ECAPA_TDNN_SMALL) aspm.register_patch('verification.init_model', init_model_patched) aspm.apply_patches()