Skip to content

Commit 4240eb6

Browse files
Improve lookup table loading for from_keras conversion
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 04d2488 commit 4240eb6

File tree

3 files changed

+57
-40
lines changed

3 files changed

+57
-40
lines changed

tests/backend_test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeh
183183
graph_def = freeze_session(sess,
184184
input_names=list(feed_dict.keys()),
185185
output_names=outputs)
186-
table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def)
186+
table_info = get_hash_table_info(graph_def)
187187
initialized_tables = {}
188-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
189-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
190-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
188+
for info in table_info:
189+
h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
190+
k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
191191
initialized_tables[n] = (sess.run(k), sess.run(v))
192192

193193
tf_reset_default_graph()

tf2onnx/tf_loader.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Methods to load tensorflow graph from graphdef, checkpoint or saved_model."""
55

66
import logging
7+
import uuid
78
from distutils.version import LooseVersion
89

910
import tensorflow as tf
@@ -15,7 +16,8 @@
1516
from tensorflow.python.util import compat
1617

1718
from tf2onnx import utils
18-
from tf2onnx.tf_utils import get_tf_version, tflist_to_onnx, get_hash_table_info, replace_placeholders_with_tables
19+
from tf2onnx.tf_utils import (get_tf_version, tflist_to_onnx, get_hash_table_info, replace_placeholders_with_tables,
20+
HashTableInfo)
1921

2022
logger = logging.getLogger(__name__)
2123

@@ -184,7 +186,7 @@ def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
184186
err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."
185187

186188
# Avoid errors due to bug in TF freezing
187-
removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
189+
removed_resource_to_placeholder, placeholder_to_resource, graph_captures_copy, func_captures_copy = \
188190
_remove_non_variable_resources_from_captures(concrete_func)
189191

190192
try:
@@ -197,16 +199,28 @@ def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
197199
# We might be returning the concrete_func so let's put it back in working order
198200
_restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)
199201

200-
table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
202+
table_info = get_hash_table_info(frozen_graph)
201203
placeholder_to_table_info = {}
202-
_get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
204+
_get_hash_table_info_from_trackable(trackable, table_info,
203205
removed_resource_to_placeholder, placeholder_to_table_info)
204206

205207
initialized_tables = {}
206-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
207-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
208+
for info in table_info:
209+
if info.shared_name is not None:
210+
h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
211+
n = info.shared_name
212+
elif info.resource_input in placeholder_to_resource and info.resource_input not in placeholder_to_table_info:
213+
# We found a lookup op with no corresponding HashTable op, but we can associate the placeholder input
214+
# from the op with the resource handle from graph captures and make up a shared_name
215+
h = placeholder_to_resource[info.resource_input]
216+
n = str(uuid.uuid4()).encode()
217+
info.shared_name = n
218+
placeholder_to_table_info[info.resource_input] = info
219+
else:
220+
# Found a lookup op but the corresponding HashTable op has already been found and processed.
221+
continue
208222
try:
209-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
223+
k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
210224
initialized_tables[n] = (k.numpy(), v.numpy())
211225
except Exception: # pylint: disable=broad-except
212226
logger.warning("Could not initialize table with shared_name = %r", n)
@@ -260,14 +274,14 @@ def freeze_session(sess, input_names=None, output_names=None, get_tables=False):
260274
for node in graph_def.node:
261275
node.device = ""
262276
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
263-
table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def)
277+
table_info = get_hash_table_info(graph_def)
264278
if get_tables:
265279
initialized_tables = {}
266280
tf.tables_initializer().run(session=sess)
267-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
268-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
281+
for info in table_info:
282+
h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
269283
try:
270-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
284+
k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
271285
k, v = sess.run([k, v])
272286
initialized_tables[n] = (k, v)
273287
except Exception: # pylint: disable=broad-except
@@ -403,7 +417,7 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
403417
return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
404418

405419

406-
def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
420+
def _get_hash_table_info_from_trackable(trackable, table_info,
407421
removed_resource_to_placeholder, placeholder_to_table_info):
408422
# pylint: disable=protected-access
409423
stack = [trackable]
@@ -420,26 +434,22 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu
420434
continue
421435
for t in r.__dict__.values() if hasattr(r, '__dict__') else []:
422436
if isinstance(t, TfStaticHashTableType) and hasattr(t, '_shared_name'):
423-
table_names.append(t._shared_name.encode())
424-
key_dtypes.append(t.key_dtype.as_datatype_enum)
425-
value_dtypes.append(t.value_dtype.as_datatype_enum)
437+
info = HashTableInfo(t._shared_name.encode(), t.key_dtype.as_datatype_enum,
438+
t.value_dtype.as_datatype_enum)
439+
table_info.append(info)
426440
table_handle = id(t.resource_handle)
427441
if table_handle in removed_resource_to_placeholder:
428-
table_info = (table_names[-1], key_dtypes[-1], value_dtypes[-1])
429-
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
442+
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = info
430443
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource'):
431444
try:
432445
table_handle = id(r.resource_handle)
433446
except Exception: # pylint: disable=broad-except
434447
continue
435448
initializer = r._create_resource.concrete_functions[0].function_def
436-
new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(initializer.node_def)
437-
table_names.extend(new_names)
438-
key_dtypes.extend(new_k_dtypes)
439-
value_dtypes.extend(new_v_dtypes)
440-
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
441-
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
442-
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
449+
new_table_info = get_hash_table_info(initializer.node_def)
450+
table_info.extend(new_table_info)
451+
if table_handle in removed_resource_to_placeholder and len(new_table_info) == 1:
452+
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = new_table_info[0]
443453

444454

445455
def _remove_non_variable_resources_from_captures(concrete_func):
@@ -449,6 +459,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
449459
"""
450460
# pylint: disable=protected-access
451461
resource_id_to_placeholder = {}
462+
placeholder_to_resource = {}
452463
graph_captures_copy = None
453464
func_captures_copy = None
454465
if hasattr(concrete_func.graph, '_captures') and hasattr(concrete_func, '_captured_inputs'):
@@ -459,6 +470,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
459470
val_tensor, name_tensor = v
460471
if val_tensor.dtype == tf.resource and id(val_tensor) not in variable_handles:
461472
resource_id_to_placeholder[id(val_tensor)] = name_tensor.name.split(':')[0]
473+
placeholder_to_resource[name_tensor.name.split(':')[0]] = val_tensor
462474
del concrete_func.graph._captures[k]
463475
for i in reversed(range(len(concrete_func._captured_inputs))):
464476
if concrete_func._captured_inputs[i] is val_tensor:
@@ -472,7 +484,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
472484
else:
473485
logger.warning(
474486
"Could not search for non-variable resources. Concrete function internal representation may have changed.")
475-
return resource_id_to_placeholder, graph_captures_copy, func_captures_copy
487+
return resource_id_to_placeholder, placeholder_to_resource, graph_captures_copy, func_captures_copy
476488

477489

478490
def _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy):

tf2onnx/tf_utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ def is_huge_shape(x):
278278
logger.info("Computed %d values for constant folding", len(outputs_to_values))
279279
return outputs_to_values, outputs_to_dtypes
280280

281+
class HashTableInfo:
282+
def __init__(self, shared_name, key_dtype, val_dtype, resource_input=None):
283+
self.shared_name = shared_name
284+
self.key_dtype = key_dtype
285+
self.val_dtype = val_dtype
286+
self.resource_input = resource_input
287+
281288
def get_hash_table_info(nodes_or_graph_def):
282289
"""
283290
Return lists of the shared_names, key_dtypes, and value_dtypes of all hash tables declared in the graph_def
@@ -287,18 +294,16 @@ def get_hash_table_info(nodes_or_graph_def):
287294
nodes = nodes_or_graph_def.node
288295
else:
289296
nodes = nodes_or_graph_def
290-
names = []
291-
key_dtypes = []
292-
val_dtypes = []
297+
info = []
293298
for n in nodes:
299+
if n.op == "LookupTableFindV2":
300+
info.append(HashTableInfo(None, n.attr['Tin'].type, n.attr['Tout'].type, n.input[0]))
294301
if n.op in ["HashTableV2", "MutableHashTableV2"]:
295302
if all(k in n.attr for k in ['shared_name', 'key_dtype', 'value_dtype']):
296303
name = n.attr['shared_name'].s
297304
if name != b'':
298-
names.append(name)
299-
key_dtypes.append(n.attr['key_dtype'].type)
300-
val_dtypes.append(n.attr['value_dtype'].type)
301-
return names, key_dtypes, val_dtypes
305+
info.append(HashTableInfo(name, n.attr['key_dtype'].type, n.attr['value_dtype'].type))
306+
return info
302307

303308
def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
304309
"""
@@ -307,13 +312,13 @@ def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
307312
"""
308313
for n in graph_def.node:
309314
if n.op == "Placeholder" and n.name in placeholder_to_table_info:
310-
name, key_dtype, val_dtype = placeholder_to_table_info[n.name]
315+
info = placeholder_to_table_info[n.name]
311316
for a in list(n.attr):
312317
del n.attr[a]
313318
n.op = "HashTableV2"
314-
n.attr['shared_name'].s = name
315-
n.attr['key_dtype'].type = key_dtype
316-
n.attr['value_dtype'].type = val_dtype
319+
n.attr['shared_name'].s = info.shared_name
320+
n.attr['key_dtype'].type = info.key_dtype
321+
n.attr['value_dtype'].type = info.val_dtype
317322

318323
def read_tf_node_def_attrs(node_def, input_dtypes, input_shapes):
319324
"""Given a tf node def, returns a dict of attribute names to values"""

0 commit comments

Comments
 (0)