File size: 4,936 Bytes
c3c908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig

class MERTFeatureExtractor(nn.Module):
    def __init__(self, freeze_feature_extractor=True):
        super(MERTFeatureExtractor, self).__init__()
        config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
        if not hasattr(config, "conv_pos_batch_norm"):
            setattr(config, "conv_pos_batch_norm", False)
        self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True)

        if freeze_feature_extractor:
            self.freeze()

    def forward(self, input_values):
        # ์ž…๋ ฅ: [batch, time]
        # ์‚ฌ์ „ํ•™์Šต๋œ MERT์˜ hidden_states ์ถ”์ถœ (์˜ˆ์‹œ๋กœ ๋ชจ๋“  ๋ ˆ์ด์–ด์˜ hidden state ์‚ฌ์šฉ)
        with torch.no_grad():
            outputs = self.mert(input_values, output_hidden_states=True)
        # hidden_states: tuple of [batch, time, feature_dim]
        # ์—ฌ๋Ÿฌ ๋ ˆ์ด์–ด์˜ hidden state๋ฅผ ์Šคํƒํ•œ ๋’ค ์‹œ๊ฐ„์ถ•์— ๋Œ€ํ•ด ํ‰๊ท ํ•˜์—ฌ feature๋ฅผ ์–ป์Œ
        hidden_states = torch.stack(outputs.hidden_states)  # [num_layers, batch, time, feature_dim]
        hidden_states = hidden_states.detach().clone().requires_grad_(True)
        time_reduced = hidden_states.mean(dim=2)  # [num_layers, batch, feature_dim]
        time_reduced = time_reduced.permute(1, 0, 2)  # [batch, num_layers, feature_dim]
        return time_reduced

    def freeze(self):
        for param in self.mert.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.mert.parameters():
            param.requires_grad = True


class CrossAttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CrossAttentionLayer, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )

    def forward(self, x, cross_input):
        # x์™€ cross_input ๊ฐ„์˜ ์–ดํ…์…˜ ์ˆ˜ํ–‰
        attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
        x = self.layer_norm1(x + attn_output)
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + ff_output)
        return x


class CCV(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True):
        super(CCV, self).__init__()
        # MERT ๊ธฐ๋ฐ˜ feature extractor (pretraining weight๋กœ๋ถ€ํ„ฐ ์œ ์˜๋ฏธํ•œ ํ”ผ์ณ ์ถ”์ถœ)
        self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor)
        # Cross-Attention ๋ ˆ์ด์–ด ์—ฌ๋Ÿฌ ์ธต
        self.cross_attention_layers = nn.ModuleList([
            CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])
        # Transformer Encoder (๋ฐฐ์น˜ ์ฐจ์› ๊ณ ๋ ค)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        # ๋ถ„๋ฅ˜๊ธฐ
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )


    def forward(self, input_values):
        """
        input_values: Tensor [batch, time]
        1. MERT๋กœ๋ถ€ํ„ฐ feature ์ถ”์ถœ โ†’ [batch, num_layers, feature_dim]
        2. ์ž„๋ฒ ๋”ฉ ์ฐจ์› ๋งž์ถ”๊ธฐ ์œ„ํ•ด transpose โ†’ [batch, feature_dim, num_layers]
        3. Cross-Attention ์ ์šฉ
        4. Transformer Encoding ํ›„ ํ‰๊ท  ํ’€๋ง
        5. ๋ถ„๋ฅ˜๊ธฐ ํ†ต๊ณผํ•˜์—ฌ ์ตœ์ข… ์ถœ๋ ฅ(logits) ๋ฐ˜ํ™˜
        """
        features = self.feature_extractor(input_values)  # [batch, num_layers, feature_dim]
        # embed_dim๋Š” ๋ณดํ†ต feature_dim๊ณผ ๋™์ผํ•˜๊ฒŒ ๋งž์ถค (์˜ˆ์‹œ: 768)
        # features = features.permute(0, 2, 1)  # [batch, embed_dim, num_layers]
        
        # Cross-Attention ์ ์šฉ (์—ฌ๊ธฐ์„œ๋Š” ์ž๊ธฐ์ž์‹ ๊ณผ์˜ ์–ดํ…์…˜์œผ๋กœ ์˜ˆ์‹œ)
        for layer in self.cross_attention_layers:
            features = layer(features, features)
        
        # Transformer Encoder๋ฅผ ์œ„ํ•ด ์‹œ๊ฐ„ ์ถ•(์—ฌ๊ธฐ์„œ๋Š” num_layers ์ถ•)์— ๋Œ€ํ•ด ํ‰๊ท 
        features = features.mean(dim=1).unsqueeze(1)  # [batch, 1, embed_dim]
        encoded = self.transformer(features)  # [batch, 1, embed_dim]
        encoded = encoded.mean(dim=1)  # [batch, embed_dim]
        output = self.classifier(encoded)  # [batch, num_classes]
        return output, encoded

    def unfreeze_feature_extractor(self):
        self.feature_extractor.unfreeze()