from transformers import BertConfig, BertModel import torch.nn as nn class EBertConfig(BertConfig): model_type = "ebert" def __init__(self, **kwargs): super().__init__(**kwargs) self.adapter_size = kwargs.pop('adapter_size', None) class EBertModel(BertModel): config_class = EBertConfig def __init__(self, config: EBertConfig): super().__init__(config) if config.adapter_size: self.adapters = nn.ModuleList([ nn.Sequential( nn.Linear(config.hidden_size, config.adapter_size), nn.ReLU(), nn.Linear(config.adapter_size, config.hidden_size), ) for _ in range(config.num_hidden_layers) ]) else: self.adapters = None def forward(self, *args, **kwargs): outputs = super().forward(*args, **kwargs) sequence_output = outputs.last_hidden_state if self.adapters is not None: for adapter in self.adapters: sequence_output = sequence_output + adapter(sequence_output) return outputs.__class__( last_hidden_state=sequence_output, pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )