mugu5 commited on
Commit
cc7777e
·
verified ·
1 Parent(s): 215d271

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. All_Model.py +36 -0
  3. BERT_MODEL.pth +3 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ BERT_MODEL.pth filter=lfs diff=lfs merge=lfs -text
All_Model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer, RobertaModel
2
+ from transformers import BertModel, BertTokenizer
3
+ from torch import nn
4
+
5
+ #=============
6
+ # RO-BERTA MODEL
7
+ #=============
8
+ RO_BERTA_MODEL_PATH = "./Models/roberta_model.pth"
9
+ roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
10
+ class RoBertaForMultiLabel(nn.Module):
11
+ def __init__(self, pretrained_model='roberta-base', num_labels=5):
12
+ super().__init__()
13
+ self.bert = RobertaModel.from_pretrained(pretrained_model)
14
+ self.dropout = nn.Dropout(0.3)
15
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
19
+ return self.classifier(self.dropout(pooled_output))
20
+
21
+
22
+ #================
23
+ # BERT MODEL
24
+ #================
25
+ BERT_MODEL_PATH = "./Models/BERT_MODEL.pth"
26
+ bert_tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
27
+ class BertForMultiLabel(nn.Module):
28
+ def __init__(self, pretrained_model='bert-base-uncased', num_labels=5):
29
+ super().__init__()
30
+ self.bert = BertModel.from_pretrained(pretrained_model)
31
+ self.dropout = nn.Dropout(0.3)
32
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
33
+
34
+ def forward(self, input_ids, attention_mask):
35
+ pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
36
+ return self.classifier(self.dropout(pooled_output))
BERT_MODEL.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57c9356e403f57cc872c6b1110530816b52b4ff08b95cf2cabad9b6d0d754ceb
3
+ size 438026970