| from transformers import BertTokenizer, BertModel | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| import torch | |
| import torch.nn as nn | |
| class TypeBERTConfig(PretrainedConfig): | |
| model_type = "type_bert" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.id2label = { | |
| 0: "agent", | |
| 1: "event", | |
| 2: "place", | |
| 3: "item", | |
| 4: "virtual", | |
| 5: "concept" | |
| } | |
| self.label2id = { | |
| "agent": 0, | |
| "event": 1, | |
| "place": 2, | |
| "item": 3, | |
| "virtual": 4, | |
| "concept": 5 | |
| } | |
| self.architectures = ['TypeBERTForSequenceClassification'] | |
| self.tokenizer_class = 'bert-base-uncased' | |
| class TypeBERTForSequenceClassification(PreTrainedModel): | |
| config_class = TypeBERTConfig | |
| def __init__(self, config): | |
| super(TypeBERTForSequenceClassification, self).__init__(config) | |
| self.bert = BertModel.from_pretrained("bert-base-uncased") | |
| # for param in self.bert.base_model.parameters(): | |
| # param.requires_grad = False | |
| # | |
| # self.bert.eval() | |
| self.tanh = nn.Tanh() | |
| self.dff = nn.Sequential( | |
| nn.Linear(768, 2048), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(2048, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(512, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(64, 6), | |
| nn.LogSoftmax(dim=1) | |
| ) | |
| self.eval() | |
| def forward(self, **kwargs): | |
| a = kwargs['attention_mask'] | |
| embs = self.bert(**kwargs)['last_hidden_state'] | |
| embs *= a.unsqueeze(2) | |
| out = embs.sum(dim=1) / a.sum(dim=1, keepdims=True) | |
| return {'logits': self.dff(self.tanh(out))} |