File size: 2,289 Bytes
adbfcbd
 
 
 
 
 
 
 
 
211baff
 
adbfcbd
211baff
 
adbfcbd
 
 
 
 
 
 
211baff
adbfcbd
211baff
 
adbfcbd
 
211baff
adbfcbd
 
 
211baff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbfcbd
211baff
adbfcbd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

class DeBERTaLSTMClassifier(nn.Module):
    def __init__(self, hidden_dim=128, num_labels=2):
        super().__init__()
        self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
        
        # Đóng băng DeBERTa
        for param in self.deberta.parameters():
            param.requires_grad = False
            
        self.lstm = nn.LSTM(
            input_size=self.deberta.config.hidden_size,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        
        # Lớp Attention: chuyển đổi hidden state thành điểm số quan trọng (score)
        self.attention = nn.Linear(hidden_dim * 2, 1)
        
        self.fc = nn.Linear(hidden_dim * 2, num_labels)

    def forward(self, input_ids, attention_mask, return_attention=False):
        # 1. DeBERTa
        with torch.no_grad():
            outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
        
        # 2. LSTM
        lstm_out, _ = self.lstm(outputs.last_hidden_state)  # [batch, seq_len, hidden*2]
        
        # 3. Tính Attention (Luôn luôn thực hiện)
        # Tính score chưa qua softmax
        attn_scores = self.attention(lstm_out).squeeze(-1)  # [batch, seq_len]
        
        # Masking chuẩn: Gán giá trị rất nhỏ (-inf) cho các vị trí padding trước khi Softmax
        # Để đảm bảo padding có attention weight = 0 tuyệt đối
        mask = attention_mask.float()
        attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax để ra weights
        attn_weights = F.softmax(attn_scores, dim=-1)  # [batch, seq_len]
        
        # Tính Context Vector (Weighted Sum)
        # [batch, seq_len, 1] * [batch, seq_len, hidden*2] -> sum -> [batch, hidden*2]
        context_vector = torch.sum(attn_weights.unsqueeze(-1) * lstm_out, dim=1)
        
        # 4. Classification
        logits = self.fc(context_vector)
        
        # 5. Return tùy theo yêu cầu
        if return_attention:
            return logits, attn_weights, outputs.attentions
        else:
            return logits