Spaces:
Sleeping
Sleeping
File size: 3,394 Bytes
2792f07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import keras
import tensorflow as tf
from custom_layers import L2Normalization, CosineLayer
@keras.saving.register_keras_serializable()
class VerificationModel(tf.keras.Model):
"""
Modular Speaker Verification Model.
Combines a backbone (feature extractor), an embedding projection, optional L2 normalization,
and a cosine classification head (CosineLayer).
Args:
base_model (tf.keras.Model): Backbone model (e.g., ResNet18).
number_of_classes (int): Number of speaker classes for classification.
embedding_dim (int, optional): Size of embedding vector. Default: 512.
return_embedding (bool, optional): If True, returns only embeddings (for verification);
if False, returns logits for classification. Default: False.
base_training (bool, optional): If set, overrides 'training' flag for base model (controls BatchNorm, Dropout).
"""
def __init__(
self,
base_model,
number_of_classes,
normalization_layer,
cosine_layer,
embedding_dim: int = 512,
return_embedding: bool = False,
**kwargs
):
super().__init__(**kwargs)
self.base_model = base_model
self.embedding_dim = embedding_dim
self.number_of_classes = number_of_classes
self.return_embedding = return_embedding
self.embedding_layer = tf.keras.layers.Dense(
embedding_dim,
activation='tanh',
use_bias=False,
name='embedding_dense'
)
self.bn_neck = tf.keras.layers.BatchNormalization(name="bn_neck")
self.normalization_layer = normalization_layer
self.cosine_layer = cosine_layer
def call(self, inputs, training=None):
"""
Forward pass.
Args:
inputs: Input tensor (e.g., spectrograms).
training (bool, optional): Training mode (Keras convention).
Returns:
Embeddings (if return_embedding=True) or logits for classification.
"""
x = self.base_model(inputs, training=training)
x = self.embedding_layer(x)
x = self.bn_neck(x, training=training)
x = self.normalization_layer(x)
if self.return_embedding:
return x
return self.cosine_layer(x)
def get_config(self):
base_config = super().get_config()
return {
**base_config,
"base_model": keras.saving.serialize_keras_object(self.base_model),
"normalization_layer": keras.saving.serialize_keras_object(
self.normalization_layer
),
"cosine_layer": keras.saving.serialize_keras_object(self.cosine_layer),
"number_of_classes": self.number_of_classes,
"embedding_dim": self.embedding_dim,
"return_embedding": self.return_embedding
}
@classmethod
def from_config(cls, config):
base_model = keras.saving.deserialize_keras_object(config.pop("base_model"))
normalization_layer = keras.saving.deserialize_keras_object(config.pop("normalization_layer"))
cosine_layer = keras.saving.deserialize_keras_object(config.pop("cosine_layer"))
return cls(base_model=base_model,
normalization_layer=normalization_layer,
cosine_layer=cosine_layer,
**config)
|