File size: 878 Bytes
1be121d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# modeling_simple_classifier.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel

class SimpleClassifierConfig:
    model_type = "simple_classifier"

class SimpleClassifier(PreTrainedModel):
    config_class = SimpleClassifierConfig

    def __init__(self, config):
        super().__init__(config)
        self.linear1 = nn.Linear(config.input_dim, 256)
        self.ln1 = nn.LayerNorm(256)
        self.dropout = nn.Dropout(config.p_dropout)
        self.linear2 = nn.Linear(256, 128)
        self.ln2 = nn.LayerNorm(128)
        self.linear_out = nn.Linear(128, config.num_classes)
        self.post_init()

    def forward(self, x):
        x = F.gelu(self.ln1(self.linear1(x)))
        x = self.dropout(x)
        x = F.gelu(self.ln2(self.linear2(x)))
        x = self.dropout(x)
        return self.linear_out(x)