diff --git a/tests/test_backend.py b/tests/test_backend.py index 1060574c1..225f9461a 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1685,6 +1685,23 @@ def func(x1, x2, x3): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3}) + def test_concat_negative_axis_none_shape(self): + x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3)) + y_val = np.array([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], dtype=np.float32).reshape((2, 3)) + s1_val = np.array([1, 1], dtype=np.int32) + s2_val = np.array([1, 1], dtype=np.int32) + def func(): + x = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT) + y = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT1) + s1 = tf_placeholder(tf.int32, [2], name="input3") + s2 = tf_placeholder(tf.int32, [2], name="input4") + s = tf.add(s1, s2) + x_with_none_shape = tf.slice(x, [0, 0], s) + t = tf.concat([x_with_none_shape, y], -1) + return tf.identity(t, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, "input3:0": s1_val, "input4:0": s2_val}, + as_session=True, premade_placeholders=True) + def test_concat_const_string(self): x_val1 = np.array([["Hello world", "abc"], ["def", "♦♥♠♣"]], dtype=str) const_val = np.array([["Hello there", "wxyz"], ["", "π"]], dtype=str) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index ca4daf0d1..bb9b0cb52 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -297,8 +297,13 @@ def version_1(cls, ctx, node, **kwargs): ctx.remove_input(node, node.input[-1], len(node.input) - 1) if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports. - input_shape = ctx.get_shape(node.input[0]) - utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0])) + input_shape = None + for node_input in node.input: + input_shape = ctx.get_shape(node_input) + if input_shape is not None: + break + utils.make_sure(input_shape is not None, + "the shapes of the following inputs are None: {}".format(', '.join(node.input))) axis_val = len(input_shape) + axis_val node.set_attr("axis", axis_val)