Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,936 Bytes
c3c908f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
class MERTFeatureExtractor(nn.Module):
def __init__(self, freeze_feature_extractor=True):
super(MERTFeatureExtractor, self).__init__()
config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
if not hasattr(config, "conv_pos_batch_norm"):
setattr(config, "conv_pos_batch_norm", False)
self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True)
if freeze_feature_extractor:
self.freeze()
def forward(self, input_values):
# ์
๋ ฅ: [batch, time]
# ์ฌ์ ํ์ต๋ MERT์ hidden_states ์ถ์ถ (์์๋ก ๋ชจ๋ ๋ ์ด์ด์ hidden state ์ฌ์ฉ)
with torch.no_grad():
outputs = self.mert(input_values, output_hidden_states=True)
# hidden_states: tuple of [batch, time, feature_dim]
# ์ฌ๋ฌ ๋ ์ด์ด์ hidden state๋ฅผ ์คํํ ๋ค ์๊ฐ์ถ์ ๋ํด ํ๊ท ํ์ฌ feature๋ฅผ ์ป์
hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, feature_dim]
hidden_states = hidden_states.detach().clone().requires_grad_(True)
time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, feature_dim]
time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, feature_dim]
return time_reduced
def freeze(self):
for param in self.mert.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.mert.parameters():
param.requires_grad = True
class CrossAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads):
super(CrossAttentionLayer, self).__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.layer_norm1 = nn.LayerNorm(embed_dim)
self.layer_norm2 = nn.LayerNorm(embed_dim)
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.ReLU(),
nn.Linear(embed_dim * 4, embed_dim)
)
def forward(self, x, cross_input):
# x์ cross_input ๊ฐ์ ์ดํ
์
์ํ
attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
x = self.layer_norm1(x + attn_output)
ff_output = self.feed_forward(x)
x = self.layer_norm2(x + ff_output)
return x
class CCV(nn.Module):
def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
super(CCV, self).__init__()
# MERT ๊ธฐ๋ฐ feature extractor (pretraining weight๋ก๋ถํฐ ์ ์๋ฏธํ ํผ์ณ ์ถ์ถ)
self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor)
# Cross-Attention ๋ ์ด์ด ์ฌ๋ฌ ์ธต
self.cross_attention_layers = nn.ModuleList([
CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
])
# Transformer Encoder (๋ฐฐ์น ์ฐจ์ ๊ณ ๋ ค)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# ๋ถ๋ฅ๊ธฐ
self.classifier = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, input_values):
"""
input_values: Tensor [batch, time]
1. MERT๋ก๋ถํฐ feature ์ถ์ถ โ [batch, num_layers, feature_dim]
2. ์๋ฒ ๋ฉ ์ฐจ์ ๋ง์ถ๊ธฐ ์ํด transpose โ [batch, feature_dim, num_layers]
3. Cross-Attention ์ ์ฉ
4. Transformer Encoding ํ ํ๊ท ํ๋ง
5. ๋ถ๋ฅ๊ธฐ ํต๊ณผํ์ฌ ์ต์ข
์ถ๋ ฅ(logits) ๋ฐํ
"""
features = self.feature_extractor(input_values) # [batch, num_layers, feature_dim]
# embed_dim๋ ๋ณดํต feature_dim๊ณผ ๋์ผํ๊ฒ ๋ง์ถค (์์: 768)
# features = features.permute(0, 2, 1) # [batch, embed_dim, num_layers]
# Cross-Attention ์ ์ฉ (์ฌ๊ธฐ์๋ ์๊ธฐ์์ ๊ณผ์ ์ดํ
์
์ผ๋ก ์์)
for layer in self.cross_attention_layers:
features = layer(features, features)
# Transformer Encoder๋ฅผ ์ํด ์๊ฐ ์ถ(์ฌ๊ธฐ์๋ num_layers ์ถ)์ ๋ํด ํ๊ท
features = features.mean(dim=1).unsqueeze(1) # [batch, 1, embed_dim]
encoded = self.transformer(features) # [batch, 1, embed_dim]
encoded = encoded.mean(dim=1) # [batch, embed_dim]
output = self.classifier(encoded) # [batch, num_classes]
return output, encoded
def unfreeze_feature_extractor(self):
self.feature_extractor.unfreeze()
|