From 77eca76cb4742ce5b3cb3df6a236ff63cdebbd30 Mon Sep 17 00:00:00 2001 From: ma7555 <7144929+ma7555@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:16:11 +0200 Subject: [PATCH 1/6] update keras/src/losses/__init__.py, losses.py, losses_test.py and numerical_utils.py --- keras/src/losses/__init__.py | 6 + keras/src/losses/losses.py | 174 +++++++++++++++++++++++++++++ keras/src/losses/losses_test.py | 102 +++++++++++++++++ keras/src/utils/numerical_utils.py | 30 +++++ 4 files changed, 312 insertions(+) diff --git a/keras/src/losses/__init__.py b/keras/src/losses/__init__.py index 3163f43d98d4..7edada501d09 100644 --- a/keras/src/losses/__init__.py +++ b/keras/src/losses/__init__.py @@ -23,6 +23,7 @@ from keras.src.losses.losses import SparseCategoricalCrossentropy from keras.src.losses.losses import SquaredHinge from keras.src.losses.losses import Tversky +from keras.src.losses.losses import Circle from keras.src.losses.losses import binary_crossentropy from keras.src.losses.losses import binary_focal_crossentropy from keras.src.losses.losses import categorical_crossentropy @@ -43,6 +44,7 @@ from keras.src.losses.losses import sparse_categorical_crossentropy from keras.src.losses.losses import squared_hinge from keras.src.losses.losses import tversky +from keras.src.losses.losses import circle from keras.src.saving import serialization_lib ALL_OBJECTS = { @@ -72,6 +74,8 @@ # Image segmentation Dice, Tversky, + # Similarity + Circle, # Sequence CTC, # Probabilistic @@ -97,6 +101,8 @@ # Image segmentation dice, tversky, + # Similarity + circle, # Sequence ctc, } diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index 409c8284ff10..a82b63541e5b 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -7,6 +7,7 @@ from keras.src.losses.loss import Loss from keras.src.losses.loss import squeeze_or_expand_to_same_rank from keras.src.saving import serialization_lib +from keras.src.utils.numerical_utils import build_pos_neg_masks from keras.src.utils.numerical_utils import normalize @@ -1403,6 +1404,91 @@ def get_config(self): return config +@keras_export("keras.losses.Circle") +class Circle(LossFunctionWrapper): + """Computes Circle Loss, a metric learning loss designed to minimize + within-class distance and maximize between-class distance in a flexible + manner by dynamically adjusting the penalty strength based on optimization + status of each similarity score. + + To use Circle Loss effectively, the model should output embeddings without + an activation function (such as a `Dense` layer with `activation=None`) + followed by UnitNormalization layer to ensure unit-norm embeddings. + + Args: + gamma: Scaling factor that determines the largest scale of each similarity score. Defaults to `80`. + margin: The relaxation factor, below this distance, negatives are + up weighted and positives are down weighted. Similarly, above this + distance negatives are down weighted and positive are up weighted. Defaults to `0.4`. + remove_diagonal: Boolean indicating whether to remove self-similarities from the positive mask. Defaults to `True`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + Usage with the `compile()` API: + + ```python + model = models.Sequential([ + keras.layers.Input(shape=(224, 224, 3)), + keras.layers.Conv2D(16, (3, 3), activation='relu'), + keras.layers.Flatten(), + keras.layers.Dense(64, activation=None), # Dense layer with no activation + keras.layers.UnitNormalization() # L2 normalization + ]) + + model.compile(optimizer="adam", loss=losses.Circle() + ``` + + Reference: + [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857) + """ + + def __init__( + self, + gamma=80.0, + margin=0.4, + remove_diagonal=True, + reduction="sum_over_batch_size", + name="circle", + dtype=None, + ): + super().__init__( + circle, + name=name, + reduction=reduction, + dtype=dtype, + gamma=gamma, + margin=margin, + remove_diagonal=remove_diagonal, + ) + self.gamma = gamma + self.margin = margin + self.remove_diagonal = remove_diagonal + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "gamma": self.gamma, + "margin": self.margin, + "remove_diagonal": self.remove_diagonal, + } + ) + return config + + def convert_binary_labels_to_hinge(y_true): """Converts binary labels into -1/1 for hinge loss/metric calculation.""" are_zeros = ops.equal(y_true, 0) @@ -2406,3 +2492,91 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5): ) return 1 - tversky + + +@keras_export("keras.losses.circle") +def circle( + y_true, + y_pred, + ref_labels=None, + ref_embeddings=None, + remove_diagonal=True, + gamma=80, + margin=0.4, +): + """Computes the Circle loss between `y_true` and `y_pred`. + + It is designed to minimize within-class distances and maximize between-class distances in embedding + space. + + Args: + y_true: Tensor of shape `[batch_size]` with ground truth labels in integer format. Can also be treated as query labels. + y_pred: Tensor of shape `[batch_size, embedding_dim]` with predicted L2 normalized embeddings. Can also be treated as query embeddings + ref_labels: Optional integer tensor with labels for reference embeddings. + If `None`, defaults to `y_true`. + ref_embeddings: Optional tensor with L2 normalized reference embeddings. + If `None`, defaults to `y_pred`. + remove_diagonal: Boolean, whether to remove self-similarities from positive mask. + Defaults to `True`. + gamma: Float, scaling factor for the loss. Defaults to `80`. + margin: Float, relaxation factor for the loss. Defaults to `0.4`. + + Returns: + Circle loss value. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, "int32") + ref_embeddings = ( + y_pred + if ref_embeddings is None + else ops.convert_to_tensor(ref_embeddings) + ) + ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32") + + optim_pos = margin + optim_neg = 1 + margin + delta_pos = margin + delta_neg = 1 - margin + + pairwise_cosine_distances = 1 - ops.matmul( + y_pred, ops.transpose(ref_embeddings) + ) + + pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0) + positive_mask, negative_mask = build_pos_neg_masks( + y_true, + ref_labels, + remove_diagonal=remove_diagonal, + ) + positive_mask = ops.cast( + positive_mask, dtype=pairwise_cosine_distances.dtype + ) + negative_mask = ops.cast( + negative_mask, dtype=pairwise_cosine_distances.dtype + ) + + pos_weights = optim_pos + pairwise_cosine_distances + pos_weights = pos_weights * positive_mask + pos_weights = ops.maximum(pos_weights, 0.0) + neg_weights = optim_neg - pairwise_cosine_distances + neg_weights = neg_weights * negative_mask + neg_weights = ops.maximum(neg_weights, 0.0) + + pos_dists = delta_pos - pairwise_cosine_distances + neg_dists = delta_neg - pairwise_cosine_distances + + pos_wdists = -1 * gamma * pos_weights * pos_dists + neg_wdists = gamma * neg_weights * neg_dists + + p_loss = ops.logsumexp( + ops.where(positive_mask, pos_wdists, float("-inf")), + axis=1, + ) + n_loss = ops.logsumexp( + ops.where(negative_mask, neg_wdists, float("-inf")), + axis=1, + ) + + circle_loss = ops.softplus(p_loss + n_loss) + backend.set_keras_mask(circle_loss, circle_loss > 0) + return circle_loss diff --git a/keras/src/losses/losses_test.py b/keras/src/losses/losses_test.py index 2ed7920a018e..65879123ff44 100644 --- a/keras/src/losses/losses_test.py +++ b/keras/src/losses/losses_test.py @@ -1645,3 +1645,105 @@ def test_dtype_arg(self): y_pred = np.array(([[4, 1], [6, 1]])) output = losses.Tversky(dtype="bfloat16")(y_true, y_pred) self.assertDType(output, "bfloat16") + + +class CircleLossTest(testing.TestCase): + def setup(self): + self.y_true = np.array([1, 1, 2, 2, 3]) + self.y_pred = np.array( + [ + [0.70014004, -0.42008403, 0.14002801, 0.56011203], + [0.17609018, 0.70436073, -0.52827054, 0.44022545], + [-0.34050261, 0.25537696, -0.68100522, 0.59587957], + [0.32163376, -0.75047877, 0.53605627, -0.21442251], + [0.51261459, -0.34174306, 0.17087153, 0.76892189], + ] + ) + self.ref_labels = np.array([1, 1, 2, 2, 3, 4]) + self.ref_embeddings = np.array( + [ + [0.40824829, -0.54433105, 0.27216553, 0.68041382], + [0.76376261, 0.10910895, -0.54554473, 0.32732684], + [-0.74420841, 0.24806947, 0.49613894, -0.3721042], + [0.52981294, -0.13245324, 0.79471941, -0.26490647], + [0.54554473, -0.32732684, 0.10910895, 0.76376261], + [-0.27216553, 0.68041382, 0.40824829, -0.54433105], + ] + ) + + def test_config(self): + self.run_class_serialization_test( + losses.Circle(name="mycircle", gamma=80.0, margin=0.4) + ) + + def test_correctness(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4) + loss = circle_loss(self.y_true, self.y_pred) + self.assertAlmostEqual(loss, 188.3883) + + circle_loss = losses.Circle(gamma=256, margin=0.25) + loss = circle_loss(self.y_true, self.y_pred) + self.assertAlmostEqual(loss, 652.7617) + + loss = losses.circle( + self.y_true, + self.y_pred, + ref_labels=self.ref_labels, + ref_embeddings=self.ref_embeddings, + gamma=80.0, + margin=0.4, + remove_diagonal=False, + ) + + self.assertAllClose( + loss, (61.5844, 94.3465, 276.9344, 90.9873, 48.8963) + ) + + def test_correctness_weighted(self): + self.setup() + sample_weight = np.array([2.0, 2.0, 1.0, 1.0, 0.5]) + circle_loss = losses.Circle(gamma=80.0, margin=0.4) + loss = circle_loss( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, 244.91918) + + def test_zero_weighted(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4) + loss = circle_loss(self.y_true, self.y_pred, sample_weight=0.0) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_no_reduction(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction=None) + loss = circle_loss(self.ref_labels, self.ref_embeddings) + + self.assertAllClose( + loss, [82.9116, 36.7942, 92.4590, 52.6798, 0.0, 0.0] + ) + + def test_sum_reduction(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction="sum") + loss = circle_loss(self.ref_labels, self.ref_embeddings) + + self.assertAlmostEqual(loss, 264.845) + + def test_mean_with_sample_weight_reduction(self): + self.setup() + sample_weight = np.array([2.0, 2.0, 1.0, 1.0, 0.5]) + circle_loss = losses.Circle( + gamma=80.0, margin=0.4, reduction="mean_with_sample_weight" + ) + loss = circle_loss( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, 163.27948) + + def test_dtype_arg(self): + self.setup() + circle_loss = losses.Circle(dtype="bfloat16") + loss = circle_loss(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") diff --git a/keras/src/utils/numerical_utils.py b/keras/src/utils/numerical_utils.py index 0b8427551337..dcd2cc422d6a 100644 --- a/keras/src/utils/numerical_utils.py +++ b/keras/src/utils/numerical_utils.py @@ -193,3 +193,33 @@ def encode_categorical_inputs( axis=reduction_axis, ) return outputs + + +def build_pos_neg_masks( + query_labels, + key_labels, + remove_diagonal=True, +): + from keras.src import ops + + if ops.ndim(query_labels) == 1: + query_labels = ops.reshape(query_labels, (-1, 1)) + + if ops.ndim(key_labels) == 1: + key_labels = ops.reshape(key_labels, (-1, 1)) + + positive_mask = ops.equal(query_labels, ops.transpose(key_labels)) + negative_mask = ops.logical_not(positive_mask) + + if remove_diagonal: + positive_mask = ops.logical_and( + positive_mask, + ~ops.eye( + ops.size(query_labels), + ops.size(key_labels), + k=0, + dtype="bool", + ), + ) + + return positive_mask, negative_mask From d064342c9a9797637800b46ef3764d98da5d533e Mon Sep 17 00:00:00 2001 From: ma7555 <7144929+ma7555@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:32:12 +0200 Subject: [PATCH 2/6] ruff fixes --- keras/src/losses/__init__.py | 4 ++-- keras/src/losses/losses.py | 29 ++++++++++++++++------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/keras/src/losses/__init__.py b/keras/src/losses/__init__.py index 7edada501d09..7afeb55a01d1 100644 --- a/keras/src/losses/__init__.py +++ b/keras/src/losses/__init__.py @@ -8,6 +8,7 @@ from keras.src.losses.losses import CategoricalCrossentropy from keras.src.losses.losses import CategoricalFocalCrossentropy from keras.src.losses.losses import CategoricalHinge +from keras.src.losses.losses import Circle from keras.src.losses.losses import CosineSimilarity from keras.src.losses.losses import Dice from keras.src.losses.losses import Hinge @@ -23,12 +24,12 @@ from keras.src.losses.losses import SparseCategoricalCrossentropy from keras.src.losses.losses import SquaredHinge from keras.src.losses.losses import Tversky -from keras.src.losses.losses import Circle from keras.src.losses.losses import binary_crossentropy from keras.src.losses.losses import binary_focal_crossentropy from keras.src.losses.losses import categorical_crossentropy from keras.src.losses.losses import categorical_focal_crossentropy from keras.src.losses.losses import categorical_hinge +from keras.src.losses.losses import circle from keras.src.losses.losses import cosine_similarity from keras.src.losses.losses import ctc from keras.src.losses.losses import dice @@ -44,7 +45,6 @@ from keras.src.losses.losses import sparse_categorical_crossentropy from keras.src.losses.losses import squared_hinge from keras.src.losses.losses import tversky -from keras.src.losses.losses import circle from keras.src.saving import serialization_lib ALL_OBJECTS = { diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index a82b63541e5b..5a29c15feead 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -1416,11 +1416,14 @@ class Circle(LossFunctionWrapper): followed by UnitNormalization layer to ensure unit-norm embeddings. Args: - gamma: Scaling factor that determines the largest scale of each similarity score. Defaults to `80`. + gamma: Scaling factor that determines the largest scale of each + similarity score. Defaults to `80`. margin: The relaxation factor, below this distance, negatives are up weighted and positives are down weighted. Similarly, above this - distance negatives are down weighted and positive are up weighted. Defaults to `0.4`. - remove_diagonal: Boolean indicating whether to remove self-similarities from the positive mask. Defaults to `True`. + distance negatives are down weighted and positive are up weighted. + Defaults to `0.4`. + remove_diagonal: Boolean, whether to remove self-similarities from the + positive mask. Defaults to `True`. reduction: Type of reduction to apply to the loss. In almost all cases this should be `"sum_over_batch_size"`. Supported options are `"sum"`, `"sum_over_batch_size"`, `"mean"`, @@ -1444,7 +1447,7 @@ class Circle(LossFunctionWrapper): keras.layers.Input(shape=(224, 224, 3)), keras.layers.Conv2D(16, (3, 3), activation='relu'), keras.layers.Flatten(), - keras.layers.Dense(64, activation=None), # Dense layer with no activation + keras.layers.Dense(64, activation=None), # No activation keras.layers.UnitNormalization() # L2 normalization ]) @@ -2504,20 +2507,20 @@ def circle( gamma=80, margin=0.4, ): - """Computes the Circle loss between `y_true` and `y_pred`. + """Computes the Circle loss - It is designed to minimize within-class distances and maximize between-class distances in embedding - space. + It is designed to minimize within-class distances and maximize between-class + distances in embedding space. Args: - y_true: Tensor of shape `[batch_size]` with ground truth labels in integer format. Can also be treated as query labels. - y_pred: Tensor of shape `[batch_size, embedding_dim]` with predicted L2 normalized embeddings. Can also be treated as query embeddings - ref_labels: Optional integer tensor with labels for reference embeddings. - If `None`, defaults to `y_true`. + y_true: Tensor with ground truth labels in integer format. + y_pred: Tensor with predicted L2 normalized embeddings. + ref_labels: Optional integer tensor with labels for reference + embeddings. If `None`, defaults to `y_true`. ref_embeddings: Optional tensor with L2 normalized reference embeddings. If `None`, defaults to `y_pred`. - remove_diagonal: Boolean, whether to remove self-similarities from positive mask. - Defaults to `True`. + remove_diagonal: Boolean, whether to remove self-similarities from + positive mask. Defaults to `True`. gamma: Float, scaling factor for the loss. Defaults to `80`. margin: Float, relaxation factor for the loss. Defaults to `0.4`. From a4d4fadb7527670735e4986a22e54863a30b312b Mon Sep 17 00:00:00 2001 From: ma7555 <7144929+ma7555@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:25:58 +0200 Subject: [PATCH 3/6] hotfix for logsumexp numerical unstability with -inf values --- keras/src/losses/losses.py | 2 ++ keras/src/losses/losses_test.py | 8 +------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index 5a29c15feead..427a7fab679a 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -2579,6 +2579,8 @@ def circle( ops.where(negative_mask, neg_wdists, float("-inf")), axis=1, ) + p_loss = ops.where(ops.isnan(p_loss), float("-inf"), p_loss) + n_loss = ops.where(ops.isnan(n_loss), float("-inf"), n_loss) circle_loss = ops.softplus(p_loss + n_loss) backend.set_keras_mask(circle_loss, circle_loss > 0) diff --git a/keras/src/losses/losses_test.py b/keras/src/losses/losses_test.py index 65879123ff44..bbecbc06d085 100644 --- a/keras/src/losses/losses_test.py +++ b/keras/src/losses/losses_test.py @@ -1647,7 +1647,7 @@ def test_dtype_arg(self): self.assertDType(output, "bfloat16") -class CircleLossTest(testing.TestCase): +class CircleTest(testing.TestCase): def setup(self): self.y_true = np.array([1, 1, 2, 2, 3]) self.y_pred = np.array( @@ -1709,12 +1709,6 @@ def test_correctness_weighted(self): ) self.assertAlmostEqual(loss, 244.91918) - def test_zero_weighted(self): - self.setup() - circle_loss = losses.Circle(gamma=80.0, margin=0.4) - loss = circle_loss(self.y_true, self.y_pred, sample_weight=0.0) - self.assertAlmostEqual(loss, 0.0, 3) - def test_no_reduction(self): self.setup() circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction=None) From e3fd53cc0f5bb495693406d2bea0a716a9fc472b Mon Sep 17 00:00:00 2001 From: ma7555 <7144929+ma7555@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:27:04 +0200 Subject: [PATCH 4/6] actual fix for logsumexp -inf unstability --- keras/src/backend/jax/math.py | 6 +----- keras/src/backend/torch/math.py | 12 ++---------- keras/src/losses/losses.py | 2 -- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index f359f7cfd7aa..6b04f58a4303 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -52,11 +52,7 @@ def in_top_k(targets, predictions, k): def logsumexp(x, axis=None, keepdims=False): - max_x = jnp.max(x, axis=axis, keepdims=True) - result = ( - jnp.log(jnp.sum(jnp.exp(x - max_x), axis=axis, keepdims=True)) + max_x - ) - return jnp.squeeze(result) if not keepdims else result + return jax.scipy.special.logsumexp(x, axis=axis, keepdims=keepdims) def qr(x, mode="reduced"): diff --git a/keras/src/backend/torch/math.py b/keras/src/backend/torch/math.py index 2ddb26165ac4..20e06b2717a9 100644 --- a/keras/src/backend/torch/math.py +++ b/keras/src/backend/torch/math.py @@ -81,16 +81,8 @@ def in_top_k(targets, predictions, k): def logsumexp(x, axis=None, keepdims=False): x = convert_to_tensor(x) - if axis is None: - max_x = torch.max(x) - return torch.log(torch.sum(torch.exp(x - max_x))) + max_x - - max_x = torch.amax(x, dim=axis, keepdim=True) - result = ( - torch.log(torch.sum(torch.exp(x - max_x), dim=axis, keepdim=True)) - + max_x - ) - return torch.squeeze(result, dim=axis) if not keepdims else result + axis = tuple(range(x.dim())) if axis is None else axis + return torch.logsumexp(x, dim=axis, keepdim=keepdims) def qr(x, mode="reduced"): diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index 427a7fab679a..5a29c15feead 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -2579,8 +2579,6 @@ def circle( ops.where(negative_mask, neg_wdists, float("-inf")), axis=1, ) - p_loss = ops.where(ops.isnan(p_loss), float("-inf"), p_loss) - n_loss = ops.where(ops.isnan(n_loss), float("-inf"), n_loss) circle_loss = ops.softplus(p_loss + n_loss) backend.set_keras_mask(circle_loss, circle_loss > 0) From 8011ce007a48449bc976209fdbaf87a6faf557f9 Mon Sep 17 00:00:00 2001 From: ma7555 <7144929+ma7555@users.noreply.github.com> Date: Wed, 6 Nov 2024 01:46:55 +0200 Subject: [PATCH 5/6] Add tests, fix numpy logsumexp, and update Circle Loss docstrings. --- keras/src/backend/numpy/math.py | 4 +- keras/src/losses/losses.py | 17 +++--- keras/src/utils/numerical_utils_test.py | 77 +++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/keras/src/backend/numpy/math.py b/keras/src/backend/numpy/math.py index b96fbece8532..bec628f915ad 100644 --- a/keras/src/backend/numpy/math.py +++ b/keras/src/backend/numpy/math.py @@ -76,9 +76,7 @@ def in_top_k(targets, predictions, k): def logsumexp(x, axis=None, keepdims=False): - max_x = np.max(x, axis=axis, keepdims=True) - result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x - return np.squeeze(result) if not keepdims else result + return scipy.special.logsumexp(x, axis=axis, keepdims=keepdims) def qr(x, mode="reduced"): diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index 5a29c15feead..b65ab4f05f05 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -1406,10 +1406,12 @@ def get_config(self): @keras_export("keras.losses.Circle") class Circle(LossFunctionWrapper): - """Computes Circle Loss, a metric learning loss designed to minimize - within-class distance and maximize between-class distance in a flexible - manner by dynamically adjusting the penalty strength based on optimization - status of each similarity score. + """Computes Circle Loss between integer labels and L2-normalized embeddings. + + This is a metric learning loss designed to minimize within-class distance + and maximize between-class distance in a flexible manner by dynamically + adjusting the penalty strength based on optimization status of each + similarity score. To use Circle Loss effectively, the model should output embeddings without an activation function (such as a `Dense` layer with `activation=None`) @@ -1455,7 +1457,8 @@ class Circle(LossFunctionWrapper): ``` Reference: - [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857) + - [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857) + """ def __init__( @@ -2507,10 +2510,10 @@ def circle( gamma=80, margin=0.4, ): - """Computes the Circle loss + """Computes the Circle loss. It is designed to minimize within-class distances and maximize between-class - distances in embedding space. + distances in L2 normalized embedding space. Args: y_true: Tensor with ground truth labels in integer format. diff --git a/keras/src/utils/numerical_utils_test.py b/keras/src/utils/numerical_utils_test.py index 41e2f1b3b94d..9b9520abc90e 100644 --- a/keras/src/utils/numerical_utils_test.py +++ b/keras/src/utils/numerical_utils_test.py @@ -72,3 +72,80 @@ def test_normalize(self, order): out = numerical_utils.normalize(xb, axis=-1, order=order) self.assertTrue(backend.is_tensor(out)) self.assertAllClose(backend.convert_to_numpy(out), expected) + + def test_build_pos_neg_masks(self): + query_labels = np.array([0, 1, 2, 2, 0]) + key_labels = np.array([0, 1, 2, 0, 2]) + expected_shape = (len(query_labels), len(key_labels)) + + positive_mask, negative_mask = numerical_utils.build_pos_neg_masks( + query_labels, key_labels, remove_diagonal=False + ) + + positive_mask = backend.convert_to_numpy(positive_mask) + negative_mask = backend.convert_to_numpy(negative_mask) + self.assertEqual(positive_mask.shape, expected_shape) + self.assertEqual(negative_mask.shape, expected_shape) + self.assertTrue( + np.all(np.logical_not(np.logical_and(positive_mask, negative_mask))) + ) + + expected_positive_mask_keep_diag = np.array( + [ + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + ], + dtype="bool", + ) + self.assertTrue( + np.all(positive_mask == expected_positive_mask_keep_diag) + ) + self.assertTrue( + np.all( + negative_mask + == np.logical_not(expected_positive_mask_keep_diag) + ) + ) + + positive_mask, negative_mask = numerical_utils.build_pos_neg_masks( + query_labels, key_labels, remove_diagonal=True + ) + positive_mask = backend.convert_to_numpy(positive_mask) + negative_mask = backend.convert_to_numpy(negative_mask) + self.assertEqual(positive_mask.shape, expected_shape) + self.assertEqual(negative_mask.shape, expected_shape) + self.assertTrue( + np.all(np.logical_not(np.logical_and(positive_mask, negative_mask))) + ) + + expected_positive_mask_with_remove_diag = np.array( + [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + ], + dtype="bool", + ) + self.assertTrue( + np.all(positive_mask == expected_positive_mask_with_remove_diag) + ) + + query_labels = np.array([1, 2, 3]) + key_labels = np.array([1, 2, 3, 1]) + + positive_mask, negative_mask = numerical_utils.build_pos_neg_masks( + query_labels, key_labels, remove_diagonal=True + ) + positive_mask = backend.convert_to_numpy(positive_mask) + negative_mask = backend.convert_to_numpy(negative_mask) + expected_shape_diff_sizes = (len(query_labels), len(key_labels)) + self.assertEqual(positive_mask.shape, expected_shape_diff_sizes) + self.assertEqual(negative_mask.shape, expected_shape_diff_sizes) + self.assertTrue( + np.all(np.logical_not(np.logical_and(positive_mask, negative_mask))) + ) From 192e6f7e5bfa11251193806023b2a835079038c1 Mon Sep 17 00:00:00 2001 From: ma7555 <7144929+ma7555@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:50:13 +0200 Subject: [PATCH 6/6] run api_gen.sh --- keras/api/_tf_keras/keras/losses/__init__.py | 2 ++ keras/api/losses/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/keras/api/_tf_keras/keras/losses/__init__.py b/keras/api/_tf_keras/keras/losses/__init__.py index 832d78f5fda0..e64b91b308ab 100644 --- a/keras/api/_tf_keras/keras/losses/__init__.py +++ b/keras/api/_tf_keras/keras/losses/__init__.py @@ -15,6 +15,7 @@ from keras.src.losses.losses import CategoricalCrossentropy from keras.src.losses.losses import CategoricalFocalCrossentropy from keras.src.losses.losses import CategoricalHinge +from keras.src.losses.losses import Circle from keras.src.losses.losses import CosineSimilarity from keras.src.losses.losses import Dice from keras.src.losses.losses import Hinge @@ -34,6 +35,7 @@ from keras.src.losses.losses import categorical_crossentropy from keras.src.losses.losses import categorical_focal_crossentropy from keras.src.losses.losses import categorical_hinge +from keras.src.losses.losses import circle from keras.src.losses.losses import cosine_similarity from keras.src.losses.losses import ctc from keras.src.losses.losses import dice diff --git a/keras/api/losses/__init__.py b/keras/api/losses/__init__.py index ecaadddf6b7e..af88a721cc4f 100644 --- a/keras/api/losses/__init__.py +++ b/keras/api/losses/__init__.py @@ -14,6 +14,7 @@ from keras.src.losses.losses import CategoricalCrossentropy from keras.src.losses.losses import CategoricalFocalCrossentropy from keras.src.losses.losses import CategoricalHinge +from keras.src.losses.losses import Circle from keras.src.losses.losses import CosineSimilarity from keras.src.losses.losses import Dice from keras.src.losses.losses import Hinge @@ -33,6 +34,7 @@ from keras.src.losses.losses import categorical_crossentropy from keras.src.losses.losses import categorical_focal_crossentropy from keras.src.losses.losses import categorical_hinge +from keras.src.losses.losses import circle from keras.src.losses.losses import cosine_similarity from keras.src.losses.losses import ctc from keras.src.losses.losses import dice