File size: 7,543 Bytes
2792f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import tensorflow as tf
import keras
import numpy as np

@keras.saving.register_keras_serializable()
class AdaCosLoss(tf.keras.losses.Loss):
    """
    Adaptive Cosine Loss (AdaCos).

    Implements the AdaCos loss function as described in:
    "AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations"
    (Zhang et al., 2019).

    Args:
        num_classes (int): Number of classes in the classification problem.
        name (str, optional): Name for the loss instance.
    """
    def __init__(self, num_classes=None, name="AdaCos", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.scale = tf.Variable(
            np.sqrt(2) * np.log(num_classes - 1), 
            dtype=tf.float32, trainable=False
        )

    def call(self, y_true, y_pred):
        """
        Args:
            y_true: (batch_size,) integer labels [0, num_classes-1].
            y_pred: (batch_size, num_classes) classification cosine similarities.

        Returns:
            Tensor scalar: Mean AdaCos loss over the batch.
        """
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.clip_by_value(
            y_pred, 
            -1.0 + tf.keras.backend.epsilon(), 
            1.0 - tf.keras.backend.epsilon()
        )
        # correct class mask
        mask = tf.one_hot(y_true, depth=self.num_classes) # shape (batch_size, n_classes)
        # get theta angles for corresponding class
        theta_true = tf.math.acos(tf.boolean_mask(y_pred, mask)) # shape (batch_size,)
        # compute median of 'correct' angles
        theta_med = tf.keras.ops.median(theta_true)
        # get non-corresponding cosine values (cos(theta) j is not yi)
        neg_mask = tf.logical_not(mask > 0) # shape (batch_size, n_classes)
        cos_theta_neg = tf.boolean_mask(y_pred, neg_mask) # shape (batch_size*(n_classes-1),)

        neg_y_pred = tf.reshape(cos_theta_neg, [-1, self.num_classes - 1]) # shape (batch_size, n_classes-1)
        
        B_avg = tf.reduce_mean(tf.reduce_sum(tf.math.exp(self.scale * neg_y_pred), axis=-1))
        #B_avg = tf.cast(B_avg, tf.float32)

        #with tf.control_dependencies([theta_med, B_avg]):
        new_scale = (
            tf.math.log(B_avg) / 
            tf.math.cos(tf.minimum(tf.constant(np.pi / 4), theta_med))
        )
        # keep current scale if new_scale is invalid
        safe_scale = tf.cond(
            tf.math.is_finite(new_scale) & (new_scale > 0),
            lambda: new_scale,
            lambda: self.scale
        )
        self.scale.assign(safe_scale)
        logits = self.scale * y_pred
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=logits)
        
        return tf.reduce_mean(loss)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config, 'num_classes': self.num_classes}
    
    def __repr__(self):
        return (f"{self.__class__.__name__}(num_classes={self.num_classes}, "
                f"name='{self.name}')")

    def __str__(self):
        return self.__repr__()

    @property
    def num_classes(self):
        return self._num_classes

    @num_classes.setter
    def num_classes(self, value):
        if not isinstance(value, int):
            raise TypeError(f"`num_classes` must be an int, got {type(value).__name__}")
        if value < 2:
            raise ValueError(f"`num_classes` must be >= 2, got {value}")
        self._num_classes = value

@keras.saving.register_keras_serializable()
class AdaCosLossMargin(tf.keras.losses.Loss):
    """
    Adaptive Cosine Loss with Margin (AdaCosMargin).

    Extends AdaCos by introducing a fixed margin penalty for the target class logits, 
    encouraging greater separation between classes in angular (cosine) space.

    Reference:
    - AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations (Zhang et al., 2019)
    - Large Margin Cosine Loss (CosFace): https://arxiv.org/abs/1801.09414

    Args:
        margin (float): Margin to subtract from the target class cosine similarity (0.0–1.0).
        num_classes (int): Number of classes.
        name (str, optional): Name for the loss.
    """
    def __init__(self, margin=0.1, num_classes=None, name="AdaCosLossMargin", **kwargs):
        super().__init__(name=name, **kwargs)
        self.margin = margin
        self.num_classes = num_classes
        self.scale = tf.Variable(
            np.sqrt(2) * np.log(num_classes - 1), 
            dtype=tf.float32, trainable=False
        )

    def call(self, y_true, y_pred):
        """
        Args:
            y_true: (batch_size,) integer labels [0, num_classes-1].
            y_pred: (batch_size, num_classes) cosine similarities.

        Returns:
            Tensor scalar: Mean AdaCosMargin loss over the batch.
        """
        batch_size = tf.shape(y_pred)[0]
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.clip_by_value(
            y_pred,
            -1.0 + tf.keras.backend.epsilon(),
            1.0 - tf.keras.backend.epsilon()
        )
        mask = tf.one_hot(y_true, depth=self.num_classes)
        theta_true = tf.math.acos(tf.boolean_mask(y_pred, mask))
        theta_med = tf.keras.ops.median(theta_true)
        neg_mask = tf.cast(tf.logical_not(mask > 0), dtype=tf.float32)
        cos_theta_neg = tf.boolean_mask(y_pred, neg_mask)
        neg_y_pred = tf.reshape(cos_theta_neg, [batch_size, self.num_classes - 1])
        B_avg = tf.reduce_mean(tf.reduce_sum(tf.math.exp(self.scale * neg_y_pred), axis=-1))
        B_avg = tf.cast(B_avg, tf.float32)

        with tf.control_dependencies([theta_med, B_avg]):
            new_scale = (
                tf.math.log(B_avg) /
                tf.math.cos(tf.minimum(tf.constant(np.pi / 4), theta_med))
            )
            safe_scale = tf.cond(
                tf.math.is_finite(new_scale) & (new_scale > 0),
                lambda: new_scale,
                lambda: self.scale
            )
            self.scale.assign(safe_scale)
            logits = self.scale * (y_pred - self.margin * mask)
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=logits)
        return tf.reduce_mean(loss)

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            'num_classes': self.num_classes,
            'margin': self.margin
        }
    
    def __repr__(self):
        return (f"{self.__class__.__name__}(margin={self.margin}, num_classes={self.num_classes}, "
                f"name='{self.name}')")

    def __str__(self):
        return self.__repr__()

    @property
    def num_classes(self):
        return self._num_classes

    @num_classes.setter
    def num_classes(self, value):
        if not isinstance(value, int):
            raise TypeError(f"`num_classes` must be an int, got {type(value).__name__}")
        if value < 2:
            raise ValueError(f"`num_classes` must be >= 2, got {value}")
        self._num_classes = value

    @property
    def margin(self):
        return self._margin

    @margin.setter
    def margin(self, value):
        if not isinstance(value, (float, int)):
            raise TypeError(f"`margin` must be a float or int, got {type(value).__name__}")
        value = float(value)
        if not (0.0 <= value <= 1.0):
            raise ValueError(f"`margin` must be between 0.0 and 1.0, got {value}")
        self._margin = value