Skip to content

Commit 982f0cb

Browse files
Add support for one hot encoding of 0-rank tensor
Signed-off-by: Dagnas <[email protected]>
1 parent 5dfd36f commit 982f0cb

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tests/test_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,17 @@ def func(x):
24212421
graph = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
24222422
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
24232423

2424+
@check_opset_min_version(9, "onehot")
2425+
def test_onehot_rank0(self):
2426+
depth = 5
2427+
for np_dtype in [np.int32, np.int64]:
2428+
x_val = np.array(3, dtype=np_dtype)
2429+
for axis in [-1, 0]:
2430+
def func(x):
2431+
x_ = tf.one_hot(x, depth, axis=axis)
2432+
return tf.identity(x_, name=_TFOUTPUT)
2433+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2434+
24242435
@skip_caffe2_backend("issue undefined dim 1")
24252436
@check_tf_max_version("1.15", "not supported in tf-2.0")
24262437
def test_flatten0(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,19 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
13441344
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
13451345

13461346
indices = node.input[0]
1347+
indices_rank = ctx.get_rank(indices)
1348+
1349+
# Add a special support for 0-rank indices, to do so we have to expand the dimension to 1
1350+
# before the one hot encoding and remove it after.
1351+
if indices_rank == 0:
1352+
dims = ctx.make_const(name=utils.make_name('dims'), np_val=np.array([1], dtype=np.int64))
1353+
indices = ctx.make_node("Expand", [indices, dims.name]).output[0]
1354+
1355+
# Axis 0 is supported by TensorFlow for the one-hot encoding of a 0-rank tensor. It should behave
1356+
# as if axis has been set to -1 so we artificially set it as is here.
1357+
if node.get_attr('axis').i == 0:
1358+
node.set_attr('axis', -1)
1359+
13471360
if ctx.is_target(constants.TARGET_RS6) \
13481361
and ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
13491362
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
@@ -1367,6 +1380,26 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
13671380
ctx.set_dtype(new_node.output[0], output_dtype)
13681381
ctx.set_shape(new_node.output[0], ctx.get_shape(node.output[0]))
13691382

1383+
# Remove the dimension artificially added in order to support 0-rank indices
1384+
if indices_rank == 0:
1385+
nodes = [node]
1386+
name = utils.make_name(node.name)
1387+
shape = ctx.get_shape(node.output[0])
1388+
dtype = ctx.get_dtype(node.output[0])
1389+
squeeze_node = GraphBuilder(ctx).make_squeeze(
1390+
{
1391+
"axes": [0],
1392+
'data': node.output[0]
1393+
},
1394+
name=name,
1395+
dtypes=[dtype],
1396+
shapes=[shape],
1397+
return_node=True)
1398+
ctx.insert_node_on_output(squeeze_node)
1399+
1400+
nodes.append(squeeze_node)
1401+
ctx.update_node_shape_dtype(node, override=True)
1402+
13701403
@classmethod
13711404
def version_9(cls, ctx, node, **kwargs):
13721405
cls.any_version_after9(9, ctx, node, **kwargs)

0 commit comments

Comments
 (0)