Skip to content

Commit 535f74c

Browse files
authored
Adding tf.strings.reduce_join mapping (#2091)
* Add support to tf.strings.reduce_join * Support for ReduceJoin op * unit test for reduce_join Signed-off-by: Salvetti, Francesco <[email protected]>
1 parent ce29107 commit 535f74c

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

tests/test_string_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def func(text1, text2, text3):
5252
return tf.identity(x_, name=_TFOUTPUT)
5353
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1, _INPUT1: text_val2, _INPUT2: text_val3})
5454

55+
@requires_custom_ops("ReduceJoin")
56+
def test_reduce_join(self):
57+
text_val = np.array([["a", "Test 1 2 3"], ["b", "test test"], ["c", "Hi there Test"]], dtype=np.str)
58+
def func(text):
59+
x_ = tf.strings.reduce_join(text, axis=1, separator="±")
60+
return tf.identity(x_, name=_TFOUTPUT)
61+
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val})
62+
5563
@requires_custom_ops("StringSplit")
5664
@check_tf_min_version("2.0", "result is sparse not ragged in tf1")
5765
def test_string_split(self):

tf2onnx/custom_opsets/string_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import numpy as np
8+
from onnx.numpy_helper import to_array
89
from onnx.onnx_pb import TensorProto
910
from onnx.helper import make_attribute
1011
from tf2onnx import constants, handler
@@ -86,6 +87,30 @@ def version_1(cls, ctx, node, **kwargs):
8687
stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
8788
ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])
8889

90+
@tf_op("ReduceJoin", domain=constants.CONTRIB_OPS_DOMAIN)
91+
class ReduceJoin:
92+
@classmethod
93+
def version_1(cls, ctx, node, **kwargs):
94+
node.domain = constants.CONTRIB_OPS_DOMAIN
95+
node.type = "StringJoin"
96+
axis_node = ctx.get_node_by_output(node.input[1])
97+
axis = axis_node.get_attr_value('value')
98+
utils.make_sure(axis.dims in [[], [1]], "Only a single axis is supported for ReduceJoin node")
99+
axis = to_array(axis)
100+
new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1)))
101+
separator = node.get_attr_value("separator")
102+
if isinstance(separator, bytes):
103+
separator = separator.decode()
104+
separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], object))
105+
ctx.replace_inputs(node, [node.input[0], separator_node.output[0], new_axis_node.output[0]])
106+
keep_dims = node.get_attr_value("keep_dims")
107+
if keep_dims:
108+
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze(
109+
{'data': node.output[0], 'axes': [-1]},
110+
name=node.name + '/Unsqueeze'
111+
)
112+
ctx.insert_node_on_output(ctx.get_node_by_output(unsqueeze_node))
113+
89114
@tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN)
90115
class StringEqual:
91116
@classmethod

0 commit comments

Comments
 (0)