import torch import torch.nn as nn class audiocnn(nn.Module): def __init__(self, num_classes=2): super(audiocnn, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) ) self.fc_block = nn.Sequential( nn.Linear(32*4*4, 128), nn.ReLU(), nn.Linear(128, num_classes) ) def forward(self, x): x = self.conv_block(x) # x.shape: (B,32,new_freq,new_time) # 1) Flatten B, C, H, W = x.shape # 동적 shape x = x.view(B, -1) # (B, 32*H*W) # 2) FC x = self.fc_block(x) return x class AudioCNN(nn.Module): def __init__(self, embed_dim=512): super(AudioCNN, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d((4, 4)) # 최종 -> (B, 32, 4, 4) ) self.projection = nn.Linear(32 * 4 * 4, embed_dim) def forward(self, x): x = self.conv_block(x) B, C, H, W = x.shape x = x.view(B, -1) # Flatten (B, C * H * W) x = self.projection(x) # Project to embed_dim return x class ViTDecoder(nn.Module): def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): super(ViTDecoder, self).__init__() # Transformer layers encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Classification head self.classifier = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes) ) def forward(self, x): # Transformer expects input of shape (seq_len, batch, embed_dim) x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim) x = self.transformer(x) # Pass through Transformer x = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim) x = self.classifier(x) # Classification head return x class AudioCNNWithViTDecoder(nn.Module): def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): super(AudioCNNWithViTDecoder, self).__init__() self.encoder = AudioCNN(embed_dim=embed_dim) self.decoder = ViTDecoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) def forward(self, x): x = self.encoder(x) # Pass through AudioCNN encoder x = self.decoder(x) # Pass through ViT decoder return x # class AudioCNN(nn.Module): # def __init__(self, num_classes=2): # super(AudioCNN, self).__init__() # self.conv_block = nn.Sequential( # nn.Conv2d(1, 16, kernel_size=3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # nn.Conv2d(16, 32, kernel_size=3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) # ) # self.fc_block = nn.Sequential( # nn.Linear(32*4*4, 128), # nn.ReLU(), # nn.Linear(128, num_classes) # ) # def forward(self, x): # x = self.conv_block(x) # # x.shape: (B,32,new_freq,new_time) # # 1) Flatten # B, C, H, W = x.shape # 동적 shape # x = x.view(B, -1) # (B, 32*H*W) # # 2) FC # x = self.fc_block(x) # return x class audio_crossattn(nn.Module): def __init__(self, embed_dim=512): super(audio_crossattn, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.AdaptiveAvgPool2d((4, 4)) # 최종 출력 -> (B, 32, 4, 4) ) self.projection = nn.Linear(32 * 4 * 4, embed_dim) def forward(self, x): x = self.conv_block(x) # Convolutional feature extraction B, C, H, W = x.shape x = x.view(B, -1) # Flatten (B, C * H * W) x = self.projection(x) # Linear projection to embed_dim return x class CrossAttentionLayer(nn.Module): def __init__(self, embed_dim, num_heads): super(CrossAttentionLayer, self).__init__() self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) self.layer_norm = nn.LayerNorm(embed_dim) self.feed_forward = nn.Sequential( nn.Linear(embed_dim, embed_dim * 4), nn.ReLU(), nn.Linear(embed_dim * 4, embed_dim) ) def forward(self, x, cross_input): # Cross-attention between x and cross_input attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input) x = self.layer_norm(x + attn_output) # Add & Norm feed_forward_output = self.feed_forward(x) x = self.layer_norm(x + feed_forward_output) # Add & Norm return x class ViTDecoderWithCrossAttention(nn.Module): def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): super(ViTDecoderWithCrossAttention, self).__init__() # Cross-Attention layers self.cross_attention_layers = nn.ModuleList([ CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) ]) # Transformer Encoder layers encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Classification head self.classifier = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes) ) def forward(self, x, cross_attention_input): # Pass through Cross-Attention layers for layer in self.cross_attention_layers: x = layer(x, cross_attention_input) # Transformer expects input of shape (seq_len, batch, embed_dim) x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim) x = self.transformer(x) # Pass through Transformer embedding = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim) # Classification head x = self.classifier(embedding) return x, embedding # class AudioCNNWithViTDecoderAndCrossAttention(nn.Module): # def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): # super(AudioCNNWithViTDecoderAndCrossAttention, self).__init__() # self.encoder = audio_crossattn(embed_dim=embed_dim) # self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) # def forward(self, x, cross_attention_input): # # Pass through AudioCNN encoder # x = self.encoder(x) # # Pass through ViTDecoder with Cross-Attention # x = self.decoder(x, cross_attention_input) # return x class CCV(nn.Module): def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): super(CCV, self).__init__() self.encoder = AudioCNN(embed_dim=embed_dim) self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) if freeze_feature_extractor: for param in self.encoder.parameters(): param.requires_grad = False for param in self.decoder.parameters(): param.requires_grad = False def forward(self, x, cross_attention_input=None): # Pass through AudioCNN encoder x = self.encoder(x) # If cross_attention_input is not provided, use the encoder output if cross_attention_input is None: cross_attention_input = x # Pass through ViTDecoder with Cross-Attention x, embedding = self.decoder(x, cross_attention_input) return x, embedding #--------------------------------------------------------- ''' audiocnn weight frozen crossatten decoder -lora tuning '''