From c9ee19e14eff65c11c15101cf77bed9c79abc1f0 Mon Sep 17 00:00:00 2001 From: "Salvetti, Francesco" Date: Fri, 7 Jul 2023 15:26:25 +0000 Subject: [PATCH 1/3] fix wrong shapes in loop body inputs if shape invariances are set in TF Signed-off-by: Salvetti, Francesco --- tf2onnx/onnx_opset/controlflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index b6f70cedf..7aae50c29 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -571,7 +571,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_ g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64) g.inputs = [g.get_node_by_output(inp) for inp in func_inputs] - for p, c in zip(loop_node.input, func_inputs): + # we should use outputs shape, not inputs, since there may be shape invariants + for p, c in zip(loop_node.output, func_inputs[2:]): g.copy_shape(p, c) for i, node in enumerate(g.inputs): From 6858c4d02e40a27703180822b67fab0518a1a710 Mon Sep 17 00:00:00 2001 From: "Salvetti, Francesco" Date: Mon, 24 Jul 2023 10:20:53 +0000 Subject: [PATCH 2/3] fix and enable test for TF2 Signed-off-by: Salvetti, Francesco --- tests/test_loops.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_loops.py b/tests/test_loops.py index 410bee378..5fbd39e11 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -286,15 +286,12 @@ def func(x, y): self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5) @check_tf_min_version("1.9") - @check_tf_max_version("1.15") def test_simple_while_loop_var_shape(self): # test for while_loop with variant shape variables - # may not meet ONNX Loop spec - # Note: this is not working on tf2 itself. def func(i): const = tf.constant(np.array([2], dtype=np.int32)) c = lambda i: tf.reduce_all(tf.shape(i) < 10) - b = lambda i: tf.concat([i, const], 0) + b = lambda i: [tf.concat([i, const], 0)] r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])]) return tf.identity(r, name="output") input_names_with_port = ["input_1:0"] From d27a1020f4126d423b872d1d8f24e893dced249c Mon Sep 17 00:00:00 2001 From: "Salvetti, Francesco" Date: Mon, 24 Jul 2023 15:53:04 +0000 Subject: [PATCH 3/3] skip tflite due to infinite loop Signed-off-by: Salvetti, Francesco --- tests/test_loops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_loops.py b/tests/test_loops.py index 5fbd39e11..f09c56c61 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -7,8 +7,8 @@ import tensorflow as tf from backend_test_base import Tf2OnnxBackendTestBase -from common import unittest_main, check_tf_min_version, check_tf_max_version, \ - check_onnxruntime_min_version, check_tfjs_max_version +from common import unittest_main, check_tf_min_version, \ + check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite from tf2onnx.tf_loader import is_tf2 @@ -286,6 +286,7 @@ def func(x, y): self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5) @check_tf_min_version("1.9") + @skip_tflite("infinite loop with tflite") def test_simple_while_loop_var_shape(self): # test for while_loop with variant shape variables def func(i):