@@ -2364,32 +2364,45 @@ def version_10(cls, ctx, node, **kwargs):
2364
2364
2365
2365
2366
2366
@tf_op ("Unique" , onnx_op = "Unique" )
2367
+ @tf_op ("UniqueWithCounts" , onnx_op = "Unique" )
2367
2368
class Unique :
2369
+ int_cast = [TensorProto .BOOL , TensorProto .INT32 , TensorProto .INT16 , TensorProto .UINT8 ,
2370
+ TensorProto .UINT16 , TensorProto .UINT32 , TensorProto .UINT64 ]
2371
+ dtype_map = {k : TensorProto .INT64 for k in int_cast }
2372
+ dtype_map [TensorProto .DOUBLE ] = TensorProto .FLOAT
2373
+
2368
2374
@classmethod
2369
2375
def version_11 (cls , ctx , node , ** kwargs ):
2370
2376
# opset 11 supports explicitly
2371
- dtypes = node .output_dtypes
2372
2377
node_name = node .name
2373
2378
node_inputs = node .input
2374
2379
node_outputs = node .output
2380
+ inp_dtype = ctx .get_dtype (node .input [0 ])
2381
+
2375
2382
ctx .remove_node (node_name )
2376
- if dtypes [0 ] in [TensorProto .INT32 , TensorProto .INT16 , TensorProto .UINT8 , TensorProto .UINT16 ]:
2377
- inp_cast = ctx .make_node ("Cast" , [node_inputs [0 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
2383
+
2384
+ # due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
2385
+ if inp_dtype in cls .dtype_map :
2386
+ inp_cast = ctx .make_node ("Cast" , [node_inputs [0 ]], attr = {'to' : cls .dtype_map [inp_dtype ]}).output [0 ]
2378
2387
node_inputs [0 ] = inp_cast
2379
- new_node = ctx .make_node ("Unique" , node_inputs , name = node_name , output_count = 3 , attr = {'sorted' : 0 })
2388
+
2389
+ new_node = ctx .make_node ("Unique" , node_inputs , name = node_name , attr = {'sorted' : 0 },
2390
+ outputs = [utils .make_name ("y" ), utils .make_name ("idx_first" ),
2391
+ utils .make_name ("idx" ), utils .make_name ("counts" )])
2380
2392
ctx .replace_all_inputs (node_outputs [0 ], new_node .output [0 ])
2381
2393
ctx .replace_all_inputs (node_outputs [1 ], new_node .output [2 ])
2382
- if ctx .get_dtype (new_node .output [0 ]) != dtypes [0 ]:
2383
- ctx .insert_new_node_on_output ("Cast" , new_node .output [0 ], name = utils .make_name (node .name ) + "_cast" ,
2384
- to = dtypes [0 ])
2385
- if len (node_outputs ) > 1 :
2386
- # cast to int64 if needed
2387
- if dtypes [1 ] != onnx_pb .TensorProto .INT64 :
2388
- cast_node = ctx .insert_new_node_on_output ("Cast" , new_node .output [2 ],
2389
- name = utils .make_name (node .name ) + "_cast" ,
2390
- to = dtypes [1 ])
2391
- ctx .set_dtype (cast_node .output [0 ], dtypes [1 ])
2392
- ctx .copy_shape (new_node .output [2 ], cast_node .output [0 ])
2394
+ if len (node_outputs ) == 3 : # we need counts too (UniqueWithCounts)
2395
+ ctx .replace_all_inputs (node_outputs [2 ], new_node .output [3 ])
2396
+ if ctx .get_dtype (new_node .output [0 ]) != inp_dtype :
2397
+ ctx .insert_new_node_on_output ("Cast" , new_node .output [0 ], to = inp_dtype ,
2398
+ name = utils .make_name (node .name ) + "_cast" )
2399
+
2400
+ # cast idx and counts if needed
2401
+ out_dtype = node .get_attr_value ('out_idx' )
2402
+ if out_dtype != TensorProto .INT64 :
2403
+ for i in range (1 , len (node_outputs )):
2404
+ cast_node = ctx .insert_new_node_on_output ("Cast" , new_node .output [i + 1 ], to = out_dtype ,
2405
+ name = utils .make_name (node .name ) + "_cast" )
2393
2406
2394
2407
2395
2408
@tf_op (["Bincount" , "DenseBincount" ])
0 commit comments