Skip to content

Commit a3a368d

Browse files
Addition of Sparsemax activation (#20558)
* add: sprsemax ops * add: sparsemax api references to inits * add: sparsemax tests * edit: changes after test * edit: test case * rename: function in numpy * add: pointers to rest inits * edit: docstrings * change: x to logits in docstring
1 parent 75522e4 commit a3a368d

File tree

16 files changed

+217
-1
lines changed

16 files changed

+217
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ dist/**
1818
examples/**/*.jpg
1919
.python-version
2020
.coverage
21-
*coverage.xml
21+
*coverage.xml
22+
.ruff_cache

keras/api/_tf_keras/keras/activations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from keras.src.activations.activations import softplus
3434
from keras.src.activations.activations import softsign
3535
from keras.src.activations.activations import sparse_plus
36+
from keras.src.activations.activations import sparsemax
3637
from keras.src.activations.activations import squareplus
3738
from keras.src.activations.activations import tanh
3839
from keras.src.activations.activations import tanh_shrink

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from keras.src.ops.nn import softsign
101101
from keras.src.ops.nn import sparse_categorical_crossentropy
102102
from keras.src.ops.nn import sparse_plus
103+
from keras.src.ops.nn import sparsemax
103104
from keras.src.ops.nn import squareplus
104105
from keras.src.ops.nn import tanh_shrink
105106
from keras.src.ops.numpy import abs

keras/api/_tf_keras/keras/ops/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,6 @@
4545
from keras.src.ops.nn import softsign
4646
from keras.src.ops.nn import sparse_categorical_crossentropy
4747
from keras.src.ops.nn import sparse_plus
48+
from keras.src.ops.nn import sparsemax
4849
from keras.src.ops.nn import squareplus
4950
from keras.src.ops.nn import tanh_shrink

keras/api/activations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from keras.src.activations.activations import softplus
3434
from keras.src.activations.activations import softsign
3535
from keras.src.activations.activations import sparse_plus
36+
from keras.src.activations.activations import sparsemax
3637
from keras.src.activations.activations import squareplus
3738
from keras.src.activations.activations import tanh
3839
from keras.src.activations.activations import tanh_shrink

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from keras.src.ops.nn import softsign
101101
from keras.src.ops.nn import sparse_categorical_crossentropy
102102
from keras.src.ops.nn import sparse_plus
103+
from keras.src.ops.nn import sparsemax
103104
from keras.src.ops.nn import squareplus
104105
from keras.src.ops.nn import tanh_shrink
105106
from keras.src.ops.numpy import abs

keras/api/ops/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,6 @@
4545
from keras.src.ops.nn import softsign
4646
from keras.src.ops.nn import sparse_categorical_crossentropy
4747
from keras.src.ops.nn import sparse_plus
48+
from keras.src.ops.nn import sparsemax
4849
from keras.src.ops.nn import squareplus
4950
from keras.src.ops.nn import tanh_shrink

keras/src/activations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from keras.src.activations.activations import softplus
2525
from keras.src.activations.activations import softsign
2626
from keras.src.activations.activations import sparse_plus
27+
from keras.src.activations.activations import sparsemax
2728
from keras.src.activations.activations import squareplus
2829
from keras.src.activations.activations import tanh
2930
from keras.src.activations.activations import tanh_shrink
@@ -59,6 +60,7 @@
5960
mish,
6061
log_softmax,
6162
log_sigmoid,
63+
sparsemax,
6264
}
6365

6466
ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}

keras/src/activations/activations.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,3 +617,28 @@ def log_softmax(x, axis=-1):
617617
axis: Integer, axis along which the softmax is applied.
618618
"""
619619
return ops.log_softmax(x, axis=axis)
620+
621+
622+
@keras_export(["keras.activations.sparsemax"])
623+
def sparsemax(x, axis=-1):
624+
"""Sparsemax activation function.
625+
626+
For each batch `i`, and class `j`,
627+
sparsemax activation function is defined as:
628+
629+
`sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
630+
631+
Args:
632+
x: Input tensor.
633+
axis: `int`, axis along which the sparsemax operation is applied.
634+
635+
Returns:
636+
A tensor, output of sparsemax transformation. Has the same type and
637+
shape as `x`.
638+
639+
Reference:
640+
641+
- [Martins et.al., 2016](https://arxiv.org/abs/1602.02068)
642+
"""
643+
x = backend.convert_to_tensor(x)
644+
return ops.sparsemax(x, axis)

keras/src/activations/activations_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,55 @@ def test_linear(self):
896896
x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32)
897897
self.assertAllClose(x_int32, activations.linear(x_int32))
898898

899+
def test_sparsemax(self):
900+
# result check with 1d
901+
x_1d = np.linspace(1, 12, num=12)
902+
expected_result = np.zeros_like(x_1d)
903+
expected_result[-1] = 1.0
904+
self.assertAllClose(expected_result, activations.sparsemax(x_1d))
905+
906+
# result check with 2d
907+
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
908+
expected_result = np.zeros_like(x_2d)
909+
expected_result[:, -1] = 1.0
910+
self.assertAllClose(expected_result, activations.sparsemax(x_2d))
911+
912+
# result check with 3d
913+
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
914+
expected_result = np.zeros_like(x_3d)
915+
expected_result[:, :, -1] = 1.0
916+
self.assertAllClose(expected_result, activations.sparsemax(x_3d))
917+
918+
# result check with axis=-2 with 2d input
919+
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
920+
expected_result = np.zeros_like(x_2d)
921+
expected_result[-1, :] = 1.0
922+
self.assertAllClose(
923+
expected_result, activations.sparsemax(x_2d, axis=-2)
924+
)
925+
926+
# result check with axis=-2 with 3d input
927+
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
928+
expected_result = np.ones_like(x_3d)
929+
self.assertAllClose(
930+
expected_result, activations.sparsemax(x_3d, axis=-2)
931+
)
932+
933+
# result check with axis=-3 with 3d input
934+
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
935+
expected_result = np.zeros_like(x_3d)
936+
expected_result[-1, :, :] = 1.0
937+
self.assertAllClose(
938+
expected_result, activations.sparsemax(x_3d, axis=-3)
939+
)
940+
941+
# result check with axis=-3 with 4d input
942+
x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2)
943+
expected_result = np.ones_like(x_4d)
944+
self.assertAllClose(
945+
expected_result, activations.sparsemax(x_4d, axis=-3)
946+
)
947+
899948
def test_get_method(self):
900949
obj = activations.get("relu")
901950
self.assertEqual(obj, activations.relu)

0 commit comments

Comments
 (0)