Spaces:
Sleeping
Sleeping
| from torch import nn | |
| import torch | |
| class ClassifierHead(nn.Module): | |
| """Basically a fancy MLP: 3-layer classifier head with GELU, LayerNorm, and Skip Connections.""" | |
| def __init__(self, hidden_size, num_labels, dropout_prob): | |
| super().__init__() | |
| # Layer 1 | |
| self.dense1 = nn.Linear(hidden_size, hidden_size) | |
| self.norm1 = nn.LayerNorm(hidden_size) | |
| self.activation = nn.GELU() | |
| self.dropout1 = nn.Dropout(dropout_prob) | |
| # Layer 2 | |
| self.dense2 = nn.Linear(hidden_size, hidden_size) | |
| self.norm2 = nn.LayerNorm(hidden_size) | |
| self.dropout2 = nn.Dropout(dropout_prob) | |
| # Output Layer | |
| self.out_proj = nn.Linear(hidden_size, num_labels) | |
| def forward(self, features): | |
| # Layer 1 | |
| identity1 = features | |
| x = self.norm1(features) | |
| x = self.dense1(x) | |
| x = self.activation(x) | |
| x = self.dropout1(x) | |
| x = x + identity1 # skip connection | |
| # Layer 2 | |
| identity2 = x | |
| x = self.norm2(x) | |
| x = self.dense2(x) | |
| x = self.activation(x) | |
| x = self.dropout2(x) | |
| x = x + identity2 # skip connection | |
| # Output Layer | |
| logits = self.out_proj(x) | |
| return logits | |
| class ConcatClassifierHead(nn.Module): | |
| """ | |
| An enhanced classifier head designed for concatenated CLS + Mean Pooling input. | |
| Includes an initial projection layer before the standard enhanced block. | |
| """ | |
| def __init__(self, input_size, hidden_size, num_labels, dropout_prob): | |
| super().__init__() | |
| # Initial projection from concatenated size (2*hidden) down to hidden_size | |
| self.initial_projection = nn.Linear(input_size, hidden_size) | |
| self.initial_norm = nn.LayerNorm(hidden_size) # Norm after projection | |
| self.initial_activation = nn.GELU() | |
| self.initial_dropout = nn.Dropout(dropout_prob) | |
| # Layer 1 | |
| self.dense1 = nn.Linear(hidden_size, hidden_size) | |
| self.norm1 = nn.LayerNorm(hidden_size) | |
| self.activation = nn.GELU() | |
| self.dropout1 = nn.Dropout(dropout_prob) | |
| # Layer 2 | |
| self.dense2 = nn.Linear(hidden_size, hidden_size) | |
| self.norm2 = nn.LayerNorm(hidden_size) | |
| self.dropout2 = nn.Dropout(dropout_prob) | |
| # Output Layer | |
| self.out_proj = nn.Linear(hidden_size, num_labels) | |
| def forward(self, features): | |
| # Initial Projection Step | |
| x = self.initial_projection(features) | |
| x = self.initial_norm(x) | |
| x = self.initial_activation(x) | |
| x = self.initial_dropout(x) | |
| # x should now be of shape (batch_size, hidden_size) | |
| # Layer 1 + Skip | |
| identity1 = x # Skip connection starts after initial projection | |
| x_res = self.norm1(x) | |
| x_res = self.dense1(x_res) | |
| x_res = self.activation(x_res) | |
| x_res = self.dropout1(x_res) | |
| x = x + x_res # skip connection | |
| # Layer 2 + Skip | |
| identity2 = x | |
| x_res = self.norm2(x) | |
| x_res = self.dense2(x_res) | |
| x_res = self.activation(x_res) | |
| x_res = self.dropout2(x_res) | |
| x = x + x_res # skip connection | |
| # Output Layer | |
| logits = self.out_proj(x) | |
| return logits | |
| # ExpansionClassifierHead currently not used | |
| class ExpansionClassifierHead(nn.Module): | |
| """ | |
| A classifier head using FFN-style expansion (input -> 4*hidden -> hidden -> labels). | |
| Takes concatenated CLS + Mean Pooled features as input. | |
| """ | |
| def __init__(self, input_size, hidden_size, num_labels, dropout_prob): | |
| super().__init__() | |
| intermediate_size = hidden_size * 4 # FFN expansion factor | |
| # Layer 1 (Expansion) | |
| self.norm1 = nn.LayerNorm(input_size) | |
| self.dense1 = nn.Linear(input_size, intermediate_size) | |
| self.activation = nn.GELU() | |
| self.dropout1 = nn.Dropout(dropout_prob) | |
| # Layer 2 (Projection back down) | |
| self.norm2 = nn.LayerNorm(intermediate_size) | |
| self.dense2 = nn.Linear(intermediate_size, hidden_size) | |
| # Activation and Dropout applied after projection | |
| self.dropout2 = nn.Dropout(dropout_prob) | |
| # Output Layer | |
| self.out_proj = nn.Linear(hidden_size, num_labels) | |
| def forward(self, features): | |
| # Layer 1 | |
| x = self.norm1(features) | |
| x = self.dense1(x) | |
| x = self.activation(x) | |
| x = self.dropout1(x) | |
| # Layer 2 | |
| x = self.norm2(x) | |
| x = self.dense2(x) | |
| x = self.activation(x) | |
| x = self.dropout2(x) | |
| # Output Layer | |
| logits = self.out_proj(x) | |
| return logits | |