Skip to content

Commit 29233b8

Browse files
ma7555wang-xianghao
authored andcommitted
Add Circle Loss Function for Similarity/Metric Learning Tasks. (keras-team#20452)
* update keras/src/losses/__init__.py, losses.py, losses_test.py and numerical_utils.py * ruff fixes * hotfix for logsumexp numerical unstability with -inf values * actual fix for logsumexp -inf unstability * Add tests, fix numpy logsumexp, and update Circle Loss docstrings. * run api_gen.sh
1 parent 4986616 commit 29233b8

File tree

10 files changed

+397
-18
lines changed

10 files changed

+397
-18
lines changed

keras/api/_tf_keras/keras/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from keras.src.losses.losses import CategoricalCrossentropy
1616
from keras.src.losses.losses import CategoricalFocalCrossentropy
1717
from keras.src.losses.losses import CategoricalHinge
18+
from keras.src.losses.losses import Circle
1819
from keras.src.losses.losses import CosineSimilarity
1920
from keras.src.losses.losses import Dice
2021
from keras.src.losses.losses import Hinge
@@ -34,6 +35,7 @@
3435
from keras.src.losses.losses import categorical_crossentropy
3536
from keras.src.losses.losses import categorical_focal_crossentropy
3637
from keras.src.losses.losses import categorical_hinge
38+
from keras.src.losses.losses import circle
3739
from keras.src.losses.losses import cosine_similarity
3840
from keras.src.losses.losses import ctc
3941
from keras.src.losses.losses import dice

keras/api/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src.losses.losses import CategoricalCrossentropy
1515
from keras.src.losses.losses import CategoricalFocalCrossentropy
1616
from keras.src.losses.losses import CategoricalHinge
17+
from keras.src.losses.losses import Circle
1718
from keras.src.losses.losses import CosineSimilarity
1819
from keras.src.losses.losses import Dice
1920
from keras.src.losses.losses import Hinge
@@ -33,6 +34,7 @@
3334
from keras.src.losses.losses import categorical_crossentropy
3435
from keras.src.losses.losses import categorical_focal_crossentropy
3536
from keras.src.losses.losses import categorical_hinge
37+
from keras.src.losses.losses import circle
3638
from keras.src.losses.losses import cosine_similarity
3739
from keras.src.losses.losses import ctc
3840
from keras.src.losses.losses import dice

keras/src/backend/jax/math.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ def in_top_k(targets, predictions, k):
5252

5353

5454
def logsumexp(x, axis=None, keepdims=False):
55-
max_x = jnp.max(x, axis=axis, keepdims=True)
56-
result = (
57-
jnp.log(jnp.sum(jnp.exp(x - max_x), axis=axis, keepdims=True)) + max_x
58-
)
59-
return jnp.squeeze(result) if not keepdims else result
55+
return jax.scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)
6056

6157

6258
def qr(x, mode="reduced"):

keras/src/backend/numpy/math.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def in_top_k(targets, predictions, k):
7676

7777

7878
def logsumexp(x, axis=None, keepdims=False):
79-
max_x = np.max(x, axis=axis, keepdims=True)
80-
result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x
81-
return np.squeeze(result) if not keepdims else result
79+
return scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)
8280

8381

8482
def qr(x, mode="reduced"):

keras/src/backend/torch/math.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,8 @@ def in_top_k(targets, predictions, k):
8181

8282
def logsumexp(x, axis=None, keepdims=False):
8383
x = convert_to_tensor(x)
84-
if axis is None:
85-
max_x = torch.max(x)
86-
return torch.log(torch.sum(torch.exp(x - max_x))) + max_x
87-
88-
max_x = torch.amax(x, dim=axis, keepdim=True)
89-
result = (
90-
torch.log(torch.sum(torch.exp(x - max_x), dim=axis, keepdim=True))
91-
+ max_x
92-
)
93-
return torch.squeeze(result, dim=axis) if not keepdims else result
84+
axis = tuple(range(x.dim())) if axis is None else axis
85+
return torch.logsumexp(x, dim=axis, keepdim=keepdims)
9486

9587

9688
def qr(x, mode="reduced"):

keras/src/losses/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.losses.losses import CategoricalCrossentropy
99
from keras.src.losses.losses import CategoricalFocalCrossentropy
1010
from keras.src.losses.losses import CategoricalHinge
11+
from keras.src.losses.losses import Circle
1112
from keras.src.losses.losses import CosineSimilarity
1213
from keras.src.losses.losses import Dice
1314
from keras.src.losses.losses import Hinge
@@ -28,6 +29,7 @@
2829
from keras.src.losses.losses import categorical_crossentropy
2930
from keras.src.losses.losses import categorical_focal_crossentropy
3031
from keras.src.losses.losses import categorical_hinge
32+
from keras.src.losses.losses import circle
3133
from keras.src.losses.losses import cosine_similarity
3234
from keras.src.losses.losses import ctc
3335
from keras.src.losses.losses import dice
@@ -72,6 +74,8 @@
7274
# Image segmentation
7375
Dice,
7476
Tversky,
77+
# Similarity
78+
Circle,
7579
# Sequence
7680
CTC,
7781
# Probabilistic
@@ -97,6 +101,8 @@
97101
# Image segmentation
98102
dice,
99103
tversky,
104+
# Similarity
105+
circle,
100106
# Sequence
101107
ctc,
102108
}

keras/src/losses/losses.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.losses.loss import Loss
88
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
99
from keras.src.saving import serialization_lib
10+
from keras.src.utils.numerical_utils import build_pos_neg_masks
1011
from keras.src.utils.numerical_utils import normalize
1112

1213

@@ -1403,6 +1404,97 @@ def get_config(self):
14031404
return config
14041405

14051406

1407+
@keras_export("keras.losses.Circle")
1408+
class Circle(LossFunctionWrapper):
1409+
"""Computes Circle Loss between integer labels and L2-normalized embeddings.
1410+
1411+
This is a metric learning loss designed to minimize within-class distance
1412+
and maximize between-class distance in a flexible manner by dynamically
1413+
adjusting the penalty strength based on optimization status of each
1414+
similarity score.
1415+
1416+
To use Circle Loss effectively, the model should output embeddings without
1417+
an activation function (such as a `Dense` layer with `activation=None`)
1418+
followed by UnitNormalization layer to ensure unit-norm embeddings.
1419+
1420+
Args:
1421+
gamma: Scaling factor that determines the largest scale of each
1422+
similarity score. Defaults to `80`.
1423+
margin: The relaxation factor, below this distance, negatives are
1424+
up weighted and positives are down weighted. Similarly, above this
1425+
distance negatives are down weighted and positive are up weighted.
1426+
Defaults to `0.4`.
1427+
remove_diagonal: Boolean, whether to remove self-similarities from the
1428+
positive mask. Defaults to `True`.
1429+
reduction: Type of reduction to apply to the loss. In almost all cases
1430+
this should be `"sum_over_batch_size"`. Supported options are
1431+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
1432+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1433+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1434+
sample size, and `"mean_with_sample_weight"` sums the loss and
1435+
divides by the sum of the sample weights. `"none"` and `None`
1436+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
1437+
name: Optional name for the loss instance.
1438+
dtype: The dtype of the loss's computations. Defaults to `None`, which
1439+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1440+
`"float32"` unless set to different value
1441+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1442+
provided, then the `compute_dtype` will be utilized.
1443+
1444+
Examples:
1445+
Usage with the `compile()` API:
1446+
1447+
```python
1448+
model = models.Sequential([
1449+
keras.layers.Input(shape=(224, 224, 3)),
1450+
keras.layers.Conv2D(16, (3, 3), activation='relu'),
1451+
keras.layers.Flatten(),
1452+
keras.layers.Dense(64, activation=None), # No activation
1453+
keras.layers.UnitNormalization() # L2 normalization
1454+
])
1455+
1456+
model.compile(optimizer="adam", loss=losses.Circle()
1457+
```
1458+
1459+
Reference:
1460+
- [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857)
1461+
1462+
"""
1463+
1464+
def __init__(
1465+
self,
1466+
gamma=80.0,
1467+
margin=0.4,
1468+
remove_diagonal=True,
1469+
reduction="sum_over_batch_size",
1470+
name="circle",
1471+
dtype=None,
1472+
):
1473+
super().__init__(
1474+
circle,
1475+
name=name,
1476+
reduction=reduction,
1477+
dtype=dtype,
1478+
gamma=gamma,
1479+
margin=margin,
1480+
remove_diagonal=remove_diagonal,
1481+
)
1482+
self.gamma = gamma
1483+
self.margin = margin
1484+
self.remove_diagonal = remove_diagonal
1485+
1486+
def get_config(self):
1487+
config = Loss.get_config(self)
1488+
config.update(
1489+
{
1490+
"gamma": self.gamma,
1491+
"margin": self.margin,
1492+
"remove_diagonal": self.remove_diagonal,
1493+
}
1494+
)
1495+
return config
1496+
1497+
14061498
def convert_binary_labels_to_hinge(y_true):
14071499
"""Converts binary labels into -1/1 for hinge loss/metric calculation."""
14081500
are_zeros = ops.equal(y_true, 0)
@@ -2406,3 +2498,91 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
24062498
)
24072499

24082500
return 1 - tversky
2501+
2502+
2503+
@keras_export("keras.losses.circle")
2504+
def circle(
2505+
y_true,
2506+
y_pred,
2507+
ref_labels=None,
2508+
ref_embeddings=None,
2509+
remove_diagonal=True,
2510+
gamma=80,
2511+
margin=0.4,
2512+
):
2513+
"""Computes the Circle loss.
2514+
2515+
It is designed to minimize within-class distances and maximize between-class
2516+
distances in L2 normalized embedding space.
2517+
2518+
Args:
2519+
y_true: Tensor with ground truth labels in integer format.
2520+
y_pred: Tensor with predicted L2 normalized embeddings.
2521+
ref_labels: Optional integer tensor with labels for reference
2522+
embeddings. If `None`, defaults to `y_true`.
2523+
ref_embeddings: Optional tensor with L2 normalized reference embeddings.
2524+
If `None`, defaults to `y_pred`.
2525+
remove_diagonal: Boolean, whether to remove self-similarities from
2526+
positive mask. Defaults to `True`.
2527+
gamma: Float, scaling factor for the loss. Defaults to `80`.
2528+
margin: Float, relaxation factor for the loss. Defaults to `0.4`.
2529+
2530+
Returns:
2531+
Circle loss value.
2532+
"""
2533+
y_pred = ops.convert_to_tensor(y_pred)
2534+
y_true = ops.cast(y_true, "int32")
2535+
ref_embeddings = (
2536+
y_pred
2537+
if ref_embeddings is None
2538+
else ops.convert_to_tensor(ref_embeddings)
2539+
)
2540+
ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32")
2541+
2542+
optim_pos = margin
2543+
optim_neg = 1 + margin
2544+
delta_pos = margin
2545+
delta_neg = 1 - margin
2546+
2547+
pairwise_cosine_distances = 1 - ops.matmul(
2548+
y_pred, ops.transpose(ref_embeddings)
2549+
)
2550+
2551+
pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0)
2552+
positive_mask, negative_mask = build_pos_neg_masks(
2553+
y_true,
2554+
ref_labels,
2555+
remove_diagonal=remove_diagonal,
2556+
)
2557+
positive_mask = ops.cast(
2558+
positive_mask, dtype=pairwise_cosine_distances.dtype
2559+
)
2560+
negative_mask = ops.cast(
2561+
negative_mask, dtype=pairwise_cosine_distances.dtype
2562+
)
2563+
2564+
pos_weights = optim_pos + pairwise_cosine_distances
2565+
pos_weights = pos_weights * positive_mask
2566+
pos_weights = ops.maximum(pos_weights, 0.0)
2567+
neg_weights = optim_neg - pairwise_cosine_distances
2568+
neg_weights = neg_weights * negative_mask
2569+
neg_weights = ops.maximum(neg_weights, 0.0)
2570+
2571+
pos_dists = delta_pos - pairwise_cosine_distances
2572+
neg_dists = delta_neg - pairwise_cosine_distances
2573+
2574+
pos_wdists = -1 * gamma * pos_weights * pos_dists
2575+
neg_wdists = gamma * neg_weights * neg_dists
2576+
2577+
p_loss = ops.logsumexp(
2578+
ops.where(positive_mask, pos_wdists, float("-inf")),
2579+
axis=1,
2580+
)
2581+
n_loss = ops.logsumexp(
2582+
ops.where(negative_mask, neg_wdists, float("-inf")),
2583+
axis=1,
2584+
)
2585+
2586+
circle_loss = ops.softplus(p_loss + n_loss)
2587+
backend.set_keras_mask(circle_loss, circle_loss > 0)
2588+
return circle_loss

0 commit comments

Comments
 (0)