Skip to content

Commit 8029c78

Browse files
authored
Added IS_THREAD_SAFE backend flag (#20383)
* Fix #20382 * Added `IS_THREAD_SAFE` backend flag * Reference to tensorflow/tensorflow#78338
1 parent 96158dc commit 8029c78

File tree

10 files changed

+44
-0
lines changed

10 files changed

+44
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import concurrent
2+
3+
import numpy as np
4+
5+
from keras.src import backend
6+
from keras.src import ops
7+
from keras.src import testing
8+
9+
10+
class TestThreadSafe(testing.TestCase):
11+
def test_is_thread_safe(self):
12+
if backend.IS_THREAD_SAFE:
13+
executor = concurrent.futures.ThreadPoolExecutor()
14+
15+
def sum(x, axis):
16+
return ops.sum(x, axis=axis)
17+
18+
futures = []
19+
20+
for i in range(10000):
21+
futures.clear()
22+
x = ops.convert_to_tensor(np.random.rand(100, 100))
23+
futures.append(executor.submit(sum, x, 1))
24+
x = ops.convert_to_tensor(np.random.rand(100))
25+
futures.append(executor.submit(sum, x, 0))
26+
concurrent.futures.wait(
27+
futures, return_when=concurrent.futures.ALL_COMPLETED
28+
)
29+
[future.result() for future in futures]

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.backend.jax import nn
88
from keras.src.backend.jax import numpy
99
from keras.src.backend.jax import random
10+
from keras.src.backend.jax.core import IS_THREAD_SAFE
1011
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
1112
from keras.src.backend.jax.core import Variable
1213
from keras.src.backend.jax.core import cast

keras/src/backend/jax/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src.backend.jax import distribution_lib
1515

1616
SUPPORTS_SPARSE_TENSORS = True
17+
IS_THREAD_SAFE = True
1718

1819

1920
class Variable(KerasVariable):

keras/src/backend/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.src.backend.numpy import nn
77
from keras.src.backend.numpy import numpy
88
from keras.src.backend.numpy import random
9+
from keras.src.backend.numpy.core import IS_THREAD_SAFE
910
from keras.src.backend.numpy.core import SUPPORTS_SPARSE_TENSORS
1011
from keras.src.backend.numpy.core import Variable
1112
from keras.src.backend.numpy.core import cast

keras/src/backend/numpy/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from keras.src.backend.common.symbolic_scope import SymbolicScope
1616

1717
SUPPORTS_SPARSE_TENSORS = False
18+
IS_THREAD_SAFE = True
1819

1920

2021
class Variable(KerasVariable):

keras/src/backend/tensorflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.backend.tensorflow import numpy
88
from keras.src.backend.tensorflow import random
99
from keras.src.backend.tensorflow import tensorboard
10+
from keras.src.backend.tensorflow.core import IS_THREAD_SAFE
1011
from keras.src.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS
1112
from keras.src.backend.tensorflow.core import Variable
1213
from keras.src.backend.tensorflow.core import cast

keras/src/backend/tensorflow/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from keras.src.utils.naming import auto_name
1919

2020
SUPPORTS_SPARSE_TENSORS = True
21+
# https://github.com/tensorflow/tensorflow/issues/78338
22+
IS_THREAD_SAFE = False
2123

2224

2325
class Variable(

keras/src/backend/torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from keras.src.backend.torch import nn
2323
from keras.src.backend.torch import numpy
2424
from keras.src.backend.torch import random
25+
from keras.src.backend.torch.core import IS_THREAD_SAFE
2526
from keras.src.backend.torch.core import SUPPORTS_SPARSE_TENSORS
2627
from keras.src.backend.torch.core import Variable
2728
from keras.src.backend.torch.core import cast

keras/src/backend/torch/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from keras.src.backend.config import floatx
2121

2222
SUPPORTS_SPARSE_TENSORS = False
23+
IS_THREAD_SAFE = True
2324

2425
# Some operators such as 'aten::_foreach_mul_.Scalar'
2526
# are not currently implemented for the MPS device.

keras/src/callbacks/callback_list.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import concurrent.futures
22

3+
from keras.src import backend
34
from keras.src import tree
45
from keras.src import utils
56
from keras.src.api_export import keras_export
@@ -39,6 +40,9 @@ def __init__(
3940
"""
4041
self.callbacks = tree.flatten(callbacks) if callbacks else []
4142
self._executor = None
43+
self._async_train = False
44+
self._async_test = False
45+
self._async_predict = False
4246
self._futures = []
4347
self._configure_async_dispatch(callbacks)
4448
self._add_default_callbacks(add_history, add_progbar)
@@ -53,6 +57,8 @@ def set_params(self, params):
5357

5458
def _configure_async_dispatch(self, callbacks):
5559
# Determine whether callbacks can be dispatched asynchronously.
60+
if not backend.IS_THREAD_SAFE:
61+
return
5662
async_train = True
5763
async_test = True
5864
async_predict = True

0 commit comments

Comments
 (0)