Neweret commited on
Commit
1be121d
·
verified ·
1 Parent(s): dd06c8e

Create modeling_simple_classifier

Browse files
Files changed (1) hide show
  1. modeling_simple_classifier +30 -0
modeling_simple_classifier ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_simple_classifier.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel
6
+
7
+
8
+ class SimpleClassifierConfig:
9
+ model_type = "simple_classifier"
10
+
11
+
12
+ class SimpleClassifier(PreTrainedModel):
13
+ config_class = SimpleClassifierConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.linear1 = nn.Linear(config.input_dim, 256)
18
+ self.ln1 = nn.LayerNorm(256)
19
+ self.dropout = nn.Dropout(config.p_dropout)
20
+ self.linear2 = nn.Linear(256, 128)
21
+ self.ln2 = nn.LayerNorm(128)
22
+ self.linear_out = nn.Linear(128, config.num_classes)
23
+ self.post_init()
24
+
25
+ def forward(self, x):
26
+ x = F.gelu(self.ln1(self.linear1(x)))
27
+ x = self.dropout(x)
28
+ x = F.gelu(self.ln2(self.linear2(x)))
29
+ x = self.dropout(x)
30
+ return self.linear_out(x)