Skip to content

Add Circle Loss Function for Similarity/Metric Learning Tasks. #20452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from keras.src.losses.losses import CategoricalCrossentropy
from keras.src.losses.losses import CategoricalFocalCrossentropy
from keras.src.losses.losses import CategoricalHinge
from keras.src.losses.losses import Circle
from keras.src.losses.losses import CosineSimilarity
from keras.src.losses.losses import Dice
from keras.src.losses.losses import Hinge
Expand All @@ -34,6 +35,7 @@
from keras.src.losses.losses import categorical_crossentropy
from keras.src.losses.losses import categorical_focal_crossentropy
from keras.src.losses.losses import categorical_hinge
from keras.src.losses.losses import circle
from keras.src.losses.losses import cosine_similarity
from keras.src.losses.losses import ctc
from keras.src.losses.losses import dice
Expand Down
2 changes: 2 additions & 0 deletions keras/api/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from keras.src.losses.losses import CategoricalCrossentropy
from keras.src.losses.losses import CategoricalFocalCrossentropy
from keras.src.losses.losses import CategoricalHinge
from keras.src.losses.losses import Circle
from keras.src.losses.losses import CosineSimilarity
from keras.src.losses.losses import Dice
from keras.src.losses.losses import Hinge
Expand All @@ -33,6 +34,7 @@
from keras.src.losses.losses import categorical_crossentropy
from keras.src.losses.losses import categorical_focal_crossentropy
from keras.src.losses.losses import categorical_hinge
from keras.src.losses.losses import circle
from keras.src.losses.losses import cosine_similarity
from keras.src.losses.losses import ctc
from keras.src.losses.losses import dice
Expand Down
6 changes: 1 addition & 5 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ def in_top_k(targets, predictions, k):


def logsumexp(x, axis=None, keepdims=False):
max_x = jnp.max(x, axis=axis, keepdims=True)
result = (
jnp.log(jnp.sum(jnp.exp(x - max_x), axis=axis, keepdims=True)) + max_x
)
return jnp.squeeze(result) if not keepdims else result
return jax.scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)


def qr(x, mode="reduced"):
Expand Down
4 changes: 1 addition & 3 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def in_top_k(targets, predictions, k):


def logsumexp(x, axis=None, keepdims=False):
max_x = np.max(x, axis=axis, keepdims=True)
result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x
return np.squeeze(result) if not keepdims else result
return scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)


def qr(x, mode="reduced"):
Expand Down
12 changes: 2 additions & 10 deletions keras/src/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,8 @@ def in_top_k(targets, predictions, k):

def logsumexp(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
if axis is None:
max_x = torch.max(x)
return torch.log(torch.sum(torch.exp(x - max_x))) + max_x

max_x = torch.amax(x, dim=axis, keepdim=True)
result = (
torch.log(torch.sum(torch.exp(x - max_x), dim=axis, keepdim=True))
+ max_x
)
return torch.squeeze(result, dim=axis) if not keepdims else result
axis = tuple(range(x.dim())) if axis is None else axis
return torch.logsumexp(x, dim=axis, keepdim=keepdims)


def qr(x, mode="reduced"):
Expand Down
6 changes: 6 additions & 0 deletions keras/src/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.losses.losses import CategoricalCrossentropy
from keras.src.losses.losses import CategoricalFocalCrossentropy
from keras.src.losses.losses import CategoricalHinge
from keras.src.losses.losses import Circle
from keras.src.losses.losses import CosineSimilarity
from keras.src.losses.losses import Dice
from keras.src.losses.losses import Hinge
Expand All @@ -28,6 +29,7 @@
from keras.src.losses.losses import categorical_crossentropy
from keras.src.losses.losses import categorical_focal_crossentropy
from keras.src.losses.losses import categorical_hinge
from keras.src.losses.losses import circle
from keras.src.losses.losses import cosine_similarity
from keras.src.losses.losses import ctc
from keras.src.losses.losses import dice
Expand Down Expand Up @@ -72,6 +74,8 @@
# Image segmentation
Dice,
Tversky,
# Similarity
Circle,
# Sequence
CTC,
# Probabilistic
Expand All @@ -97,6 +101,8 @@
# Image segmentation
dice,
tversky,
# Similarity
circle,
# Sequence
ctc,
}
Expand Down
180 changes: 180 additions & 0 deletions keras/src/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.losses.loss import Loss
from keras.src.losses.loss import squeeze_or_expand_to_same_rank
from keras.src.saving import serialization_lib
from keras.src.utils.numerical_utils import build_pos_neg_masks
from keras.src.utils.numerical_utils import normalize


Expand Down Expand Up @@ -1403,6 +1404,97 @@ def get_config(self):
return config


@keras_export("keras.losses.Circle")
class Circle(LossFunctionWrapper):
"""Computes Circle Loss between integer labels and L2-normalized embeddings.

This is a metric learning loss designed to minimize within-class distance
and maximize between-class distance in a flexible manner by dynamically
adjusting the penalty strength based on optimization status of each
similarity score.

To use Circle Loss effectively, the model should output embeddings without
an activation function (such as a `Dense` layer with `activation=None`)
followed by UnitNormalization layer to ensure unit-norm embeddings.

Args:
gamma: Scaling factor that determines the largest scale of each
similarity score. Defaults to `80`.
margin: The relaxation factor, below this distance, negatives are
up weighted and positives are down weighted. Similarly, above this
distance negatives are down weighted and positive are up weighted.
Defaults to `0.4`.
remove_diagonal: Boolean, whether to remove self-similarities from the
positive mask. Defaults to `True`.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`. Supported options are
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
sample size, and `"mean_with_sample_weight"` sums the loss and
divides by the sum of the sample weights. `"none"` and `None`
perform no aggregation. Defaults to `"sum_over_batch_size"`.
name: Optional name for the loss instance.
dtype: The dtype of the loss's computations. Defaults to `None`, which
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
`"float32"` unless set to different value
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
provided, then the `compute_dtype` will be utilized.

Examples:
Usage with the `compile()` API:

```python
model = models.Sequential([
keras.layers.Input(shape=(224, 224, 3)),
keras.layers.Conv2D(16, (3, 3), activation='relu'),
keras.layers.Flatten(),
keras.layers.Dense(64, activation=None), # No activation
keras.layers.UnitNormalization() # L2 normalization
])

model.compile(optimizer="adam", loss=losses.Circle()
```

Reference:
- [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857)

"""

def __init__(
self,
gamma=80.0,
margin=0.4,
remove_diagonal=True,
reduction="sum_over_batch_size",
name="circle",
dtype=None,
):
super().__init__(
circle,
name=name,
reduction=reduction,
dtype=dtype,
gamma=gamma,
margin=margin,
remove_diagonal=remove_diagonal,
)
self.gamma = gamma
self.margin = margin
self.remove_diagonal = remove_diagonal

def get_config(self):
config = Loss.get_config(self)
config.update(
{
"gamma": self.gamma,
"margin": self.margin,
"remove_diagonal": self.remove_diagonal,
}
)
return config


def convert_binary_labels_to_hinge(y_true):
"""Converts binary labels into -1/1 for hinge loss/metric calculation."""
are_zeros = ops.equal(y_true, 0)
Expand Down Expand Up @@ -2406,3 +2498,91 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
)

return 1 - tversky


@keras_export("keras.losses.circle")
def circle(
y_true,
y_pred,
ref_labels=None,
ref_embeddings=None,
remove_diagonal=True,
gamma=80,
margin=0.4,
):
"""Computes the Circle loss.

It is designed to minimize within-class distances and maximize between-class
distances in L2 normalized embedding space.

Args:
y_true: Tensor with ground truth labels in integer format.
y_pred: Tensor with predicted L2 normalized embeddings.
ref_labels: Optional integer tensor with labels for reference
embeddings. If `None`, defaults to `y_true`.
ref_embeddings: Optional tensor with L2 normalized reference embeddings.
If `None`, defaults to `y_pred`.
remove_diagonal: Boolean, whether to remove self-similarities from
positive mask. Defaults to `True`.
gamma: Float, scaling factor for the loss. Defaults to `80`.
margin: Float, relaxation factor for the loss. Defaults to `0.4`.

Returns:
Circle loss value.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, "int32")
ref_embeddings = (
y_pred
if ref_embeddings is None
else ops.convert_to_tensor(ref_embeddings)
)
ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32")

optim_pos = margin
optim_neg = 1 + margin
delta_pos = margin
delta_neg = 1 - margin

pairwise_cosine_distances = 1 - ops.matmul(
y_pred, ops.transpose(ref_embeddings)
)

pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0)
positive_mask, negative_mask = build_pos_neg_masks(
y_true,
ref_labels,
remove_diagonal=remove_diagonal,
)
positive_mask = ops.cast(
positive_mask, dtype=pairwise_cosine_distances.dtype
)
negative_mask = ops.cast(
negative_mask, dtype=pairwise_cosine_distances.dtype
)

pos_weights = optim_pos + pairwise_cosine_distances
pos_weights = pos_weights * positive_mask
pos_weights = ops.maximum(pos_weights, 0.0)
neg_weights = optim_neg - pairwise_cosine_distances
neg_weights = neg_weights * negative_mask
neg_weights = ops.maximum(neg_weights, 0.0)

pos_dists = delta_pos - pairwise_cosine_distances
neg_dists = delta_neg - pairwise_cosine_distances

pos_wdists = -1 * gamma * pos_weights * pos_dists
neg_wdists = gamma * neg_weights * neg_dists

p_loss = ops.logsumexp(
ops.where(positive_mask, pos_wdists, float("-inf")),
axis=1,
)
n_loss = ops.logsumexp(
ops.where(negative_mask, neg_wdists, float("-inf")),
axis=1,
)

circle_loss = ops.softplus(p_loss + n_loss)
backend.set_keras_mask(circle_loss, circle_loss > 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on how the loss is used, the mask might not be taken into account. How critical is it?

Copy link
Contributor Author

@ma7555 ma7555 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This masking behaviour is used to mask samples from the batch that has no negative/positive pairs (a solo class with nothing to compare to). When this happens, you want to eleminate it from the loss as it has a loss value of 0 and not masking it can affect the sum_over_batch_size reduction (make the loss lower than reality). In pytorch, they use the AvgNonZeroReducer by taking the mean only of values that is not a zero.

I won't say it is critical because:

  1. Using a data sampler for pair generation like TFDataSampler or any similar sampler solves this (which is a standard in metric learning).
  2. Using a larger batch size.
  3. Not an issue if it happens every now and then.
  4. In tensorflow-similarity version, no masking is applied.
  5. If it happens all the time, it means that the data feeding sampler is the problem not the loss function itself.

return circle_loss
Loading