Skip to content

Commit dd373a1

Browse files
pagranfatcat-z
andauthored
Added support for int64 -> string CategoryMapper updated (#2181)
* add int64->string support Signed-off-by: pagran <[email protected]> --------- Signed-off-by: pagran <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent df57e4b commit dd373a1

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

tests/test_backend.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5440,6 +5440,24 @@ def func(query_holder):
54405440
self._run_test_case(func, [_OUTPUT], {_INPUT: query}, as_session=True)
54415441
os.remove(filnm)
54425442

5443+
@check_opset_min_version(8, "CategoryMapper")
5444+
@skip_tfjs("TFJS does not initialize table")
5445+
@skip_onnx_checker("ONNX can't do type inference on CategoryMapper")
5446+
def test_hashtable_lookup_invert(self):
5447+
filnm = "vocab.tmp"
5448+
words = ["apple", "pear", "banana", "cherry", "grape"]
5449+
query = np.array([3], dtype=np.int64)
5450+
with open(filnm, "w") as f:
5451+
for word in words:
5452+
f.write(word + "\n")
5453+
def func(query_holder):
5454+
hash_table = lookup_ops.index_to_string_table_from_file(filnm)
5455+
lookup_results = hash_table.lookup(query_holder)
5456+
ret = tf.identity(lookup_results, name=_TFOUTPUT)
5457+
return ret
5458+
self._run_test_case(func, [_OUTPUT], {_INPUT: query}, as_session=True)
5459+
os.remove(filnm)
5460+
54435461
@check_opset_min_version(8, "CategoryMapper")
54445462
@skip_tfjs("TFJS does not initialize table")
54455463
def test_hashtable_lookup_const(self):
@@ -5458,6 +5476,24 @@ def func():
54585476
self._run_test_case(func, [_OUTPUT], {}, as_session=True)
54595477
os.remove(filnm)
54605478

5479+
@check_opset_min_version(8, "CategoryMapper")
5480+
@skip_tfjs("TFJS does not initialize table")
5481+
def test_hashtable_lookup_invert_const(self):
5482+
filnm = "vocab.tmp"
5483+
words = ["apple", "pear", "banana", "cherry", "grape"]
5484+
query_val = np.array([3, 2], dtype=np.int64).reshape((1, 2, 1))
5485+
with open(filnm, "w", encoding='UTF-8') as f:
5486+
for word in words:
5487+
f.write(word + "\n")
5488+
def func():
5489+
hash_table = lookup_ops.index_to_string_table_from_file(filnm)
5490+
query = tf.constant(query_val)
5491+
lookup_results = hash_table.lookup(query)
5492+
ret = tf.identity(lookup_results, name=_TFOUTPUT)
5493+
return ret
5494+
self._run_test_case(func, [_OUTPUT], {}, as_session=True)
5495+
os.remove(filnm)
5496+
54615497
@skip_tfjs("TFJS does not initialize table")
54625498
def test_hashtable_size(self):
54635499
filnm = "vocab.tmp"

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,18 @@ def version_8(cls, ctx, node, initialized_tables, **kwargs):
4242

4343
dtype = ctx.get_dtype(node.output[0])
4444
in_dtype = ctx.get_dtype(node.input[1])
45-
utils.make_sure(dtype == TensorProto.INT64 and in_dtype == TensorProto.STRING,
46-
"Only lookup tables of type string->int64 are currently supported.")
45+
utils.make_sure((dtype == TensorProto.INT64 and in_dtype == TensorProto.STRING) or
46+
(dtype == TensorProto.STRING and in_dtype == TensorProto.INT64),
47+
f"Only lookup tables of type string<->int64 are currently supported.")
48+
49+
if in_dtype == TensorProto.STRING:
50+
cats_strings, cats_int64s = initialized_tables[shared_name]
51+
default_key = 'default_int64'
52+
else:
53+
cats_int64s, cats_strings = initialized_tables[shared_name]
54+
default_key = 'default_string'
55+
attr = {'cats_int64s': cats_int64s, 'cats_strings': cats_strings, default_key: default_val}
4756

48-
cats_strings, cats_int64s = initialized_tables[shared_name]
4957
shape = ctx.get_shape(node.input[1])
5058

5159
node_name = node.name
@@ -56,18 +64,19 @@ def version_8(cls, ctx, node, initialized_tables, **kwargs):
5664
# Handle explicitly since const folding doesn't work for tables
5765
key_np = node.inputs[1].get_tensor_value(as_list=False)
5866
ctx.remove_node(node.name)
59-
key_to_val = dict(zip(cats_strings, cats_int64s))
60-
def lookup_value(key):
61-
return key_to_val.get(key.encode("UTF-8"), default_val_np)
62-
lookup_result = np.vectorize(lookup_value)(key_np)
67+
if in_dtype == TensorProto.STRING:
68+
key_to_val = dict(zip(cats_strings, cats_int64s))
69+
lookup_result = np.vectorize(lambda key: key_to_val.get(key.encode("UTF-8"), default_val_np))(key_np)
70+
else:
71+
key_to_val = dict(zip(cats_int64s, cats_strings))
72+
lookup_result = np.vectorize(lambda key: key_to_val.get(key, default_val_np))(key_np).astype(object)
6373
onnx_tensor = numpy_helper.from_array(lookup_result, node_name)
6474
ctx.make_node("Const", name=node_name, inputs=[], outputs=node_outputs,
6575
attr={"value": onnx_tensor}, shapes=[lookup_result.shape], dtypes=[dtype])
6676
else:
6777
ctx.remove_node(node.name)
6878
ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN,
69-
name=node_name, inputs=[node_inputs[1]], outputs=node_outputs,
70-
attr={'cats_int64s': cats_int64s, 'cats_strings': cats_strings, 'default_int64': default_val},
79+
name=node_name, inputs=[node_inputs[1]], outputs=node_outputs, attr=attr,
7180
shapes=[shape], dtypes=[dtype])
7281

7382
customer_nodes = ctx.find_output_consumers(table_node.output[0])

0 commit comments

Comments
 (0)