DL_Project / All_Model.py
mugu5's picture
Upload 2 files
cc7777e verified
raw
history blame
1.56 kB
from transformers import RobertaTokenizer, RobertaModel
from transformers import BertModel, BertTokenizer
from torch import nn
#=============
# RO-BERTA MODEL
#=============
RO_BERTA_MODEL_PATH = "./Models/roberta_model.pth"
roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
class RoBertaForMultiLabel(nn.Module):
def __init__(self, pretrained_model='roberta-base', num_labels=5):
super().__init__()
self.bert = RobertaModel.from_pretrained(pretrained_model)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
return self.classifier(self.dropout(pooled_output))
#================
# BERT MODEL
#================
BERT_MODEL_PATH = "./Models/BERT_MODEL.pth"
bert_tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
class BertForMultiLabel(nn.Module):
def __init__(self, pretrained_model='bert-base-uncased', num_labels=5):
super().__init__()
self.bert = BertModel.from_pretrained(pretrained_model)
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
return self.classifier(self.dropout(pooled_output))