Skip to content

Commit 81d4f6a

Browse files
committed
skip tfjs 3.17 tests
Signed-off-by: Deyu Huang <[email protected]>
1 parent 772dbe6 commit 81d4f6a

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

tests/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
"check_onnxruntime_backend",
2525
"check_tf_min_version",
2626
"check_tf_max_version",
27+
"check_tfjs_min_version",
28+
"check_tfjs_max_version",
2729
"skip_tf_versions",
2830
"skip_tf_cpu",
2931
"check_onnxruntime_min_version",
@@ -272,6 +274,29 @@ def requires_custom_ops(message=""):
272274
can_import = False
273275
return unittest.skipIf(not can_import, reason)
274276

277+
def check_tfjs_max_version(max_accepted_version, message=""):
278+
""" Skip if tfjs_version > max_required_version """
279+
config = get_test_config()
280+
reason = _append_message("conversion requires tensorflowjs <= {}".format(max_accepted_version), message)
281+
try:
282+
import tensorflowjs
283+
can_import = True
284+
except ModuleNotFoundError:
285+
can_import = False
286+
return unittest.skipIf(can_import and not config.skip_tfjs_tests and \
287+
tensorflowjs.__version__ > LooseVersion(max_accepted_version), reason)
288+
289+
def check_tfjs_min_version(min_required_version, message=""):
290+
""" Skip if tjs_version < min_required_version """
291+
config = get_test_config()
292+
reason = _append_message("conversion requires tensorflowjs >= {}".format(min_required_version), message)
293+
try:
294+
import tensorflowjs
295+
can_import = True
296+
except ModuleNotFoundError:
297+
can_import = False
298+
return unittest.skipIf(can_import and not config.skip_tfjs_tests and \
299+
tensorflowjs.__version__ < LooseVersion(min_required_version), reason)
275300

276301
def check_tf_max_version(max_accepted_version, message=""):
277302
""" Skip if tf_version > max_required_version """

tests/test_cond.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def false_fn():
118118
output_names_with_port = ["output:0"]
119119
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
120120

121+
@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
121122
def test_cond_in_while_loop(self):
122123
def func(i, inputs):
123124
inputs_2 = tf.identity(inputs)

tests/test_loops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import tensorflow as tf
88

99
from backend_test_base import Tf2OnnxBackendTestBase
10-
from common import unittest_main, check_tf_min_version, check_tf_max_version, check_onnxruntime_min_version
10+
from common import unittest_main, check_tf_min_version, check_tf_max_version, \
11+
check_onnxruntime_min_version, check_tfjs_max_version
1112
from tf2onnx.tf_loader import is_tf2
1213

1314

@@ -66,6 +67,7 @@ def func(i):
6667
x_val = np.array(3, dtype=np.int32)
6768
self.run_test_case(func, {_INPUT: x_val}, [], [_OUTPUT], rtol=1e-06)
6869

70+
@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
6971
def test_while_loop_with_ta_write(self):
7072
def func(i):
7173
output_ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
@@ -159,6 +161,7 @@ def b(i, res, res2):
159161
output_names_with_port = ["i:0", "x:0", "y:0"]
160162
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
161163

164+
@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
162165
def test_while_loop_with_ta_read_and_write(self):
163166
def func(i, inputs):
164167
inputs_2 = tf.identity(inputs)
@@ -183,6 +186,7 @@ def b(i, out_ta):
183186
output_names_with_port = ["i:0", "output_ta:0"]
184187
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
185188

189+
@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
186190
def test_while_loop_with_multi_scan_outputs(self):
187191
def func(i, inputs1, inputs2):
188192
inputs1_ = tf.identity(inputs1)
@@ -217,6 +221,7 @@ def b(i, out_ta, out_ta2):
217221
output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"]
218222
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
219223

224+
@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
220225
@check_onnxruntime_min_version(
221226
"0.5.0",
222227
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"

0 commit comments

Comments
 (0)