Spaces:
Paused
Paused
File size: 2,289 Bytes
adbfcbd 211baff adbfcbd 211baff adbfcbd 211baff adbfcbd 211baff adbfcbd 211baff adbfcbd 211baff adbfcbd 211baff adbfcbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
class DeBERTaLSTMClassifier(nn.Module):
def __init__(self, hidden_dim=128, num_labels=2):
super().__init__()
self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
# Đóng băng DeBERTa
for param in self.deberta.parameters():
param.requires_grad = False
self.lstm = nn.LSTM(
input_size=self.deberta.config.hidden_size,
hidden_size=hidden_dim,
batch_first=True,
bidirectional=True
)
# Lớp Attention: chuyển đổi hidden state thành điểm số quan trọng (score)
self.attention = nn.Linear(hidden_dim * 2, 1)
self.fc = nn.Linear(hidden_dim * 2, num_labels)
def forward(self, input_ids, attention_mask, return_attention=False):
# 1. DeBERTa
with torch.no_grad():
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
# 2. LSTM
lstm_out, _ = self.lstm(outputs.last_hidden_state) # [batch, seq_len, hidden*2]
# 3. Tính Attention (Luôn luôn thực hiện)
# Tính score chưa qua softmax
attn_scores = self.attention(lstm_out).squeeze(-1) # [batch, seq_len]
# Masking chuẩn: Gán giá trị rất nhỏ (-inf) cho các vị trí padding trước khi Softmax
# Để đảm bảo padding có attention weight = 0 tuyệt đối
mask = attention_mask.float()
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# Softmax để ra weights
attn_weights = F.softmax(attn_scores, dim=-1) # [batch, seq_len]
# Tính Context Vector (Weighted Sum)
# [batch, seq_len, 1] * [batch, seq_len, hidden*2] -> sum -> [batch, hidden*2]
context_vector = torch.sum(attn_weights.unsqueeze(-1) * lstm_out, dim=1)
# 4. Classification
logits = self.fc(context_vector)
# 5. Return tùy theo yêu cầu
if return_attention:
return logits, attn_weights, outputs.attentions
else:
return logits |