Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| class AdaCosLoss(tf.keras.losses.Loss): | |
| """ | |
| Adaptive Cosine Loss (AdaCos). | |
| Implements the AdaCos loss function as described in: | |
| "AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations" | |
| (Zhang et al., 2019). | |
| Args: | |
| num_classes (int): Number of classes in the classification problem. | |
| name (str, optional): Name for the loss instance. | |
| """ | |
| def __init__(self, num_classes=None, name="AdaCos", **kwargs): | |
| super().__init__(name=name, **kwargs) | |
| self.num_classes = num_classes | |
| self.scale = tf.Variable( | |
| np.sqrt(2) * np.log(num_classes - 1), | |
| dtype=tf.float32, trainable=False | |
| ) | |
| def call(self, y_true, y_pred): | |
| """ | |
| Args: | |
| y_true: (batch_size,) integer labels [0, num_classes-1]. | |
| y_pred: (batch_size, num_classes) classification cosine similarities. | |
| Returns: | |
| Tensor scalar: Mean AdaCos loss over the batch. | |
| """ | |
| y_true = tf.cast(y_true, tf.int32) | |
| y_pred = tf.clip_by_value( | |
| y_pred, | |
| -1.0 + tf.keras.backend.epsilon(), | |
| 1.0 - tf.keras.backend.epsilon() | |
| ) | |
| # correct class mask | |
| mask = tf.one_hot(y_true, depth=self.num_classes) # shape (batch_size, n_classes) | |
| # get theta angles for corresponding class | |
| theta_true = tf.math.acos(tf.boolean_mask(y_pred, mask)) # shape (batch_size,) | |
| # compute median of 'correct' angles | |
| theta_med = tf.keras.ops.median(theta_true) | |
| # get non-corresponding cosine values (cos(theta) j is not yi) | |
| neg_mask = tf.logical_not(mask > 0) # shape (batch_size, n_classes) | |
| cos_theta_neg = tf.boolean_mask(y_pred, neg_mask) # shape (batch_size*(n_classes-1),) | |
| neg_y_pred = tf.reshape(cos_theta_neg, [-1, self.num_classes - 1]) # shape (batch_size, n_classes-1) | |
| B_avg = tf.reduce_mean(tf.reduce_sum(tf.math.exp(self.scale * neg_y_pred), axis=-1)) | |
| #B_avg = tf.cast(B_avg, tf.float32) | |
| #with tf.control_dependencies([theta_med, B_avg]): | |
| new_scale = ( | |
| tf.math.log(B_avg) / | |
| tf.math.cos(tf.minimum(tf.constant(np.pi / 4), theta_med)) | |
| ) | |
| # keep current scale if new_scale is invalid | |
| safe_scale = tf.cond( | |
| tf.math.is_finite(new_scale) & (new_scale > 0), | |
| lambda: new_scale, | |
| lambda: self.scale | |
| ) | |
| self.scale.assign(safe_scale) | |
| logits = self.scale * y_pred | |
| loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=logits) | |
| return tf.reduce_mean(loss) | |
| def get_config(self): | |
| base_config = super().get_config() | |
| return {**base_config, 'num_classes': self.num_classes} | |
| def __repr__(self): | |
| return (f"{self.__class__.__name__}(num_classes={self.num_classes}, " | |
| f"name='{self.name}')") | |
| def __str__(self): | |
| return self.__repr__() | |
| def num_classes(self): | |
| return self._num_classes | |
| def num_classes(self, value): | |
| if not isinstance(value, int): | |
| raise TypeError(f"`num_classes` must be an int, got {type(value).__name__}") | |
| if value < 2: | |
| raise ValueError(f"`num_classes` must be >= 2, got {value}") | |
| self._num_classes = value | |
| class AdaCosLossMargin(tf.keras.losses.Loss): | |
| """ | |
| Adaptive Cosine Loss with Margin (AdaCosMargin). | |
| Extends AdaCos by introducing a fixed margin penalty for the target class logits, | |
| encouraging greater separation between classes in angular (cosine) space. | |
| Reference: | |
| - AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations (Zhang et al., 2019) | |
| - Large Margin Cosine Loss (CosFace): https://arxiv.org/abs/1801.09414 | |
| Args: | |
| margin (float): Margin to subtract from the target class cosine similarity (0.0β1.0). | |
| num_classes (int): Number of classes. | |
| name (str, optional): Name for the loss. | |
| """ | |
| def __init__(self, margin=0.1, num_classes=None, name="AdaCosLossMargin", **kwargs): | |
| super().__init__(name=name, **kwargs) | |
| self.margin = margin | |
| self.num_classes = num_classes | |
| self.scale = tf.Variable( | |
| np.sqrt(2) * np.log(num_classes - 1), | |
| dtype=tf.float32, trainable=False | |
| ) | |
| def call(self, y_true, y_pred): | |
| """ | |
| Args: | |
| y_true: (batch_size,) integer labels [0, num_classes-1]. | |
| y_pred: (batch_size, num_classes) cosine similarities. | |
| Returns: | |
| Tensor scalar: Mean AdaCosMargin loss over the batch. | |
| """ | |
| batch_size = tf.shape(y_pred)[0] | |
| y_true = tf.cast(y_true, tf.int32) | |
| y_pred = tf.clip_by_value( | |
| y_pred, | |
| -1.0 + tf.keras.backend.epsilon(), | |
| 1.0 - tf.keras.backend.epsilon() | |
| ) | |
| mask = tf.one_hot(y_true, depth=self.num_classes) | |
| theta_true = tf.math.acos(tf.boolean_mask(y_pred, mask)) | |
| theta_med = tf.keras.ops.median(theta_true) | |
| neg_mask = tf.cast(tf.logical_not(mask > 0), dtype=tf.float32) | |
| cos_theta_neg = tf.boolean_mask(y_pred, neg_mask) | |
| neg_y_pred = tf.reshape(cos_theta_neg, [batch_size, self.num_classes - 1]) | |
| B_avg = tf.reduce_mean(tf.reduce_sum(tf.math.exp(self.scale * neg_y_pred), axis=-1)) | |
| B_avg = tf.cast(B_avg, tf.float32) | |
| with tf.control_dependencies([theta_med, B_avg]): | |
| new_scale = ( | |
| tf.math.log(B_avg) / | |
| tf.math.cos(tf.minimum(tf.constant(np.pi / 4), theta_med)) | |
| ) | |
| safe_scale = tf.cond( | |
| tf.math.is_finite(new_scale) & (new_scale > 0), | |
| lambda: new_scale, | |
| lambda: self.scale | |
| ) | |
| self.scale.assign(safe_scale) | |
| logits = self.scale * (y_pred - self.margin * mask) | |
| loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=logits) | |
| return tf.reduce_mean(loss) | |
| def get_config(self): | |
| base_config = super().get_config() | |
| return { | |
| **base_config, | |
| 'num_classes': self.num_classes, | |
| 'margin': self.margin | |
| } | |
| def __repr__(self): | |
| return (f"{self.__class__.__name__}(margin={self.margin}, num_classes={self.num_classes}, " | |
| f"name='{self.name}')") | |
| def __str__(self): | |
| return self.__repr__() | |
| def num_classes(self): | |
| return self._num_classes | |
| def num_classes(self, value): | |
| if not isinstance(value, int): | |
| raise TypeError(f"`num_classes` must be an int, got {type(value).__name__}") | |
| if value < 2: | |
| raise ValueError(f"`num_classes` must be >= 2, got {value}") | |
| self._num_classes = value | |
| def margin(self): | |
| return self._margin | |
| def margin(self, value): | |
| if not isinstance(value, (float, int)): | |
| raise TypeError(f"`margin` must be a float or int, got {type(value).__name__}") | |
| value = float(value) | |
| if not (0.0 <= value <= 1.0): | |
| raise ValueError(f"`margin` must be between 0.0 and 1.0, got {value}") | |
| self._margin = value | |