Skip to content

Commit 98ac5a8

Browse files
committed
add while loop cond subgrpah test
1 parent 43e5501 commit 98ac5a8

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

tests/test_loops.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from backend_test_base import Tf2OnnxBackendTestBase
1010
from common import unittest_main, check_tf_min_version, check_tf_max_version, \
11-
check_onnxruntime_min_version, check_tfjs_max_version
11+
check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite
1212
from tf2onnx.tf_loader import is_tf2
1313

1414

@@ -302,6 +302,21 @@ def func(i):
302302
output_names_with_port = ["output:0"]
303303
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
304304

305+
@skip_tflite("shape inference fails with tflite")
306+
def test_while_loop_cond_subgraphs(self):
307+
# test for while_loop with subgraphs in cond
308+
def func(x):
309+
x_dim = tf.shape(x)[0]
310+
r = tf.cast(tf.zeros(1), x.dtype)
311+
for i in range(tf.constant(10)):
312+
if i == x_dim:
313+
break
314+
r += x[i]
315+
return tf.identity(r, name="output")
316+
input_names_with_port = ["input_1:0"]
317+
feed_dict = {"input_1:0": np.arange(0, 15, dtype=np.int32)}
318+
output_names_with_port = ["output:0"]
319+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
305320

306321
if __name__ == '__main__':
307-
unittest_main()
322+
unittest_main()

0 commit comments

Comments
 (0)