|
5 | 5 | import json
|
6 | 6 | import logging
|
7 | 7 | import numpy as np
|
| 8 | +from onnx.numpy_helper import to_array |
8 | 9 | from onnx.onnx_pb import TensorProto
|
9 | 10 | from onnx.helper import make_attribute
|
10 | 11 | from tf2onnx import constants, handler
|
@@ -86,6 +87,30 @@ def version_1(cls, ctx, node, **kwargs):
|
86 | 87 | stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
|
87 | 88 | ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])
|
88 | 89 |
|
| 90 | +@tf_op("ReduceJoin", domain=constants.CONTRIB_OPS_DOMAIN) |
| 91 | +class ReduceJoin: |
| 92 | + @classmethod |
| 93 | + def version_1(cls, ctx, node, **kwargs): |
| 94 | + node.domain = constants.CONTRIB_OPS_DOMAIN |
| 95 | + node.type = "StringJoin" |
| 96 | + axis_node = ctx.get_node_by_output(node.input[1]) |
| 97 | + axis = axis_node.get_attr_value('value') |
| 98 | + utils.make_sure(axis.dims in [[], [1]], "Only a single axis is supported for ReduceJoin node") |
| 99 | + axis = to_array(axis) |
| 100 | + new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1))) |
| 101 | + separator = node.get_attr_value("separator") |
| 102 | + if isinstance(separator, bytes): |
| 103 | + separator = separator.decode() |
| 104 | + separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], object)) |
| 105 | + ctx.replace_inputs(node, [node.input[0], separator_node.output[0], new_axis_node.output[0]]) |
| 106 | + keep_dims = node.get_attr_value("keep_dims") |
| 107 | + if keep_dims: |
| 108 | + unsqueeze_node = GraphBuilder(ctx).make_unsqueeze( |
| 109 | + {'data': node.output[0], 'axes': [-1]}, |
| 110 | + name=node.name + '/Unsqueeze' |
| 111 | + ) |
| 112 | + ctx.insert_node_on_output(ctx.get_node_by_output(unsqueeze_node)) |
| 113 | + |
89 | 114 | @tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN)
|
90 | 115 | class StringEqual:
|
91 | 116 | @classmethod
|
|
0 commit comments