File size: 3,394 Bytes
2792f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import keras
import tensorflow as tf
from custom_layers import L2Normalization, CosineLayer

@keras.saving.register_keras_serializable()
class VerificationModel(tf.keras.Model):
    """
    Modular Speaker Verification Model.

    Combines a backbone (feature extractor), an embedding projection, optional L2 normalization,
    and a cosine classification head (CosineLayer).

    Args:
        base_model (tf.keras.Model): Backbone model (e.g., ResNet18).
        number_of_classes (int): Number of speaker classes for classification.
        embedding_dim (int, optional): Size of embedding vector. Default: 512.
        return_embedding (bool, optional): If True, returns only embeddings (for verification); 
            if False, returns logits for classification. Default: False.
        base_training (bool, optional): If set, overrides 'training' flag for base model (controls BatchNorm, Dropout).
    """
    def __init__(
        self,
        base_model,
        number_of_classes,
        normalization_layer,
        cosine_layer,
        embedding_dim: int = 512,
        return_embedding: bool = False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model = base_model
        self.embedding_dim = embedding_dim
        self.number_of_classes = number_of_classes
        self.return_embedding = return_embedding

        self.embedding_layer = tf.keras.layers.Dense(
            embedding_dim,
            activation='tanh',
            use_bias=False,
            name='embedding_dense'
        )
        self.bn_neck = tf.keras.layers.BatchNormalization(name="bn_neck")
        self.normalization_layer = normalization_layer
        self.cosine_layer = cosine_layer

    def call(self, inputs, training=None):
        """
        Forward pass.

        Args:
            inputs: Input tensor (e.g., spectrograms).
            training (bool, optional): Training mode (Keras convention).
        Returns:
            Embeddings (if return_embedding=True) or logits for classification.
        """

        x = self.base_model(inputs, training=training)
        x = self.embedding_layer(x)
        x = self.bn_neck(x, training=training)
        x = self.normalization_layer(x)
        if self.return_embedding:
            return x
        return self.cosine_layer(x)

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "base_model": keras.saving.serialize_keras_object(self.base_model),
            "normalization_layer": keras.saving.serialize_keras_object(
                self.normalization_layer
            ),
            "cosine_layer": keras.saving.serialize_keras_object(self.cosine_layer),
            "number_of_classes": self.number_of_classes,
            "embedding_dim": self.embedding_dim,
            "return_embedding": self.return_embedding
        }

    @classmethod
    def from_config(cls, config):
        base_model = keras.saving.deserialize_keras_object(config.pop("base_model"))
        normalization_layer = keras.saving.deserialize_keras_object(config.pop("normalization_layer"))
        cosine_layer = keras.saving.deserialize_keras_object(config.pop("cosine_layer"))
        return cls(base_model=base_model,
                   normalization_layer=normalization_layer,
                   cosine_layer=cosine_layer,
                   **config)