|
7 | 7 | from keras.src.losses.loss import Loss
|
8 | 8 | from keras.src.losses.loss import squeeze_or_expand_to_same_rank
|
9 | 9 | from keras.src.saving import serialization_lib
|
| 10 | +from keras.src.utils.numerical_utils import build_pos_neg_masks |
10 | 11 | from keras.src.utils.numerical_utils import normalize
|
11 | 12 |
|
12 | 13 |
|
@@ -1403,6 +1404,97 @@ def get_config(self):
|
1403 | 1404 | return config
|
1404 | 1405 |
|
1405 | 1406 |
|
| 1407 | +@keras_export("keras.losses.Circle") |
| 1408 | +class Circle(LossFunctionWrapper): |
| 1409 | + """Computes Circle Loss between integer labels and L2-normalized embeddings. |
| 1410 | +
|
| 1411 | + This is a metric learning loss designed to minimize within-class distance |
| 1412 | + and maximize between-class distance in a flexible manner by dynamically |
| 1413 | + adjusting the penalty strength based on optimization status of each |
| 1414 | + similarity score. |
| 1415 | +
|
| 1416 | + To use Circle Loss effectively, the model should output embeddings without |
| 1417 | + an activation function (such as a `Dense` layer with `activation=None`) |
| 1418 | + followed by UnitNormalization layer to ensure unit-norm embeddings. |
| 1419 | +
|
| 1420 | + Args: |
| 1421 | + gamma: Scaling factor that determines the largest scale of each |
| 1422 | + similarity score. Defaults to `80`. |
| 1423 | + margin: The relaxation factor, below this distance, negatives are |
| 1424 | + up weighted and positives are down weighted. Similarly, above this |
| 1425 | + distance negatives are down weighted and positive are up weighted. |
| 1426 | + Defaults to `0.4`. |
| 1427 | + remove_diagonal: Boolean, whether to remove self-similarities from the |
| 1428 | + positive mask. Defaults to `True`. |
| 1429 | + reduction: Type of reduction to apply to the loss. In almost all cases |
| 1430 | + this should be `"sum_over_batch_size"`. Supported options are |
| 1431 | + `"sum"`, `"sum_over_batch_size"`, `"mean"`, |
| 1432 | + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, |
| 1433 | + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the |
| 1434 | + sample size, and `"mean_with_sample_weight"` sums the loss and |
| 1435 | + divides by the sum of the sample weights. `"none"` and `None` |
| 1436 | + perform no aggregation. Defaults to `"sum_over_batch_size"`. |
| 1437 | + name: Optional name for the loss instance. |
| 1438 | + dtype: The dtype of the loss's computations. Defaults to `None`, which |
| 1439 | + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a |
| 1440 | + `"float32"` unless set to different value |
| 1441 | + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is |
| 1442 | + provided, then the `compute_dtype` will be utilized. |
| 1443 | +
|
| 1444 | + Examples: |
| 1445 | + Usage with the `compile()` API: |
| 1446 | +
|
| 1447 | + ```python |
| 1448 | + model = models.Sequential([ |
| 1449 | + keras.layers.Input(shape=(224, 224, 3)), |
| 1450 | + keras.layers.Conv2D(16, (3, 3), activation='relu'), |
| 1451 | + keras.layers.Flatten(), |
| 1452 | + keras.layers.Dense(64, activation=None), # No activation |
| 1453 | + keras.layers.UnitNormalization() # L2 normalization |
| 1454 | + ]) |
| 1455 | +
|
| 1456 | + model.compile(optimizer="adam", loss=losses.Circle() |
| 1457 | + ``` |
| 1458 | +
|
| 1459 | + Reference: |
| 1460 | + - [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857) |
| 1461 | +
|
| 1462 | + """ |
| 1463 | + |
| 1464 | + def __init__( |
| 1465 | + self, |
| 1466 | + gamma=80.0, |
| 1467 | + margin=0.4, |
| 1468 | + remove_diagonal=True, |
| 1469 | + reduction="sum_over_batch_size", |
| 1470 | + name="circle", |
| 1471 | + dtype=None, |
| 1472 | + ): |
| 1473 | + super().__init__( |
| 1474 | + circle, |
| 1475 | + name=name, |
| 1476 | + reduction=reduction, |
| 1477 | + dtype=dtype, |
| 1478 | + gamma=gamma, |
| 1479 | + margin=margin, |
| 1480 | + remove_diagonal=remove_diagonal, |
| 1481 | + ) |
| 1482 | + self.gamma = gamma |
| 1483 | + self.margin = margin |
| 1484 | + self.remove_diagonal = remove_diagonal |
| 1485 | + |
| 1486 | + def get_config(self): |
| 1487 | + config = Loss.get_config(self) |
| 1488 | + config.update( |
| 1489 | + { |
| 1490 | + "gamma": self.gamma, |
| 1491 | + "margin": self.margin, |
| 1492 | + "remove_diagonal": self.remove_diagonal, |
| 1493 | + } |
| 1494 | + ) |
| 1495 | + return config |
| 1496 | + |
| 1497 | + |
1406 | 1498 | def convert_binary_labels_to_hinge(y_true):
|
1407 | 1499 | """Converts binary labels into -1/1 for hinge loss/metric calculation."""
|
1408 | 1500 | are_zeros = ops.equal(y_true, 0)
|
@@ -2406,3 +2498,91 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
|
2406 | 2498 | )
|
2407 | 2499 |
|
2408 | 2500 | return 1 - tversky
|
| 2501 | + |
| 2502 | + |
| 2503 | +@keras_export("keras.losses.circle") |
| 2504 | +def circle( |
| 2505 | + y_true, |
| 2506 | + y_pred, |
| 2507 | + ref_labels=None, |
| 2508 | + ref_embeddings=None, |
| 2509 | + remove_diagonal=True, |
| 2510 | + gamma=80, |
| 2511 | + margin=0.4, |
| 2512 | +): |
| 2513 | + """Computes the Circle loss. |
| 2514 | +
|
| 2515 | + It is designed to minimize within-class distances and maximize between-class |
| 2516 | + distances in L2 normalized embedding space. |
| 2517 | +
|
| 2518 | + Args: |
| 2519 | + y_true: Tensor with ground truth labels in integer format. |
| 2520 | + y_pred: Tensor with predicted L2 normalized embeddings. |
| 2521 | + ref_labels: Optional integer tensor with labels for reference |
| 2522 | + embeddings. If `None`, defaults to `y_true`. |
| 2523 | + ref_embeddings: Optional tensor with L2 normalized reference embeddings. |
| 2524 | + If `None`, defaults to `y_pred`. |
| 2525 | + remove_diagonal: Boolean, whether to remove self-similarities from |
| 2526 | + positive mask. Defaults to `True`. |
| 2527 | + gamma: Float, scaling factor for the loss. Defaults to `80`. |
| 2528 | + margin: Float, relaxation factor for the loss. Defaults to `0.4`. |
| 2529 | +
|
| 2530 | + Returns: |
| 2531 | + Circle loss value. |
| 2532 | + """ |
| 2533 | + y_pred = ops.convert_to_tensor(y_pred) |
| 2534 | + y_true = ops.cast(y_true, "int32") |
| 2535 | + ref_embeddings = ( |
| 2536 | + y_pred |
| 2537 | + if ref_embeddings is None |
| 2538 | + else ops.convert_to_tensor(ref_embeddings) |
| 2539 | + ) |
| 2540 | + ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32") |
| 2541 | + |
| 2542 | + optim_pos = margin |
| 2543 | + optim_neg = 1 + margin |
| 2544 | + delta_pos = margin |
| 2545 | + delta_neg = 1 - margin |
| 2546 | + |
| 2547 | + pairwise_cosine_distances = 1 - ops.matmul( |
| 2548 | + y_pred, ops.transpose(ref_embeddings) |
| 2549 | + ) |
| 2550 | + |
| 2551 | + pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0) |
| 2552 | + positive_mask, negative_mask = build_pos_neg_masks( |
| 2553 | + y_true, |
| 2554 | + ref_labels, |
| 2555 | + remove_diagonal=remove_diagonal, |
| 2556 | + ) |
| 2557 | + positive_mask = ops.cast( |
| 2558 | + positive_mask, dtype=pairwise_cosine_distances.dtype |
| 2559 | + ) |
| 2560 | + negative_mask = ops.cast( |
| 2561 | + negative_mask, dtype=pairwise_cosine_distances.dtype |
| 2562 | + ) |
| 2563 | + |
| 2564 | + pos_weights = optim_pos + pairwise_cosine_distances |
| 2565 | + pos_weights = pos_weights * positive_mask |
| 2566 | + pos_weights = ops.maximum(pos_weights, 0.0) |
| 2567 | + neg_weights = optim_neg - pairwise_cosine_distances |
| 2568 | + neg_weights = neg_weights * negative_mask |
| 2569 | + neg_weights = ops.maximum(neg_weights, 0.0) |
| 2570 | + |
| 2571 | + pos_dists = delta_pos - pairwise_cosine_distances |
| 2572 | + neg_dists = delta_neg - pairwise_cosine_distances |
| 2573 | + |
| 2574 | + pos_wdists = -1 * gamma * pos_weights * pos_dists |
| 2575 | + neg_wdists = gamma * neg_weights * neg_dists |
| 2576 | + |
| 2577 | + p_loss = ops.logsumexp( |
| 2578 | + ops.where(positive_mask, pos_wdists, float("-inf")), |
| 2579 | + axis=1, |
| 2580 | + ) |
| 2581 | + n_loss = ops.logsumexp( |
| 2582 | + ops.where(negative_mask, neg_wdists, float("-inf")), |
| 2583 | + axis=1, |
| 2584 | + ) |
| 2585 | + |
| 2586 | + circle_loss = ops.softplus(p_loss + n_loss) |
| 2587 | + backend.set_keras_mask(circle_loss, circle_loss > 0) |
| 2588 | + return circle_loss |
0 commit comments