DSTK / evaluation /patch_unispeech.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from patch_utils import MindSpeedPatchesManager as aspm
import os
import torch
import torch.nn as nn
import logging
import torchaudio.transforms as trans
from s3prl.upstream.wavlm.expert import UpstreamExpert as s3prl_UpstreamExpert
from models.ecapa_tdnn import Conv1dReluBn, SE_Res2Block, AttentiveStatsPool
from models.ecapa_tdnn import ECAPA_TDNN_SMALL, ECAPA_TDNN
def init_model_patched(model_name, checkpoint=None):
S3PRL_PATH = os.environ.get("S3PRL_PATH")
if model_name == 'unispeech_sat':
config_path = 'config/unispeech_sat.th'
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
elif model_name == 'wavlm_base_plus':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
elif model_name == 'wavlm_large':
config_path = S3PRL_PATH
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
elif model_name == 'hubert_large':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
elif model_name == 'wav2vec2_xlsr':
config_path = None
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
else:
model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
if checkpoint is not None:
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict['model'], strict=False)
return model
class patched_ECAPA_TDNN(ECAPA_TDNN):
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
super(ECAPA_TDNN, self).__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
if feat_type == "fbank" or feat_type == "mfcc":
self.update_extract = False
win_len = int(sr * 0.025)
hop_len = int(sr * 0.01)
if feat_type == 'fbank':
self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
hop_length=hop_len, f_min=0.0, f_max=sr // 2,
pad=0, n_mels=feat_dim)
elif feat_type == 'mfcc':
melkwargs = {
'n_fft': 512,
'win_length': win_len,
'hop_length': hop_len,
'f_min': 0.0,
'f_max': sr // 2,
'pad': 0
}
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
melkwargs=melkwargs)
else:
if config_path is None:
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
else:
self.feature_extract = s3prl_UpstreamExpert(config_path)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def patched_ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
return patched_ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
def patch_for_npu():
aspm.register_patch('models.ecapa_tdnn.ECAPA_TDNN_SMALL', patched_ECAPA_TDNN_SMALL)
aspm.register_patch('verification.init_model', init_model_patched)
aspm.apply_patches()