4
4
"""Methods to load tensorflow graph from graphdef, checkpoint or saved_model."""
5
5
6
6
import logging
7
+ import uuid
7
8
from distutils .version import LooseVersion
8
9
9
10
import tensorflow as tf
15
16
from tensorflow .python .util import compat
16
17
17
18
from tf2onnx import utils
18
- from tf2onnx .tf_utils import get_tf_version , tflist_to_onnx , get_hash_table_info , replace_placeholders_with_tables
19
+ from tf2onnx .tf_utils import (get_tf_version , tflist_to_onnx , get_hash_table_info , replace_placeholders_with_tables ,
20
+ HashTableInfo )
19
21
20
22
logger = logging .getLogger (__name__ )
21
23
@@ -184,7 +186,7 @@ def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
184
186
err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."
185
187
186
188
# Avoid errors due to bug in TF freezing
187
- removed_resource_to_placeholder , graph_captures_copy , func_captures_copy = \
189
+ removed_resource_to_placeholder , placeholder_to_resource , graph_captures_copy , func_captures_copy = \
188
190
_remove_non_variable_resources_from_captures (concrete_func )
189
191
190
192
try :
@@ -197,16 +199,28 @@ def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
197
199
# We might be returning the concrete_func so let's put it back in working order
198
200
_restore_captured_resources (concrete_func , graph_captures_copy , func_captures_copy )
199
201
200
- table_names , key_dtypes , value_dtypes = get_hash_table_info (frozen_graph )
202
+ table_info = get_hash_table_info (frozen_graph )
201
203
placeholder_to_table_info = {}
202
- _get_hash_table_info_from_trackable (trackable , table_names , key_dtypes , value_dtypes ,
204
+ _get_hash_table_info_from_trackable (trackable , table_info ,
203
205
removed_resource_to_placeholder , placeholder_to_table_info )
204
206
205
207
initialized_tables = {}
206
- for n , k_dtype , val_dtype in zip (table_names , key_dtypes , value_dtypes ):
207
- h = lookup_ops .hash_table_v2 (k_dtype , val_dtype , shared_name = n )
208
+ for info in table_info :
209
+ if info .shared_name is not None :
210
+ h = lookup_ops .hash_table_v2 (info .key_dtype , info .val_dtype , shared_name = info .shared_name )
211
+ n = info .shared_name
212
+ elif info .resource_input in placeholder_to_resource and info .resource_input not in placeholder_to_table_info :
213
+ # We found a lookup op with no corresponding HashTable op, but we can associate the placeholder input
214
+ # from the op with the resource handle from graph captures and make up a shared_name
215
+ h = placeholder_to_resource [info .resource_input ]
216
+ n = str (uuid .uuid4 ()).encode ()
217
+ info .shared_name = n
218
+ placeholder_to_table_info [info .resource_input ] = info
219
+ else :
220
+ # Found a lookup op but the corresponding HashTable op has already been found and processed.
221
+ continue
208
222
try :
209
- k , v = lookup_ops .lookup_table_export_v2 (h , k_dtype , val_dtype )
223
+ k , v = lookup_ops .lookup_table_export_v2 (h , info . key_dtype , info . val_dtype )
210
224
initialized_tables [n ] = (k .numpy (), v .numpy ())
211
225
except Exception : # pylint: disable=broad-except
212
226
logger .warning ("Could not initialize table with shared_name = %r" , n )
@@ -260,14 +274,14 @@ def freeze_session(sess, input_names=None, output_names=None, get_tables=False):
260
274
for node in graph_def .node :
261
275
node .device = ""
262
276
graph_def = convert_variables_to_constants (sess , graph_def , output_node_names )
263
- table_names , key_dtypes , value_dtypes = get_hash_table_info (graph_def )
277
+ table_info = get_hash_table_info (graph_def )
264
278
if get_tables :
265
279
initialized_tables = {}
266
280
tf .tables_initializer ().run (session = sess )
267
- for n , k_dtype , val_dtype in zip ( table_names , key_dtypes , value_dtypes ) :
268
- h = lookup_ops .hash_table_v2 (k_dtype , val_dtype , shared_name = n )
281
+ for info in table_info :
282
+ h = lookup_ops .hash_table_v2 (info . key_dtype , info . val_dtype , shared_name = info . shared_name )
269
283
try :
270
- k , v = lookup_ops .lookup_table_export_v2 (h , k_dtype , val_dtype )
284
+ k , v = lookup_ops .lookup_table_export_v2 (h , info . key_dtype , info . val_dtype )
271
285
k , v = sess .run ([k , v ])
272
286
initialized_tables [n ] = (k , v )
273
287
except Exception : # pylint: disable=broad-except
@@ -403,7 +417,7 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
403
417
return frozen_graph , input_names , output_names , initialized_tables , tensors_to_rename
404
418
405
419
406
- def _get_hash_table_info_from_trackable (trackable , table_names , key_dtypes , value_dtypes ,
420
+ def _get_hash_table_info_from_trackable (trackable , table_info ,
407
421
removed_resource_to_placeholder , placeholder_to_table_info ):
408
422
# pylint: disable=protected-access
409
423
stack = [trackable ]
@@ -420,26 +434,22 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu
420
434
continue
421
435
for t in r .__dict__ .values () if hasattr (r , '__dict__' ) else []:
422
436
if isinstance (t , TfStaticHashTableType ) and hasattr (t , '_shared_name' ):
423
- table_names . append (t ._shared_name .encode ())
424
- key_dtypes . append ( t . key_dtype .as_datatype_enum )
425
- value_dtypes .append (t . value_dtype . as_datatype_enum )
437
+ info = HashTableInfo (t ._shared_name .encode (), t . key_dtype . as_datatype_enum ,
438
+ t . value_dtype .as_datatype_enum )
439
+ table_info .append (info )
426
440
table_handle = id (t .resource_handle )
427
441
if table_handle in removed_resource_to_placeholder :
428
- table_info = (table_names [- 1 ], key_dtypes [- 1 ], value_dtypes [- 1 ])
429
- placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = table_info
442
+ placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = info
430
443
if isinstance (r , TfRestoredResourceType ) and hasattr (r , '_create_resource' ):
431
444
try :
432
445
table_handle = id (r .resource_handle )
433
446
except Exception : # pylint: disable=broad-except
434
447
continue
435
448
initializer = r ._create_resource .concrete_functions [0 ].function_def
436
- new_names , new_k_dtypes , new_v_dtypes = get_hash_table_info (initializer .node_def )
437
- table_names .extend (new_names )
438
- key_dtypes .extend (new_k_dtypes )
439
- value_dtypes .extend (new_v_dtypes )
440
- if table_handle in removed_resource_to_placeholder and len (new_names ) == 1 :
441
- table_info = (new_names [0 ], new_k_dtypes [0 ], new_v_dtypes [0 ])
442
- placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = table_info
449
+ new_table_info = get_hash_table_info (initializer .node_def )
450
+ table_info .extend (new_table_info )
451
+ if table_handle in removed_resource_to_placeholder and len (new_table_info ) == 1 :
452
+ placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = new_table_info [0 ]
443
453
444
454
445
455
def _remove_non_variable_resources_from_captures (concrete_func ):
@@ -449,6 +459,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
449
459
"""
450
460
# pylint: disable=protected-access
451
461
resource_id_to_placeholder = {}
462
+ placeholder_to_resource = {}
452
463
graph_captures_copy = None
453
464
func_captures_copy = None
454
465
if hasattr (concrete_func .graph , '_captures' ) and hasattr (concrete_func , '_captured_inputs' ):
@@ -459,6 +470,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
459
470
val_tensor , name_tensor = v
460
471
if val_tensor .dtype == tf .resource and id (val_tensor ) not in variable_handles :
461
472
resource_id_to_placeholder [id (val_tensor )] = name_tensor .name .split (':' )[0 ]
473
+ placeholder_to_resource [name_tensor .name .split (':' )[0 ]] = val_tensor
462
474
del concrete_func .graph ._captures [k ]
463
475
for i in reversed (range (len (concrete_func ._captured_inputs ))):
464
476
if concrete_func ._captured_inputs [i ] is val_tensor :
@@ -472,7 +484,7 @@ def _remove_non_variable_resources_from_captures(concrete_func):
472
484
else :
473
485
logger .warning (
474
486
"Could not search for non-variable resources. Concrete function internal representation may have changed." )
475
- return resource_id_to_placeholder , graph_captures_copy , func_captures_copy
487
+ return resource_id_to_placeholder , placeholder_to_resource , graph_captures_copy , func_captures_copy
476
488
477
489
478
490
def _restore_captured_resources (concrete_func , graph_captures_copy , func_captures_copy ):
0 commit comments