Skip to content

Commit 2849c50

Browse files
Refactor tfonnx
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d18d3f7 commit 2849c50

File tree

2 files changed

+90
-74
lines changed

2 files changed

+90
-74
lines changed

tf2onnx/graph.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,23 @@ def is_const(self, output):
874874
def get_tensor_value(self, output, as_list=True):
875875
return self.get_node_by_output(output).get_tensor_value(as_list)
876876

877+
def rename_tensors(self, tensors_to_rename):
878+
"""Replace tensor names within nodes and graph inputs/outputs"""
879+
def rename_list(l):
880+
return [tensors_to_rename.get(t, t) for t in l]
881+
882+
def rename_keys(d):
883+
return {tensors_to_rename.get(k, k): v for k, v in d.items()}
884+
885+
self._output_to_node_name = rename_keys(self._output_to_node_name)
886+
self._output_to_consumers = rename_keys(self._output_to_consumers)
887+
self._dtypes = rename_keys(self._dtypes)
888+
self._output_shapes = rename_keys(self._output_shapes)
889+
self.outputs = rename_list(self.outputs)
890+
for node in self._nodes:
891+
node._input = rename_list(node._input)
892+
node._output = rename_list(node._output)
893+
877894
def change_node_name(self, node, new_name):
878895
"""Remove node in current graph."""
879896
utils.make_sure(new_name not in self._nodes_by_name, "node %s not unique ", new_name)
@@ -1232,15 +1249,23 @@ def follow_inputs(self, node, num, space=""):
12321249
return []
12331250
return val
12341251

1235-
def dump_node_statistics(self):
1252+
def dump_node_statistics(self, include_attrs=False, include_subgraphs=True):
1253+
"""Return a counter of op types (and optionally attribute names) within the graph"""
12361254
op_cnt = collections.Counter()
1255+
attr_cnt = collections.Counter()
12371256
for n in self.get_nodes():
12381257
op_cnt[n.type] += 1
1258+
for k in n.attr.keys():
1259+
attr_cnt[k] += 1
12391260
body_graphs = n.get_body_graphs()
1240-
if body_graphs:
1261+
if body_graphs and include_subgraphs:
12411262
for b_g in body_graphs.values():
1242-
op_cnt += b_g.dump_node_statistics()
1263+
g_op_cnt, g_attr_cnt = b_g.dump_node_statistics(include_attrs=True, include_subgraphs=True)
1264+
op_cnt += g_op_cnt
1265+
attr_cnt += g_attr_cnt
12431266

1267+
if include_attrs:
1268+
return op_cnt, attr_cnt
12441269
return op_cnt
12451270

12461271
def remove_input(self, node, to_be_removed, input_index=None):

tf2onnx/tfonnx.py

Lines changed: 62 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
3434
# pylint: disable=unused-variable
3535

36-
def fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes):
36+
def fold_constants_using_tf(g, outputs_to_values):
3737
ops = list(g.get_nodes())
3838
# pylint: disable=too-many-nested-blocks
3939
keep_looking = True
@@ -409,14 +409,13 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
409409
del verbose
410410

411411
opset = utils.find_opset(opset)
412-
if not is_subgraph:
413-
logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
414-
get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6])
415-
logger.info("Using opset <onnx, %s>", opset)
416-
if opset > schemas.get_max_supported_opset_version():
417-
logger.warning("Currently installed onnx package %s is too low to support opset %s, "
418-
"please upgrade onnx package to avoid potential conversion issue.",
419-
utils.get_onnx_version(), opset)
412+
logger.info("Using tensorflow=%s, onnx=%s, tf2onnx=%s/%s",
413+
get_tf_version(), utils.get_onnx_version(), tf2onnx.__version__, tf2onnx.version.git_version[:6])
414+
logger.info("Using opset <onnx, %s>", opset)
415+
if opset > schemas.get_max_supported_opset_version():
416+
logger.warning("Currently installed onnx package %s is too low to support opset %s, "
417+
"please upgrade onnx package to avoid potential conversion issue.",
418+
utils.get_onnx_version(), opset)
420419

421420
if shape_override is None:
422421
shape_override = {}
@@ -440,34 +439,17 @@ def check_io(input_names, output_names, output_shapes):
440439
non_exists)
441440
raise ValueError("Inputs/Outputs Not Found")
442441

443-
def rename_tensors_in_dict(d):
444-
if tensors_to_rename is None:
445-
return d
446-
return {tensors_to_rename.get(k, k): v for k, v in d.items()}
447-
448-
def rename_tensors_in_list(tensors):
449-
if tensors_to_rename is None or tensors is None:
450-
return tensors
451-
return [tensors_to_rename.get(t, t) for t in tensors]
452-
453-
def rename_tensors_in_nodes(onnx_nodes):
454-
if tensors_to_rename is None:
455-
return
456-
for n in onnx_nodes:
457-
n.input[:] = rename_tensors_in_list(n.input)
458-
n.output[:] = rename_tensors_in_list(n.output)
459-
460442
if tflite_path is not None:
461443
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
462444
main_g = None
463-
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
445+
subgraphs = []
464446
for i, tfl_graph in enumerate(tflite_graphs):
465447
is_main_g = i == len(tflite_graphs) - 1
466448
prefix = '' if is_main_g else tfl_graph.Name().decode() + '_'
467449
tensor_shapes_from_interpreter = None
468450
if is_main_g:
469451
tensor_shapes_from_interpreter = tensor_shapes
470-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
452+
onnx_nodes, _, _, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
471453
parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
472454
g_inputs = f_inputs
473455
g_outputs = f_outputs
@@ -478,63 +460,74 @@ def rename_tensors_in_nodes(onnx_nodes):
478460
g_inputs = input_names
479461
if output_names is not None:
480462
g_outputs = output_names
481-
rename_tensors_in_nodes(onnx_nodes)
482-
g_inputs = rename_tensors_in_list(g_inputs)
483-
g_outputs = rename_tensors_in_list(g_outputs)
484-
output_shapes = rename_tensors_in_dict(output_shapes)
485-
dtypes = rename_tensors_in_dict(dtypes)
486-
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs, is_subgraph)
487-
fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
488-
g_outputs, {}, {}, {}, op_cnt, attr_cnt, is_tflite=True, dequantize=dequantize)
489-
fg.graph_name = graph_name
463+
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, g_inputs, g_outputs,
464+
not is_main_g, graph_name)
490465
if is_main_g:
491-
main_g = fg
466+
main_g = g
492467
else:
493-
set_function(graph_name, fg)
494-
495-
return main_g
496-
497-
is_func = is_function(tf_graph)
498-
if not is_func:
499-
tf_graph = infer_shape(tf_graph, shape_override)
468+
subgraphs.append(g)
500469

501-
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
470+
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
471+
target, {}, tensors_to_rename, is_tflite=True, dequantize=dequantize)
472+
return g
502473

503-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
504-
tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
505474
if not is_subgraph:
506475
# make tf2onnx internal subgraphs from the tensorflow subgraphs
507476
ordered_func = resolve_functions(tf_graph)
477+
subgraphs = []
508478
for func in ordered_func:
509479
f_inputs_names = [t.name for t in func.inputs]
510480
f_output_names = [t.name for t in func.outputs]
511-
fg = process_tf_graph(func, continue_on_error, False, target, opset,
512-
custom_op_handlers, custom_rewriter,
513-
extra_opset, shape_override, inputs_as_nchw,
514-
f_inputs_names, f_output_names, is_subgraph=True,
515-
const_node_values=const_node_values, tensors_to_rename=tensors_to_rename,
516-
initialized_tables=initialized_tables)
517-
fg.graph_name = func.name
518-
set_function(func.name, fg)
481+
482+
outputs_to_values, _ = compute_const_folding_using_tf(func, const_node_values, output_names)
483+
484+
onnx_nodes, _, _, output_shapes, dtypes, _ = \
485+
tensorflow_to_onnx(func, shape_override, const_node_values, ignore_default, use_default)
486+
487+
fg = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, f_inputs_names, f_output_names,
488+
is_subgraph=True, graph_name=func.name)
489+
fold_constants_using_tf(fg, outputs_to_values)
490+
subgraphs.append(fg)
491+
492+
is_func = is_function(tf_graph)
493+
if not is_func:
494+
tf_graph = infer_shape(tf_graph, shape_override)
495+
496+
outputs_to_values, _ = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
497+
498+
onnx_nodes, _, _, output_shapes, dtypes, _ = \
499+
tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
519500

520501
check_io(input_names, output_names, output_shapes)
502+
main_g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names,
503+
is_subgraph)
504+
fold_constants_using_tf(main_g, outputs_to_values)
505+
g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter,
506+
target, initialized_tables, tensors_to_rename)
507+
return g
521508

522-
if not is_subgraph:
523-
rename_tensors_in_nodes(onnx_nodes)
524-
input_names = rename_tensors_in_list(input_names)
525-
output_names = rename_tensors_in_list(output_names)
526-
output_shapes = rename_tensors_in_dict(output_shapes)
527-
dtypes = rename_tensors_in_dict(dtypes)
528-
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
529-
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph)
530-
g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
531-
output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt)
509+
510+
def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
511+
initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False):
512+
513+
if tensors_to_rename is not None:
514+
main_g.rename_tensors(tensors_to_rename)
515+
inputs_as_nchw = [tensors_to_rename.get(t, t) for t in inputs_as_nchw]
516+
517+
for g in subgraphs:
518+
fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
519+
initialized_tables, is_tflite, dequantize)
520+
set_function(fg.graph_name, fg)
521+
g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
522+
initialized_tables, is_tflite,
523+
dequantize)
532524
return g
533525

534526

535527
def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target,
536-
output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt,
537-
is_tflite=False, dequantize=False):
528+
initialized_tables, is_tflite=False, dequantize=False):
529+
530+
op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False)
538531

539532
if is_tflite:
540533
tfl_rewriters = []
@@ -587,8 +580,6 @@ def compat_handler(ctx, node, **kwargs):
587580
if inputs_as_nchw:
588581
transpose_inputs(g, inputs_as_nchw)
589582

590-
fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes)
591-
592583
# pre-processing graph rewrites
593584
# bi-directional re-writer should be placed after single directional re-writer
594585
rewriters = [
@@ -626,7 +617,7 @@ def compat_handler(ctx, node, **kwargs):
626617
run_rewriters(g, rewriters, continue_on_error)
627618

628619
# some nodes may already copied into inner Graph, so remove them from main Graph.
629-
g.delete_unused_nodes(output_names)
620+
g.delete_unused_nodes(g.outputs)
630621
topological_sort(g, continue_on_error)
631622

632623
mapped_op, unmapped_op, exceptions = \

0 commit comments

Comments
 (0)