diff --git a/tests/test_string_ops.py b/tests/test_string_ops.py index ea6f38034..e27ee8927 100644 --- a/tests/test_string_ops.py +++ b/tests/test_string_ops.py @@ -52,6 +52,14 @@ def func(text1, text2, text3): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1, _INPUT1: text_val2, _INPUT2: text_val3}) + @requires_custom_ops("ReduceJoin") + def test_reduce_join(self): + text_val = np.array([["a", "Test 1 2 3"], ["b", "test test"], ["c", "Hi there Test"]], dtype=np.str) + def func(text): + x_ = tf.strings.reduce_join(text, axis=1, separator="±") + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: text_val}) + @requires_custom_ops("StringSplit") @check_tf_min_version("2.0", "result is sparse not ragged in tf1") def test_string_split(self): diff --git a/tf2onnx/custom_opsets/string_ops.py b/tf2onnx/custom_opsets/string_ops.py index 303fcd94b..2e777e3c5 100644 --- a/tf2onnx/custom_opsets/string_ops.py +++ b/tf2onnx/custom_opsets/string_ops.py @@ -5,6 +5,7 @@ import json import logging import numpy as np +from onnx.numpy_helper import to_array from onnx.onnx_pb import TensorProto from onnx.helper import make_attribute from tf2onnx import constants, handler @@ -86,6 +87,30 @@ def version_1(cls, ctx, node, **kwargs): stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0}) ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]]) +@tf_op("ReduceJoin", domain=constants.CONTRIB_OPS_DOMAIN) +class ReduceJoin: + @classmethod + def version_1(cls, ctx, node, **kwargs): + node.domain = constants.CONTRIB_OPS_DOMAIN + node.type = "StringJoin" + axis_node = ctx.get_node_by_output(node.input[1]) + axis = axis_node.get_attr_value('value') + utils.make_sure(axis.dims in [[], [1]], "Only a single axis is supported for ReduceJoin node") + axis = to_array(axis) + new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1))) + separator = node.get_attr_value("separator") + if isinstance(separator, bytes): + separator = separator.decode() + separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], object)) + ctx.replace_inputs(node, [node.input[0], separator_node.output[0], new_axis_node.output[0]]) + keep_dims = node.get_attr_value("keep_dims") + if keep_dims: + unsqueeze_node = GraphBuilder(ctx).make_unsqueeze( + {'data': node.output[0], 'axes': [-1]}, + name=node.name + '/Unsqueeze' + ) + ctx.insert_node_on_output(ctx.get_node_by_output(unsqueeze_node)) + @tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN) class StringEqual: @classmethod