Speaker_Verification_Demo / src /custom_losses.py
2pift's picture
Update dependecies
2792f07
import tensorflow as tf
import keras
import numpy as np
@keras.saving.register_keras_serializable()
class AdaCosLoss(tf.keras.losses.Loss):
"""
Adaptive Cosine Loss (AdaCos).
Implements the AdaCos loss function as described in:
"AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations"
(Zhang et al., 2019).
Args:
num_classes (int): Number of classes in the classification problem.
name (str, optional): Name for the loss instance.
"""
def __init__(self, num_classes=None, name="AdaCos", **kwargs):
super().__init__(name=name, **kwargs)
self.num_classes = num_classes
self.scale = tf.Variable(
np.sqrt(2) * np.log(num_classes - 1),
dtype=tf.float32, trainable=False
)
def call(self, y_true, y_pred):
"""
Args:
y_true: (batch_size,) integer labels [0, num_classes-1].
y_pred: (batch_size, num_classes) classification cosine similarities.
Returns:
Tensor scalar: Mean AdaCos loss over the batch.
"""
y_true = tf.cast(y_true, tf.int32)
y_pred = tf.clip_by_value(
y_pred,
-1.0 + tf.keras.backend.epsilon(),
1.0 - tf.keras.backend.epsilon()
)
# correct class mask
mask = tf.one_hot(y_true, depth=self.num_classes) # shape (batch_size, n_classes)
# get theta angles for corresponding class
theta_true = tf.math.acos(tf.boolean_mask(y_pred, mask)) # shape (batch_size,)
# compute median of 'correct' angles
theta_med = tf.keras.ops.median(theta_true)
# get non-corresponding cosine values (cos(theta) j is not yi)
neg_mask = tf.logical_not(mask > 0) # shape (batch_size, n_classes)
cos_theta_neg = tf.boolean_mask(y_pred, neg_mask) # shape (batch_size*(n_classes-1),)
neg_y_pred = tf.reshape(cos_theta_neg, [-1, self.num_classes - 1]) # shape (batch_size, n_classes-1)
B_avg = tf.reduce_mean(tf.reduce_sum(tf.math.exp(self.scale * neg_y_pred), axis=-1))
#B_avg = tf.cast(B_avg, tf.float32)
#with tf.control_dependencies([theta_med, B_avg]):
new_scale = (
tf.math.log(B_avg) /
tf.math.cos(tf.minimum(tf.constant(np.pi / 4), theta_med))
)
# keep current scale if new_scale is invalid
safe_scale = tf.cond(
tf.math.is_finite(new_scale) & (new_scale > 0),
lambda: new_scale,
lambda: self.scale
)
self.scale.assign(safe_scale)
logits = self.scale * y_pred
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=logits)
return tf.reduce_mean(loss)
def get_config(self):
base_config = super().get_config()
return {**base_config, 'num_classes': self.num_classes}
def __repr__(self):
return (f"{self.__class__.__name__}(num_classes={self.num_classes}, "
f"name='{self.name}')")
def __str__(self):
return self.__repr__()
@property
def num_classes(self):
return self._num_classes
@num_classes.setter
def num_classes(self, value):
if not isinstance(value, int):
raise TypeError(f"`num_classes` must be an int, got {type(value).__name__}")
if value < 2:
raise ValueError(f"`num_classes` must be >= 2, got {value}")
self._num_classes = value
@keras.saving.register_keras_serializable()
class AdaCosLossMargin(tf.keras.losses.Loss):
"""
Adaptive Cosine Loss with Margin (AdaCosMargin).
Extends AdaCos by introducing a fixed margin penalty for the target class logits,
encouraging greater separation between classes in angular (cosine) space.
Reference:
- AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations (Zhang et al., 2019)
- Large Margin Cosine Loss (CosFace): https://arxiv.org/abs/1801.09414
Args:
margin (float): Margin to subtract from the target class cosine similarity (0.0–1.0).
num_classes (int): Number of classes.
name (str, optional): Name for the loss.
"""
def __init__(self, margin=0.1, num_classes=None, name="AdaCosLossMargin", **kwargs):
super().__init__(name=name, **kwargs)
self.margin = margin
self.num_classes = num_classes
self.scale = tf.Variable(
np.sqrt(2) * np.log(num_classes - 1),
dtype=tf.float32, trainable=False
)
def call(self, y_true, y_pred):
"""
Args:
y_true: (batch_size,) integer labels [0, num_classes-1].
y_pred: (batch_size, num_classes) cosine similarities.
Returns:
Tensor scalar: Mean AdaCosMargin loss over the batch.
"""
batch_size = tf.shape(y_pred)[0]
y_true = tf.cast(y_true, tf.int32)
y_pred = tf.clip_by_value(
y_pred,
-1.0 + tf.keras.backend.epsilon(),
1.0 - tf.keras.backend.epsilon()
)
mask = tf.one_hot(y_true, depth=self.num_classes)
theta_true = tf.math.acos(tf.boolean_mask(y_pred, mask))
theta_med = tf.keras.ops.median(theta_true)
neg_mask = tf.cast(tf.logical_not(mask > 0), dtype=tf.float32)
cos_theta_neg = tf.boolean_mask(y_pred, neg_mask)
neg_y_pred = tf.reshape(cos_theta_neg, [batch_size, self.num_classes - 1])
B_avg = tf.reduce_mean(tf.reduce_sum(tf.math.exp(self.scale * neg_y_pred), axis=-1))
B_avg = tf.cast(B_avg, tf.float32)
with tf.control_dependencies([theta_med, B_avg]):
new_scale = (
tf.math.log(B_avg) /
tf.math.cos(tf.minimum(tf.constant(np.pi / 4), theta_med))
)
safe_scale = tf.cond(
tf.math.is_finite(new_scale) & (new_scale > 0),
lambda: new_scale,
lambda: self.scale
)
self.scale.assign(safe_scale)
logits = self.scale * (y_pred - self.margin * mask)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=logits)
return tf.reduce_mean(loss)
def get_config(self):
base_config = super().get_config()
return {
**base_config,
'num_classes': self.num_classes,
'margin': self.margin
}
def __repr__(self):
return (f"{self.__class__.__name__}(margin={self.margin}, num_classes={self.num_classes}, "
f"name='{self.name}')")
def __str__(self):
return self.__repr__()
@property
def num_classes(self):
return self._num_classes
@num_classes.setter
def num_classes(self, value):
if not isinstance(value, int):
raise TypeError(f"`num_classes` must be an int, got {type(value).__name__}")
if value < 2:
raise ValueError(f"`num_classes` must be >= 2, got {value}")
self._num_classes = value
@property
def margin(self):
return self._margin
@margin.setter
def margin(self, value):
if not isinstance(value, (float, int)):
raise TypeError(f"`margin` must be a float or int, got {type(value).__name__}")
value = float(value)
if not (0.0 <= value <= 1.0):
raise ValueError(f"`margin` must be between 0.0 and 1.0, got {value}")
self._margin = value