33
33
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
34
34
# pylint: disable=unused-variable
35
35
36
- def fold_constants_using_tf (g , outputs_to_values , outputs_to_dtypes ):
36
+ def fold_constants_using_tf (g , outputs_to_values ):
37
37
ops = list (g .get_nodes ())
38
38
# pylint: disable=too-many-nested-blocks
39
39
keep_looking = True
@@ -409,14 +409,13 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
409
409
del verbose
410
410
411
411
opset = utils .find_opset (opset )
412
- if not is_subgraph :
413
- logger .info ("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s" ,
414
- get_tf_version (), utils .get_onnx_version (), tf2onnx .__version__ , tf2onnx .version .git_version [:6 ])
415
- logger .info ("Using opset <onnx, %s>" , opset )
416
- if opset > schemas .get_max_supported_opset_version ():
417
- logger .warning ("Currently installed onnx package %s is too low to support opset %s, "
418
- "please upgrade onnx package to avoid potential conversion issue." ,
419
- utils .get_onnx_version (), opset )
412
+ logger .info ("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s" ,
413
+ get_tf_version (), utils .get_onnx_version (), tf2onnx .__version__ , tf2onnx .version .git_version [:6 ])
414
+ logger .info ("Using opset <onnx, %s>" , opset )
415
+ if opset > schemas .get_max_supported_opset_version ():
416
+ logger .warning ("Currently installed onnx package %s is too low to support opset %s, "
417
+ "please upgrade onnx package to avoid potential conversion issue." ,
418
+ utils .get_onnx_version (), opset )
420
419
421
420
if shape_override is None :
422
421
shape_override = {}
@@ -440,34 +439,17 @@ def check_io(input_names, output_names, output_shapes):
440
439
non_exists )
441
440
raise ValueError ("Inputs/Outputs Not Found" )
442
441
443
- def rename_tensors_in_dict (d ):
444
- if tensors_to_rename is None :
445
- return d
446
- return {tensors_to_rename .get (k , k ): v for k , v in d .items ()}
447
-
448
- def rename_tensors_in_list (tensors ):
449
- if tensors_to_rename is None or tensors is None :
450
- return tensors
451
- return [tensors_to_rename .get (t , t ) for t in tensors ]
452
-
453
- def rename_tensors_in_nodes (onnx_nodes ):
454
- if tensors_to_rename is None :
455
- return
456
- for n in onnx_nodes :
457
- n .input [:] = rename_tensors_in_list (n .input )
458
- n .output [:] = rename_tensors_in_list (n .output )
459
-
460
442
if tflite_path is not None :
461
443
tflite_graphs , opcodes , model , tensor_shapes = read_tflite_model (tflite_path )
462
444
main_g = None
463
- inputs_as_nchw = rename_tensors_in_list ( inputs_as_nchw )
445
+ subgraphs = []
464
446
for i , tfl_graph in enumerate (tflite_graphs ):
465
447
is_main_g = i == len (tflite_graphs ) - 1
466
448
prefix = '' if is_main_g else tfl_graph .Name ().decode () + '_'
467
449
tensor_shapes_from_interpreter = None
468
450
if is_main_g :
469
451
tensor_shapes_from_interpreter = tensor_shapes
470
- onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes , f_inputs , f_outputs , graph_name = \
452
+ onnx_nodes , _ , _ , output_shapes , dtypes , f_inputs , f_outputs , graph_name = \
471
453
parse_tflite_graph (tfl_graph , opcodes , model , prefix , tensor_shapes_from_interpreter )
472
454
g_inputs = f_inputs
473
455
g_outputs = f_outputs
@@ -478,63 +460,74 @@ def rename_tensors_in_nodes(onnx_nodes):
478
460
g_inputs = input_names
479
461
if output_names is not None :
480
462
g_outputs = output_names
481
- rename_tensors_in_nodes (onnx_nodes )
482
- g_inputs = rename_tensors_in_list (g_inputs )
483
- g_outputs = rename_tensors_in_list (g_outputs )
484
- output_shapes = rename_tensors_in_dict (output_shapes )
485
- dtypes = rename_tensors_in_dict (dtypes )
486
- g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , g_inputs , g_outputs , is_subgraph )
487
- fg = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
488
- g_outputs , {}, {}, {}, op_cnt , attr_cnt , is_tflite = True , dequantize = dequantize )
489
- fg .graph_name = graph_name
463
+ g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , g_inputs , g_outputs ,
464
+ not is_main_g , graph_name )
490
465
if is_main_g :
491
- main_g = fg
466
+ main_g = g
492
467
else :
493
- set_function (graph_name , fg )
494
-
495
- return main_g
496
-
497
- is_func = is_function (tf_graph )
498
- if not is_func :
499
- tf_graph = infer_shape (tf_graph , shape_override )
468
+ subgraphs .append (g )
500
469
501
- outputs_to_values , outputs_to_dtypes = compute_const_folding_using_tf (tf_graph , const_node_values , output_names )
470
+ g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
471
+ target , {}, tensors_to_rename , is_tflite = True , dequantize = dequantize )
472
+ return g
502
473
503
- onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes , _ = \
504
- tensorflow_to_onnx (tf_graph , shape_override , const_node_values , ignore_default , use_default )
505
474
if not is_subgraph :
506
475
# make tf2onnx internal subgraphs from the tensorflow subgraphs
507
476
ordered_func = resolve_functions (tf_graph )
477
+ subgraphs = []
508
478
for func in ordered_func :
509
479
f_inputs_names = [t .name for t in func .inputs ]
510
480
f_output_names = [t .name for t in func .outputs ]
511
- fg = process_tf_graph (func , continue_on_error , False , target , opset ,
512
- custom_op_handlers , custom_rewriter ,
513
- extra_opset , shape_override , inputs_as_nchw ,
514
- f_inputs_names , f_output_names , is_subgraph = True ,
515
- const_node_values = const_node_values , tensors_to_rename = tensors_to_rename ,
516
- initialized_tables = initialized_tables )
517
- fg .graph_name = func .name
518
- set_function (func .name , fg )
481
+
482
+ outputs_to_values , _ = compute_const_folding_using_tf (func , const_node_values , output_names )
483
+
484
+ onnx_nodes , _ , _ , output_shapes , dtypes , _ = \
485
+ tensorflow_to_onnx (func , shape_override , const_node_values , ignore_default , use_default )
486
+
487
+ fg = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , f_inputs_names , f_output_names ,
488
+ is_subgraph = True , graph_name = func .name )
489
+ fold_constants_using_tf (fg , outputs_to_values )
490
+ subgraphs .append (fg )
491
+
492
+ is_func = is_function (tf_graph )
493
+ if not is_func :
494
+ tf_graph = infer_shape (tf_graph , shape_override )
495
+
496
+ outputs_to_values , _ = compute_const_folding_using_tf (tf_graph , const_node_values , output_names )
497
+
498
+ onnx_nodes , _ , _ , output_shapes , dtypes , _ = \
499
+ tensorflow_to_onnx (tf_graph , shape_override , const_node_values , ignore_default , use_default )
519
500
520
501
check_io (input_names , output_names , output_shapes )
502
+ main_g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , input_names , output_names ,
503
+ is_subgraph )
504
+ fold_constants_using_tf (main_g , outputs_to_values )
505
+ g = process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter ,
506
+ target , initialized_tables , tensors_to_rename )
507
+ return g
521
508
522
- if not is_subgraph :
523
- rename_tensors_in_nodes (onnx_nodes )
524
- input_names = rename_tensors_in_list (input_names )
525
- output_names = rename_tensors_in_list (output_names )
526
- output_shapes = rename_tensors_in_dict (output_shapes )
527
- dtypes = rename_tensors_in_dict (dtypes )
528
- inputs_as_nchw = rename_tensors_in_list (inputs_as_nchw )
529
- g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , input_names , output_names , is_subgraph )
530
- g = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
531
- output_names , initialized_tables , outputs_to_values , outputs_to_dtypes , op_cnt , attr_cnt )
509
+
510
+ def process_graphs (main_g , subgraphs , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
511
+ initialized_tables , tensors_to_rename , is_tflite = False , dequantize = False ):
512
+
513
+ if tensors_to_rename is not None :
514
+ main_g .rename_tensors (tensors_to_rename )
515
+ inputs_as_nchw = [tensors_to_rename .get (t , t ) for t in inputs_as_nchw ]
516
+
517
+ for g in subgraphs :
518
+ fg = process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
519
+ initialized_tables , is_tflite , dequantize )
520
+ set_function (fg .graph_name , fg )
521
+ g = process_parsed_graph (main_g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
522
+ initialized_tables , is_tflite ,
523
+ dequantize )
532
524
return g
533
525
534
526
535
527
def process_parsed_graph (g , custom_op_handlers , inputs_as_nchw , continue_on_error , custom_rewriter , target ,
536
- output_names , initialized_tables , outputs_to_values , outputs_to_dtypes , op_cnt , attr_cnt ,
537
- is_tflite = False , dequantize = False ):
528
+ initialized_tables , is_tflite = False , dequantize = False ):
529
+
530
+ op_cnt , attr_cnt = g .dump_node_statistics (include_attrs = True , include_subgraphs = False )
538
531
539
532
if is_tflite :
540
533
tfl_rewriters = []
@@ -587,8 +580,6 @@ def compat_handler(ctx, node, **kwargs):
587
580
if inputs_as_nchw :
588
581
transpose_inputs (g , inputs_as_nchw )
589
582
590
- fold_constants_using_tf (g , outputs_to_values , outputs_to_dtypes )
591
-
592
583
# pre-processing graph rewrites
593
584
# bi-directional re-writer should be placed after single directional re-writer
594
585
rewriters = [
@@ -626,7 +617,7 @@ def compat_handler(ctx, node, **kwargs):
626
617
run_rewriters (g , rewriters , continue_on_error )
627
618
628
619
# some nodes may already copied into inner Graph, so remove them from main Graph.
629
- g .delete_unused_nodes (output_names )
620
+ g .delete_unused_nodes (g . outputs )
630
621
topological_sort (g , continue_on_error )
631
622
632
623
mapped_op , unmapped_op , exceptions = \
0 commit comments