from transformers import RobertaTokenizer, RobertaModel from transformers import BertModel, BertTokenizer from torch import nn #================ # BERT MODEL #================ BERT_MODEL_PATH = "./BERT_MODEL.pth" bert_tokenizer=BertTokenizer.from_pretrained('bert-base-uncased') class BertForMultiLabel(nn.Module): def __init__(self, pretrained_model='bert-base-uncased', num_labels=5): super().__init__() self.bert = BertModel.from_pretrained(pretrained_model) self.dropout = nn.Dropout(0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output return self.classifier(self.dropout(pooled_output))