| import logging | |
| import torch.nn as nn | |
| from transformers import HubertModel, Wav2Vec2FeatureExtractor | |
| logging.getLogger("numba").setLevel(logging.WARNING) | |
| class CNHubert(nn.Module): | |
| def __init__(self, cnhubert_base_path): | |
| super().__init__() | |
| self.model = HubertModel.from_pretrained(cnhubert_base_path) | |
| self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path) | |
| def forward(self, x): | |
| input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) | |
| feats = self.model(input_values)["last_hidden_state"] | |
| return feats | |
| def get_model(cnhubert_base_path): | |
| model = CNHubert(cnhubert_base_path) | |
| model.eval() | |
| return model | |