File tree Expand file tree Collapse file tree 10 files changed +44
-0
lines changed Expand file tree Collapse file tree 10 files changed +44
-0
lines changed Original file line number Diff line number Diff line change
1
+ import concurrent
2
+
3
+ import numpy as np
4
+
5
+ from keras .src import backend
6
+ from keras .src import ops
7
+ from keras .src import testing
8
+
9
+
10
+ class TestThreadSafe (testing .TestCase ):
11
+ def test_is_thread_safe (self ):
12
+ if backend .IS_THREAD_SAFE :
13
+ executor = concurrent .futures .ThreadPoolExecutor ()
14
+
15
+ def sum (x , axis ):
16
+ return ops .sum (x , axis = axis )
17
+
18
+ futures = []
19
+
20
+ for i in range (10000 ):
21
+ futures .clear ()
22
+ x = ops .convert_to_tensor (np .random .rand (100 , 100 ))
23
+ futures .append (executor .submit (sum , x , 1 ))
24
+ x = ops .convert_to_tensor (np .random .rand (100 ))
25
+ futures .append (executor .submit (sum , x , 0 ))
26
+ concurrent .futures .wait (
27
+ futures , return_when = concurrent .futures .ALL_COMPLETED
28
+ )
29
+ [future .result () for future in futures ]
Original file line number Diff line number Diff line change 7
7
from keras .src .backend .jax import nn
8
8
from keras .src .backend .jax import numpy
9
9
from keras .src .backend .jax import random
10
+ from keras .src .backend .jax .core import IS_THREAD_SAFE
10
11
from keras .src .backend .jax .core import SUPPORTS_SPARSE_TENSORS
11
12
from keras .src .backend .jax .core import Variable
12
13
from keras .src .backend .jax .core import cast
Original file line number Diff line number Diff line change 14
14
from keras .src .backend .jax import distribution_lib
15
15
16
16
SUPPORTS_SPARSE_TENSORS = True
17
+ IS_THREAD_SAFE = True
17
18
18
19
19
20
class Variable (KerasVariable ):
Original file line number Diff line number Diff line change 6
6
from keras .src .backend .numpy import nn
7
7
from keras .src .backend .numpy import numpy
8
8
from keras .src .backend .numpy import random
9
+ from keras .src .backend .numpy .core import IS_THREAD_SAFE
9
10
from keras .src .backend .numpy .core import SUPPORTS_SPARSE_TENSORS
10
11
from keras .src .backend .numpy .core import Variable
11
12
from keras .src .backend .numpy .core import cast
Original file line number Diff line number Diff line change 15
15
from keras .src .backend .common .symbolic_scope import SymbolicScope
16
16
17
17
SUPPORTS_SPARSE_TENSORS = False
18
+ IS_THREAD_SAFE = True
18
19
19
20
20
21
class Variable (KerasVariable ):
Original file line number Diff line number Diff line change 7
7
from keras .src .backend .tensorflow import numpy
8
8
from keras .src .backend .tensorflow import random
9
9
from keras .src .backend .tensorflow import tensorboard
10
+ from keras .src .backend .tensorflow .core import IS_THREAD_SAFE
10
11
from keras .src .backend .tensorflow .core import SUPPORTS_SPARSE_TENSORS
11
12
from keras .src .backend .tensorflow .core import Variable
12
13
from keras .src .backend .tensorflow .core import cast
Original file line number Diff line number Diff line change 18
18
from keras .src .utils .naming import auto_name
19
19
20
20
SUPPORTS_SPARSE_TENSORS = True
21
+ # https://github.com/tensorflow/tensorflow/issues/78338
22
+ IS_THREAD_SAFE = False
21
23
22
24
23
25
class Variable (
Original file line number Diff line number Diff line change 22
22
from keras .src .backend .torch import nn
23
23
from keras .src .backend .torch import numpy
24
24
from keras .src .backend .torch import random
25
+ from keras .src .backend .torch .core import IS_THREAD_SAFE
25
26
from keras .src .backend .torch .core import SUPPORTS_SPARSE_TENSORS
26
27
from keras .src .backend .torch .core import Variable
27
28
from keras .src .backend .torch .core import cast
Original file line number Diff line number Diff line change 20
20
from keras .src .backend .config import floatx
21
21
22
22
SUPPORTS_SPARSE_TENSORS = False
23
+ IS_THREAD_SAFE = True
23
24
24
25
# Some operators such as 'aten::_foreach_mul_.Scalar'
25
26
# are not currently implemented for the MPS device.
Original file line number Diff line number Diff line change 1
1
import concurrent .futures
2
2
3
+ from keras .src import backend
3
4
from keras .src import tree
4
5
from keras .src import utils
5
6
from keras .src .api_export import keras_export
@@ -39,6 +40,9 @@ def __init__(
39
40
"""
40
41
self .callbacks = tree .flatten (callbacks ) if callbacks else []
41
42
self ._executor = None
43
+ self ._async_train = False
44
+ self ._async_test = False
45
+ self ._async_predict = False
42
46
self ._futures = []
43
47
self ._configure_async_dispatch (callbacks )
44
48
self ._add_default_callbacks (add_history , add_progbar )
@@ -53,6 +57,8 @@ def set_params(self, params):
53
57
54
58
def _configure_async_dispatch (self , callbacks ):
55
59
# Determine whether callbacks can be dispatched asynchronously.
60
+ if not backend .IS_THREAD_SAFE :
61
+ return
56
62
async_train = True
57
63
async_test = True
58
64
async_predict = True
You can’t perform that action at this time.
0 commit comments