@@ -42,10 +42,18 @@ def version_8(cls, ctx, node, initialized_tables, **kwargs):
42
42
43
43
dtype = ctx .get_dtype (node .output [0 ])
44
44
in_dtype = ctx .get_dtype (node .input [1 ])
45
- utils .make_sure (dtype == TensorProto .INT64 and in_dtype == TensorProto .STRING ,
46
- "Only lookup tables of type string->int64 are currently supported." )
45
+ utils .make_sure ((dtype == TensorProto .INT64 and in_dtype == TensorProto .STRING ) or
46
+ (dtype == TensorProto .STRING and in_dtype == TensorProto .INT64 ),
47
+ f"Only lookup tables of type string<->int64 are currently supported." )
48
+
49
+ if in_dtype == TensorProto .STRING :
50
+ cats_strings , cats_int64s = initialized_tables [shared_name ]
51
+ default_key = 'default_int64'
52
+ else :
53
+ cats_int64s , cats_strings = initialized_tables [shared_name ]
54
+ default_key = 'default_string'
55
+ attr = {'cats_int64s' : cats_int64s , 'cats_strings' : cats_strings , default_key : default_val }
47
56
48
- cats_strings , cats_int64s = initialized_tables [shared_name ]
49
57
shape = ctx .get_shape (node .input [1 ])
50
58
51
59
node_name = node .name
@@ -56,18 +64,19 @@ def version_8(cls, ctx, node, initialized_tables, **kwargs):
56
64
# Handle explicitly since const folding doesn't work for tables
57
65
key_np = node .inputs [1 ].get_tensor_value (as_list = False )
58
66
ctx .remove_node (node .name )
59
- key_to_val = dict (zip (cats_strings , cats_int64s ))
60
- def lookup_value (key ):
61
- return key_to_val .get (key .encode ("UTF-8" ), default_val_np )
62
- lookup_result = np .vectorize (lookup_value )(key_np )
67
+ if in_dtype == TensorProto .STRING :
68
+ key_to_val = dict (zip (cats_strings , cats_int64s ))
69
+ lookup_result = np .vectorize (lambda key : key_to_val .get (key .encode ("UTF-8" ), default_val_np ))(key_np )
70
+ else :
71
+ key_to_val = dict (zip (cats_int64s , cats_strings ))
72
+ lookup_result = np .vectorize (lambda key : key_to_val .get (key , default_val_np ))(key_np ).astype (object )
63
73
onnx_tensor = numpy_helper .from_array (lookup_result , node_name )
64
74
ctx .make_node ("Const" , name = node_name , inputs = [], outputs = node_outputs ,
65
75
attr = {"value" : onnx_tensor }, shapes = [lookup_result .shape ], dtypes = [dtype ])
66
76
else :
67
77
ctx .remove_node (node .name )
68
78
ctx .make_node ("CategoryMapper" , domain = constants .AI_ONNX_ML_DOMAIN ,
69
- name = node_name , inputs = [node_inputs [1 ]], outputs = node_outputs ,
70
- attr = {'cats_int64s' : cats_int64s , 'cats_strings' : cats_strings , 'default_int64' : default_val },
79
+ name = node_name , inputs = [node_inputs [1 ]], outputs = node_outputs , attr = attr ,
71
80
shapes = [shape ], dtypes = [dtype ])
72
81
73
82
customer_nodes = ctx .find_output_consumers (table_node .output [0 ])
0 commit comments