Spaces:
Runtime error
Runtime error
| import types | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| class WhisperWrappedEncoder: | |
| def load(cls, model_config): | |
| def extract_variable_length_features(self, x: torch.Tensor): | |
| """ | |
| x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) | |
| the mel spectrogram of the audio | |
| """ | |
| x = F.gelu(self.conv1(x)) | |
| x = F.gelu(self.conv2(x)) | |
| x = x.permute(0, 2, 1) | |
| # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" | |
| # x = (x + self.positional_embedding).to(x.dtype) | |
| x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.ln_post(x) | |
| return x | |
| import whisper | |
| encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder | |
| encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder) | |
| return encoder | |
| class BEATsEncoder: | |
| def load(cls, model_config): | |
| from .BEATs.BEATs import BEATs, BEATsConfig | |
| checkpoint = torch.load(model_config.encoder_path) | |
| cfg = BEATsConfig(checkpoint['cfg']) | |
| BEATs_model = BEATs(cfg) | |
| BEATs_model.load_state_dict(checkpoint['model']) | |
| return BEATs_model | |
| class UserDirModule: | |
| user_dir: str | |
| class EATEncoder: | |
| def load(cls, model_config): | |
| import fairseq | |
| model_path = UserDirModule(model_config.encoder_fairseq_dir) | |
| fairseq.utils.import_user_module(model_path) | |
| EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) | |
| EATEncoder = EATEncoder[0] | |
| return EATEncoder | |
| def extract_features(self, source, padding_mask): | |
| return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x'] | |
| class SpatialASTEncoder: | |
| def load(cls, model_config): | |
| from functools import partial | |
| from .SpatialAST import SpatialAST | |
| binaural_encoder = SpatialAST.BinauralEncoder( | |
| num_classes=355, drop_path_rate=0.1, num_cls_tokens=3, | |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, | |
| qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) | |
| ) | |
| checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu') | |
| binaural_encoder.load_state_dict(checkpoint['model'], strict=False) | |
| return binaural_encoder | |
| class WavLMEncoder(nn.Module): | |
| def __init__(self, config, model): | |
| super().__init__() | |
| self.config = config | |
| self.model = model | |
| def load(cls, model_config): | |
| from .wavlm.WavLM import WavLM, WavLMConfig | |
| checkpoint = torch.load(model_config.encoder_path) | |
| cfg = WavLMConfig(checkpoint['cfg']) | |
| WavLM_model = WavLM(cfg) | |
| WavLM_model.load_state_dict(checkpoint['model']) | |
| assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match" | |
| return cls(cfg, WavLM_model) | |
| def extract_features(self, source, padding_mask): | |
| return self.model.extract_features(source, padding_mask)[0] | |
| class AVHubertEncoder: | |
| def load(cls, model_config): | |
| import fairseq | |
| from .avhubert import hubert_pretraining, hubert, hubert_asr | |
| models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) | |
| model = models[0] | |
| return model | |
| class HubertEncoder: | |
| def load(cls, model_config): | |
| import fairseq | |
| models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) | |
| model = models[0] | |
| if model_config.encoder_type == "pretrain": | |
| pass | |
| elif model_config.encoder_type == "finetune": | |
| model.w2v_encoder.proj = None | |
| model.w2v_encoder.apply_mask = False | |
| else: | |
| assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]" | |
| return model | |
| class HfTextEncoder: | |
| def load(cls, model_config): | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained(model_config.encoder_path) | |
| return model | |
| class MusicFMEncoder(nn.Module): | |
| def __init__(self, config, model): | |
| super().__init__() | |
| self.config = config | |
| self.model = model | |
| def load(cls, model_config): | |
| from .musicfm.model.musicfm_25hz import MusicFM25Hz | |
| model = MusicFM25Hz( | |
| stat_path = model_config.encoder_stat_path, | |
| model_path = model_config.encoder_path, | |
| w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft") | |
| ) | |
| return cls(model_config, model) | |
| def extract_features(self, source, padding_mask=None): | |
| _, hidden_states = self.model.get_predictions(source) | |
| out = hidden_states[self.config.encoder_layer_idx] | |
| return out | |