File size: 2,107 Bytes
2792f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import tensorflow as tf
import keras

@keras.saving.register_keras_serializable()
class L2Normalization(tf.keras.layers.Layer):
    """
    Applies L2 normalization to the last axis of the input tensor.
    
    This is used as a top layer in speaker embedding models before
    cosine similarity computation.
    """
    def call(self, inputs):
        return tf.math.l2_normalize(inputs, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape

@keras.saving.register_keras_serializable()
class CosineLayer(tf.keras.layers.Layer):
    """
    Dense layer with L2-normalized weights, for cosine similarity-based classification.

    Args:
        out_features (int): Number of output features/classes.
        use_bias (bool): Whether to use bias term.
        name (str, optional): Layer name.
    """
    def __init__(self, out_features, use_bias=False, name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.out_features = out_features
        self.use_bias = use_bias

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(int(input_shape[-1]), self.out_features),
            initializer='glorot_uniform',
            trainable=True,
            name='weights'
        )
        if self.use_bias:
            self.b = self.add_weight(
                shape=(self.out_features,),
                initializer='zeros',
                trainable=True,
                name='bias'
            )
        else:
            self.b = None
        super().build(input_shape)
        
    def call(self, inputs):
        w_normalized = tf.math.l2_normalize(self.w, axis=0)
        logits = tf.linalg.matmul(inputs, w_normalized)
        if self.use_bias:
            logits = logits + self.b
        return logits

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.out_features)

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            'out_features': self.out_features,
            'use_bias': self.use_bias
        }