dungeon29 commited on
Commit
211baff
·
verified ·
1 Parent(s): 835dd8a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +32 -23
model.py CHANGED
@@ -6,44 +6,53 @@ from transformers import AutoModel
6
  class DeBERTaLSTMClassifier(nn.Module):
7
  def __init__(self, hidden_dim=128, num_labels=2):
8
  super().__init__()
9
-
10
  self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
 
 
11
  for param in self.deberta.parameters():
12
- param.requires_grad = False # freeze DeBERTa (as we don't have enough resources, we will not train DeBERTa in this model)
13
-
14
  self.lstm = nn.LSTM(
15
  input_size=self.deberta.config.hidden_size,
16
  hidden_size=hidden_dim,
17
  batch_first=True,
18
  bidirectional=True
19
  )
20
-
21
- self.fc = nn.Linear(hidden_dim * 2, num_labels)
22
 
23
- # Attention layer để tính token importance
24
  self.attention = nn.Linear(hidden_dim * 2, 1)
 
 
25
 
26
  def forward(self, input_ids, attention_mask, return_attention=False):
 
27
  with torch.no_grad():
28
  outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
29
-
30
- lstm_out, _ = self.lstm(outputs.last_hidden_state) # shape: [batch, seq_len, hidden*2]
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if return_attention:
33
- # Tính attention weights cho từng token
34
- attention_weights = self.attention(lstm_out) # [batch, seq_len, 1]
35
- attention_weights = F.softmax(attention_weights.squeeze(-1), dim=-1) # [batch, seq_len]
36
-
37
- # Apply attention mask
38
- attention_weights = attention_weights * attention_mask.float()
39
- attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8)
40
-
41
- # Weighted sum of LSTM outputs
42
- attended_output = torch.sum(lstm_out * attention_weights.unsqueeze(-1), dim=1)
43
- logits = self.fc(attended_output)
44
-
45
- return logits, attention_weights, outputs.attentions
46
  else:
47
- final_hidden = lstm_out[:, -1, :] # last token output
48
- logits = self.fc(final_hidden)
49
  return logits
 
6
  class DeBERTaLSTMClassifier(nn.Module):
7
  def __init__(self, hidden_dim=128, num_labels=2):
8
  super().__init__()
 
9
  self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
10
+
11
+ # Đóng băng DeBERTa
12
  for param in self.deberta.parameters():
13
+ param.requires_grad = False
14
+
15
  self.lstm = nn.LSTM(
16
  input_size=self.deberta.config.hidden_size,
17
  hidden_size=hidden_dim,
18
  batch_first=True,
19
  bidirectional=True
20
  )
 
 
21
 
22
+ # Lớp Attention: chuyển đổi hidden state thành điểm số quan trọng (score)
23
  self.attention = nn.Linear(hidden_dim * 2, 1)
24
+
25
+ self.fc = nn.Linear(hidden_dim * 2, num_labels)
26
 
27
  def forward(self, input_ids, attention_mask, return_attention=False):
28
+ # 1. DeBERTa
29
  with torch.no_grad():
30
  outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
 
 
31
 
32
+ # 2. LSTM
33
+ lstm_out, _ = self.lstm(outputs.last_hidden_state) # [batch, seq_len, hidden*2]
34
+
35
+ # 3. Tính Attention (Luôn luôn thực hiện)
36
+ # Tính score chưa qua softmax
37
+ attn_scores = self.attention(lstm_out).squeeze(-1) # [batch, seq_len]
38
+
39
+ # Masking chuẩn: Gán giá trị rất nhỏ (-inf) cho các vị trí padding trước khi Softmax
40
+ # Để đảm bảo padding có attention weight = 0 tuyệt đối
41
+ mask = attention_mask.float()
42
+ attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
43
+
44
+ # Softmax để ra weights
45
+ attn_weights = F.softmax(attn_scores, dim=-1) # [batch, seq_len]
46
+
47
+ # Tính Context Vector (Weighted Sum)
48
+ # [batch, seq_len, 1] * [batch, seq_len, hidden*2] -> sum -> [batch, hidden*2]
49
+ context_vector = torch.sum(attn_weights.unsqueeze(-1) * lstm_out, dim=1)
50
+
51
+ # 4. Classification
52
+ logits = self.fc(context_vector)
53
+
54
+ # 5. Return tùy theo yêu cầu
55
  if return_attention:
56
+ return logits, attn_weights, outputs.attentions
 
 
 
 
 
 
 
 
 
 
 
 
57
  else:
 
 
58
  return logits