mugu5 commited on
Commit
7152b2e
·
verified ·
1 Parent(s): e6c25d5

Update All_Model.py

Browse files
Files changed (1) hide show
  1. All_Model.py +19 -36
All_Model.py CHANGED
@@ -1,36 +1,19 @@
1
- from transformers import RobertaTokenizer, RobertaModel
2
- from transformers import BertModel, BertTokenizer
3
- from torch import nn
4
-
5
- #=============
6
- # RO-BERTA MODEL
7
- #=============
8
- RO_BERTA_MODEL_PATH = "./Models/roberta_model.pth"
9
- roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
10
- class RoBertaForMultiLabel(nn.Module):
11
- def __init__(self, pretrained_model='roberta-base', num_labels=5):
12
- super().__init__()
13
- self.bert = RobertaModel.from_pretrained(pretrained_model)
14
- self.dropout = nn.Dropout(0.3)
15
- self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
16
-
17
- def forward(self, input_ids, attention_mask):
18
- pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
19
- return self.classifier(self.dropout(pooled_output))
20
-
21
-
22
- #================
23
- # BERT MODEL
24
- #================
25
- BERT_MODEL_PATH = "./Models/BERT_MODEL.pth"
26
- bert_tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
27
- class BertForMultiLabel(nn.Module):
28
- def __init__(self, pretrained_model='bert-base-uncased', num_labels=5):
29
- super().__init__()
30
- self.bert = BertModel.from_pretrained(pretrained_model)
31
- self.dropout = nn.Dropout(0.3)
32
- self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
33
-
34
- def forward(self, input_ids, attention_mask):
35
- pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
36
- return self.classifier(self.dropout(pooled_output))
 
1
+ from transformers import RobertaTokenizer, RobertaModel
2
+ from transformers import BertModel, BertTokenizer
3
+ from torch import nn
4
+
5
+ #================
6
+ # BERT MODEL
7
+ #================
8
+ BERT_MODEL_PATH = "./BERT_MODEL.pth"
9
+ bert_tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
10
+ class BertForMultiLabel(nn.Module):
11
+ def __init__(self, pretrained_model='bert-base-uncased', num_labels=5):
12
+ super().__init__()
13
+ self.bert = BertModel.from_pretrained(pretrained_model)
14
+ self.dropout = nn.Dropout(0.3)
15
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
19
+ return self.classifier(self.dropout(pooled_output))