Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| ''' | |
| freeze_feature_extractor=True ์ Feature Extractor๋ฅผ ๋๊ฒฐ (Pretraining) | |
| unfreeze_feature_extractor()๋ฅผ ํธ์ถํ๋ฉด Fine-Tuning ๊ฐ๋ฅ | |
| ''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from transformers import Wav2Vec2Model | |
| class cnn(nn.Module): | |
| def __init__(self, embed_dim=512): | |
| super(cnn, self).__init__() | |
| self.conv_block = nn.Sequential( | |
| nn.Conv2d(1, 16, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(16, 32, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.AdaptiveAvgPool2d((4, 4)) | |
| ) | |
| self.projection = nn.Linear(32 * 4 * 4, embed_dim) | |
| def forward(self, x): | |
| x = self.conv_block(x) | |
| B, C, H, W = x.shape | |
| x = x.view(B, -1) | |
| x = self.projection(x) | |
| return x | |
| class CrossAttentionLayer(nn.Module): | |
| def __init__(self, embed_dim, num_heads): | |
| super(CrossAttentionLayer, self).__init__() | |
| self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) | |
| self.layer_norm = nn.LayerNorm(embed_dim) | |
| self.feed_forward = nn.Sequential( | |
| nn.Linear(embed_dim, embed_dim * 4), | |
| nn.ReLU(), | |
| nn.Linear(embed_dim * 4, embed_dim) | |
| ) | |
| self.attention_weights = None | |
| def forward(self, x, cross_input): | |
| # Cross-attention between x and cross_input | |
| attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input) | |
| self.attention_weights = attn_weights | |
| x = self.layer_norm(x + attn_output) | |
| feed_forward_output = self.feed_forward(x) | |
| x = self.layer_norm(x + feed_forward_output) | |
| return x | |
| class CrossAttentionViT(nn.Module): | |
| def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): | |
| super(CrossAttentionViT, self).__init__() | |
| self.cross_attention_layers = nn.ModuleList([ | |
| CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) | |
| ]) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| nn.Linear(embed_dim, num_classes) | |
| ) | |
| def forward(self, x, cross_attention_input): | |
| self.attention_maps = [] | |
| for layer in self.cross_attention_layers: | |
| x = layer(x, cross_attention_input) | |
| self.attention_maps.append(layer.attention_weights) | |
| x = x.unsqueeze(1).permute(1, 0, 2) | |
| x = self.transformer(x) | |
| x = x.mean(dim=0) | |
| x = self.classifier(x) | |
| return x | |
| class CCV(nn.Module): | |
| def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): | |
| super(CCV, self).__init__() | |
| self.encoder = cnn(embed_dim=embed_dim) | |
| self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) | |
| def forward(self, x, cross_attention_input=None): | |
| x = self.encoder(x) | |
| if cross_attention_input is None: | |
| cross_attention_input = x | |
| x = self.decoder(x, cross_attention_input) | |
| # Attention Map ์ ์ฅ | |
| self.attention_maps = self.decoder.attention_maps | |
| return x | |
| def get_attention_maps(self): | |
| return self.attention_maps | |
| import torch | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2Model | |
| class Wav2Vec2ForFakeMusic(nn.Module): | |
| def __init__(self, num_classes=2, freeze_feature_extractor=True): | |
| super(Wav2Vec2ForFakeMusic, self).__init__() | |
| self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| if freeze_feature_extractor: | |
| for param in self.wav2vec.parameters(): | |
| param.requires_grad = False | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.wav2vec.config.hidden_size, 256), # 768 โ 256 | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) # 256 โ 2 (Binary Classification) | |
| ) | |
| def forward(self, x): | |
| x = x.squeeze(1) | |
| output = self.wav2vec(x) | |
| features = output["last_hidden_state"] # (batch_size, seq_len, feature_dim) | |
| pooled_features = features.mean(dim=1) # โ Mean Pooling ์ ์ฉ (batch_size, feature_dim) | |
| logits = self.classifier(pooled_features) # (batch_size, num_classes) | |
| return logits, pooled_features | |
| def visualize_attention_map(attn_map, mel_spec, layer_idx): | |
| attn_map = attn_map.mean(dim=1).squeeze().cpu().numpy() # ์ฌ๋ฌ head ํ๊ท | |
| mel_spec = mel_spec.squeeze().cpu().numpy() | |
| fig, axs = plt.subplots(2, 1, figsize=(10, 8)) | |
| # 1Log-Mel Spectrogram ์๊ฐํ | |
| sns.heatmap(mel_spec, cmap='inferno', ax=axs[0]) | |
| axs[0].set_title("Log-Mel Spectrogram") | |
| axs[0].set_xlabel("Time Frames") | |
| axs[0].set_ylabel("Mel Frequency Bins") | |
| # Attention Map ์๊ฐํ | |
| sns.heatmap(attn_map, cmap='viridis', ax=axs[1]) | |
| axs[1].set_title(f"Attention Map (Layer {layer_idx})") | |
| axs[1].set_xlabel("Time Frames") | |
| axs[1].set_ylabel("Query Positions") | |
| plt.tight_layout() | |
| plt.show() | |
| plt.savefig("/data/kym/AI_Music_Detection/Code/model/attention_map/crossattn.png") | |