Skip to content

Commit 1d5774e

Browse files
authored
Add parameter axis to tversky loss (#20563)
* Add axis to tversky loss * Add tests for tversky loss * Fiz line too long error * Reformat code
1 parent a3a368d commit 1d5774e

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

keras/src/losses/losses.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,7 @@ def __init__(
13881388
beta=0.5,
13891389
reduction="sum_over_batch_size",
13901390
name="tversky",
1391+
axis=None,
13911392
dtype=None,
13921393
):
13931394
super().__init__(
@@ -1397,13 +1398,17 @@ def __init__(
13971398
dtype=dtype,
13981399
alpha=alpha,
13991400
beta=beta,
1401+
axis=axis,
14001402
)
14011403
self.alpha = alpha
14021404
self.beta = beta
1405+
self.axis = axis
14031406

14041407
def get_config(self):
14051408
config = Loss.get_config(self)
1406-
config.update({"alpha": self.alpha, "beta": self.beta})
1409+
config.update(
1410+
{"alpha": self.alpha, "beta": self.beta, "axis": self.axis}
1411+
)
14071412
return config
14081413

14091414

@@ -2465,7 +2470,7 @@ def dice(y_true, y_pred, axis=None):
24652470

24662471

24672472
@keras_export("keras.losses.tversky")
2468-
def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
2473+
def tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None):
24692474
"""Computes the Tversky loss value between `y_true` and `y_pred`.
24702475
24712476
This loss function is weighted by the alpha and beta coefficients
@@ -2479,6 +2484,7 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
24792484
y_pred: tensor of predicted targets.
24802485
alpha: coefficient controlling incidence of false positives.
24812486
beta: coefficient controlling incidence of false negatives.
2487+
axis: tuple for which dimensions the loss is calculated.
24822488
24832489
Returns:
24842490
Tversky loss value.
@@ -2490,12 +2496,13 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
24902496
y_pred = ops.convert_to_tensor(y_pred)
24912497
y_true = ops.cast(y_true, y_pred.dtype)
24922498

2493-
inputs = ops.reshape(y_true, [-1])
2494-
targets = ops.reshape(y_pred, [-1])
2499+
inputs = y_true
2500+
targets = y_pred
2501+
2502+
intersection = ops.sum(inputs * targets, axis=axis)
2503+
fp = ops.sum((1 - targets) * inputs, axis=axis)
2504+
fn = ops.sum(targets * (1 - inputs), axis=axis)
24952505

2496-
intersection = ops.sum(inputs * targets)
2497-
fp = ops.sum((1 - targets) * inputs)
2498-
fn = ops.sum(targets * (1 - inputs))
24992506
tversky = ops.divide(
25002507
intersection,
25012508
intersection + fp * alpha + fn * beta + backend.epsilon(),

keras/src/losses/losses_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,16 @@ def test_binary_segmentation(self):
16301630
output = losses.Tversky()(y_true, y_pred)
16311631
self.assertAllClose(output, 0.77777773)
16321632

1633+
def test_binary_segmentation_with_axis(self):
1634+
y_true = np.array(
1635+
[[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]
1636+
)
1637+
y_pred = np.array(
1638+
[[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]
1639+
)
1640+
output = losses.Tversky(axis=(1, 2, 3), reduction=None)(y_true, y_pred)
1641+
self.assertAllClose(output, [0.5, 0.75757575])
1642+
16331643
def test_binary_segmentation_custom_coefficients(self):
16341644
y_true = np.array(
16351645
([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
@@ -1640,6 +1650,18 @@ def test_binary_segmentation_custom_coefficients(self):
16401650
output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred)
16411651
self.assertAllClose(output, 0.7916667)
16421652

1653+
def test_binary_segmentation_custom_coefficients_with_axis(self):
1654+
y_true = np.array(
1655+
[[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]
1656+
)
1657+
y_pred = np.array(
1658+
[[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]
1659+
)
1660+
output = losses.Tversky(
1661+
alpha=0.2, beta=0.8, axis=(1, 2, 3), reduction=None
1662+
)(y_true, y_pred)
1663+
self.assertAllClose(output, [0.5, 0.7222222])
1664+
16431665
def test_dtype_arg(self):
16441666
y_true = np.array(([[1, 2], [1, 2]]))
16451667
y_pred = np.array(([[4, 1], [6, 1]]))

0 commit comments

Comments
 (0)