@@ -1344,6 +1344,19 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
1344
1344
off_on_value = ctx .make_node ("Concat" , [off_value , on_value ], attr = {"axis" : 0 }).output [0 ]
1345
1345
1346
1346
indices = node .input [0 ]
1347
+ indices_rank = ctx .get_rank (indices )
1348
+
1349
+ # Add a special support for 0-rank indices, to do so we have to expand the dimension to 1
1350
+ # before the one hot encoding and remove it after.
1351
+ if indices_rank == 0 :
1352
+ dims = ctx .make_const (name = utils .make_name ('dims' ), np_val = np .array ([1 ], dtype = np .int64 ))
1353
+ indices = ctx .make_node ("Expand" , [indices , dims .name ]).output [0 ]
1354
+
1355
+ # Axis 0 is supported by TensorFlow for the one-hot encoding of a 0-rank tensor. It should behave
1356
+ # as if axis has been set to -1 so we artificially set it as is here.
1357
+ if node .get_attr ('axis' ).i == 0 :
1358
+ node .set_attr ('axis' , - 1 )
1359
+
1347
1360
if ctx .is_target (constants .TARGET_RS6 ) \
1348
1361
and ctx .get_dtype (indices ) != onnx_pb .TensorProto .INT64 :
1349
1362
indices = ctx .make_node ("Cast" , [indices ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
@@ -1367,6 +1380,26 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
1367
1380
ctx .set_dtype (new_node .output [0 ], output_dtype )
1368
1381
ctx .set_shape (new_node .output [0 ], ctx .get_shape (node .output [0 ]))
1369
1382
1383
+ # Remove the dimension artificially added in order to support 0-rank indices
1384
+ if indices_rank == 0 :
1385
+ nodes = [node ]
1386
+ name = utils .make_name (node .name )
1387
+ shape = ctx .get_shape (node .output [0 ])
1388
+ dtype = ctx .get_dtype (node .output [0 ])
1389
+ squeeze_node = GraphBuilder (ctx ).make_squeeze (
1390
+ {
1391
+ "axes" : [0 ],
1392
+ 'data' : node .output [0 ]
1393
+ },
1394
+ name = name ,
1395
+ dtypes = [dtype ],
1396
+ shapes = [shape ],
1397
+ return_node = True )
1398
+ ctx .insert_node_on_output (squeeze_node )
1399
+
1400
+ nodes .append (squeeze_node )
1401
+ ctx .update_node_shape_dtype (node , override = True )
1402
+
1370
1403
@classmethod
1371
1404
def version_9 (cls , ctx , node , ** kwargs ):
1372
1405
cls .any_version_after9 (9 , ctx , node , ** kwargs )
0 commit comments