Skip to content

Commit 6cdb7e3

Browse files
authored
Support unique_with_counts (#2195)
* support UniqueWithCounts * unique_with_counts test --------- Signed-off-by: Salvetti, Francesco <[email protected]>
1 parent 6dda2bb commit 6cdb7e3

File tree

3 files changed

+62
-15
lines changed

3 files changed

+62
-15
lines changed

support_status.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
| Transpose | 1 ~ 17 |
268268
| TruncateDiv | 1 ~ 17 |
269269
| Unique | 11 ~ 17 |
270+
| UniqueWithCounts | 11 ~ 18 |
270271
| Unpack | 1 ~ 17 |
271272
| UnsortedSegmentMax | 11 ~ 17 |
272273
| UnsortedSegmentMin | 11 ~ 17 |

tests/test_backend.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5047,6 +5047,39 @@ def func(x):
50475047
return y1, y2
50485048
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val})
50495049

5050+
@check_opset_min_version(11, "Unique")
5051+
def test_unique_with_counts(self):
5052+
x_val = np.array([1, 2, 8, 1, 2, 2, 7, 7, 7, 1], dtype=np.float32)
5053+
def func(x):
5054+
x1_, x2_, x3_ = tf.unique_with_counts(x)
5055+
y1 = tf.identity(x1_, name=_TFOUTPUT)
5056+
y2 = tf.identity(x2_, name=_TFOUTPUT1)
5057+
y3 = tf.identity(x3_, name=_TFOUTPUT2)
5058+
return y1, y2, y3
5059+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val})
5060+
5061+
@check_opset_min_version(11, "Unique")
5062+
def test_unique_with_counts_out_int64(self):
5063+
x_val = np.array([2, 3, 3, 6, 4, 1, 1], dtype=np.float32)
5064+
def func(x):
5065+
x1_, x2_, x3_ = tf.unique_with_counts(x, out_idx=tf.int64)
5066+
y1 = tf.identity(x1_, name=_TFOUTPUT)
5067+
y2 = tf.identity(x2_, name=_TFOUTPUT1)
5068+
y3 = tf.identity(x3_, name=_TFOUTPUT2)
5069+
return y1, y2, y3
5070+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val})
5071+
5072+
@check_opset_min_version(11, "Unique")
5073+
def test_unique_with_counts_out_int32(self):
5074+
x_val = np.array([2, 3, 3, 6, 4, 1, 1], dtype=np.float32)
5075+
def func(x):
5076+
x1_, x2_, x3_ = tf.unique_with_counts(x, out_idx=tf.int32)
5077+
y1 = tf.identity(x1_, name=_TFOUTPUT)
5078+
y2 = tf.identity(x2_, name=_TFOUTPUT1)
5079+
y3 = tf.identity(x3_, name=_TFOUTPUT2)
5080+
return y1, y2, y3
5081+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val})
5082+
50505083
@check_opset_min_version(11, "Unique")
50515084
def test_bincount(self):
50525085
x_val = np.array([5, 2, 3, 1, 3, 2, 7, 5, 9, 10], dtype=np.int32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,32 +2364,45 @@ def version_10(cls, ctx, node, **kwargs):
23642364

23652365

23662366
@tf_op("Unique", onnx_op="Unique")
2367+
@tf_op("UniqueWithCounts", onnx_op="Unique")
23672368
class Unique:
2369+
int_cast = [TensorProto.BOOL, TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8,
2370+
TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64]
2371+
dtype_map = {k: TensorProto.INT64 for k in int_cast}
2372+
dtype_map[TensorProto.DOUBLE] = TensorProto.FLOAT
2373+
23682374
@classmethod
23692375
def version_11(cls, ctx, node, **kwargs):
23702376
# opset 11 supports explicitly
2371-
dtypes = node.output_dtypes
23722377
node_name = node.name
23732378
node_inputs = node.input
23742379
node_outputs = node.output
2380+
inp_dtype = ctx.get_dtype(node.input[0])
2381+
23752382
ctx.remove_node(node_name)
2376-
if dtypes[0] in [TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8, TensorProto.UINT16]:
2377-
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': TensorProto.INT64}).output[0]
2383+
2384+
# due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
2385+
if inp_dtype in cls.dtype_map:
2386+
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': cls.dtype_map[inp_dtype]}).output[0]
23782387
node_inputs[0] = inp_cast
2379-
new_node = ctx.make_node("Unique", node_inputs, name=node_name, output_count=3, attr={'sorted': 0})
2388+
2389+
new_node = ctx.make_node("Unique", node_inputs, name=node_name, attr={'sorted': 0},
2390+
outputs=[utils.make_name("y"), utils.make_name("idx_first"),
2391+
utils.make_name("idx"), utils.make_name("counts")])
23802392
ctx.replace_all_inputs(node_outputs[0], new_node.output[0])
23812393
ctx.replace_all_inputs(node_outputs[1], new_node.output[2])
2382-
if ctx.get_dtype(new_node.output[0]) != dtypes[0]:
2383-
ctx.insert_new_node_on_output("Cast", new_node.output[0], name=utils.make_name(node.name) + "_cast",
2384-
to=dtypes[0])
2385-
if len(node_outputs) > 1:
2386-
# cast to int64 if needed
2387-
if dtypes[1] != onnx_pb.TensorProto.INT64:
2388-
cast_node = ctx.insert_new_node_on_output("Cast", new_node.output[2],
2389-
name=utils.make_name(node.name) + "_cast",
2390-
to=dtypes[1])
2391-
ctx.set_dtype(cast_node.output[0], dtypes[1])
2392-
ctx.copy_shape(new_node.output[2], cast_node.output[0])
2394+
if len(node_outputs) == 3: # we need counts too (UniqueWithCounts)
2395+
ctx.replace_all_inputs(node_outputs[2], new_node.output[3])
2396+
if ctx.get_dtype(new_node.output[0]) != inp_dtype:
2397+
ctx.insert_new_node_on_output("Cast", new_node.output[0], to=inp_dtype,
2398+
name=utils.make_name(node.name) + "_cast")
2399+
2400+
# cast idx and counts if needed
2401+
out_dtype = node.get_attr_value('out_idx')
2402+
if out_dtype != TensorProto.INT64:
2403+
for i in range(1, len(node_outputs)):
2404+
cast_node = ctx.insert_new_node_on_output("Cast", new_node.output[i+1], to=out_dtype,
2405+
name=utils.make_name(node.name) + "_cast")
23932406

23942407

23952408
@tf_op(["Bincount", "DenseBincount"])

0 commit comments

Comments
 (0)