Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, AutoConfig | |
| class MERTFeatureExtractor(nn.Module): | |
| def __init__(self, freeze_feature_extractor=True): | |
| super(MERTFeatureExtractor, self).__init__() | |
| config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) | |
| if not hasattr(config, "conv_pos_batch_norm"): | |
| setattr(config, "conv_pos_batch_norm", False) | |
| self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True) | |
| if freeze_feature_extractor: | |
| self.freeze() | |
| def forward(self, input_values): | |
| # ์ ๋ ฅ: [batch, time] | |
| # ์ฌ์ ํ์ต๋ MERT์ hidden_states ์ถ์ถ (์์๋ก ๋ชจ๋ ๋ ์ด์ด์ hidden state ์ฌ์ฉ) | |
| with torch.no_grad(): | |
| outputs = self.mert(input_values, output_hidden_states=True) | |
| # hidden_states: tuple of [batch, time, feature_dim] | |
| # ์ฌ๋ฌ ๋ ์ด์ด์ hidden state๋ฅผ ์คํํ ๋ค ์๊ฐ์ถ์ ๋ํด ํ๊ท ํ์ฌ feature๋ฅผ ์ป์ | |
| hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, feature_dim] | |
| hidden_states = hidden_states.detach().clone().requires_grad_(True) | |
| time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, feature_dim] | |
| time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, feature_dim] | |
| return time_reduced | |
| def freeze(self): | |
| for param in self.mert.parameters(): | |
| param.requires_grad = False | |
| def unfreeze(self): | |
| for param in self.mert.parameters(): | |
| param.requires_grad = True | |
| class CrossAttentionLayer(nn.Module): | |
| def __init__(self, embed_dim, num_heads): | |
| super(CrossAttentionLayer, self).__init__() | |
| self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) | |
| self.layer_norm1 = nn.LayerNorm(embed_dim) | |
| self.layer_norm2 = nn.LayerNorm(embed_dim) | |
| self.feed_forward = nn.Sequential( | |
| nn.Linear(embed_dim, embed_dim * 4), | |
| nn.ReLU(), | |
| nn.Linear(embed_dim * 4, embed_dim) | |
| ) | |
| def forward(self, x, cross_input): | |
| # x์ cross_input ๊ฐ์ ์ดํ ์ ์ํ | |
| attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input) | |
| x = self.layer_norm1(x + attn_output) | |
| ff_output = self.feed_forward(x) | |
| x = self.layer_norm2(x + ff_output) | |
| return x | |
| class CCV(nn.Module): | |
| def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): | |
| super(CCV, self).__init__() | |
| # MERT ๊ธฐ๋ฐ feature extractor (pretraining weight๋ก๋ถํฐ ์ ์๋ฏธํ ํผ์ณ ์ถ์ถ) | |
| self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor) | |
| # Cross-Attention ๋ ์ด์ด ์ฌ๋ฌ ์ธต | |
| self.cross_attention_layers = nn.ModuleList([ | |
| CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) | |
| ]) | |
| # Transformer Encoder (๋ฐฐ์น ์ฐจ์ ๊ณ ๋ ค) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| # ๋ถ๋ฅ๊ธฐ | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| nn.Linear(embed_dim, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, input_values): | |
| """ | |
| input_values: Tensor [batch, time] | |
| 1. MERT๋ก๋ถํฐ feature ์ถ์ถ โ [batch, num_layers, feature_dim] | |
| 2. ์๋ฒ ๋ฉ ์ฐจ์ ๋ง์ถ๊ธฐ ์ํด transpose โ [batch, feature_dim, num_layers] | |
| 3. Cross-Attention ์ ์ฉ | |
| 4. Transformer Encoding ํ ํ๊ท ํ๋ง | |
| 5. ๋ถ๋ฅ๊ธฐ ํต๊ณผํ์ฌ ์ต์ข ์ถ๋ ฅ(logits) ๋ฐํ | |
| """ | |
| features = self.feature_extractor(input_values) # [batch, num_layers, feature_dim] | |
| # embed_dim๋ ๋ณดํต feature_dim๊ณผ ๋์ผํ๊ฒ ๋ง์ถค (์์: 768) | |
| # features = features.permute(0, 2, 1) # [batch, embed_dim, num_layers] | |
| # Cross-Attention ์ ์ฉ (์ฌ๊ธฐ์๋ ์๊ธฐ์์ ๊ณผ์ ์ดํ ์ ์ผ๋ก ์์) | |
| for layer in self.cross_attention_layers: | |
| features = layer(features, features) | |
| # Transformer Encoder๋ฅผ ์ํด ์๊ฐ ์ถ(์ฌ๊ธฐ์๋ num_layers ์ถ)์ ๋ํด ํ๊ท | |
| features = features.mean(dim=1).unsqueeze(1) # [batch, 1, embed_dim] | |
| encoded = self.transformer(features) # [batch, 1, embed_dim] | |
| encoded = encoded.mean(dim=1) # [batch, embed_dim] | |
| output = self.classifier(encoded) # [batch, num_classes] | |
| return output, encoded | |
| def unfreeze_feature_extractor(self): | |
| self.feature_extractor.unfreeze() | |