Skip to content

Commit bd79c6d

Browse files
hwangdeyufatcat-z
andcommitted
Turn on graph tf optimize grappler dependency (#2020)
* Turn on graph tf optimize grappler dependency Signed-off-by: Deyu Huang <[email protected]> * Aviod output name rewrite in random normal Signed-off-by: Deyu Huang <[email protected]> Co-authored-by: Jay Zhang <[email protected]> Signed-off-by: Jay Zhang <[email protected]>
1 parent 29f4cc0 commit bd79c6d

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

tf2onnx/rewriter/random_normal_rewriter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,22 @@ def rewrite_random_normal(g, ops):
3333
match_results = list(matcher.match_ops(ops))
3434
for match in match_results:
3535
output = match.get_op('output')
36-
if output.type == 'Add':
36+
input2 = match.get_op('input2')
37+
is_output = False
38+
for output_name in g.outputs:
39+
# input2 and output can not be output node.
40+
if input2.name in output_name or output.name in output_name:
41+
is_output = True
42+
break
43+
if is_output:
44+
continue
45+
if output.type == 'Add' and input2.type == 'Mul':
3746
# pattern 1
3847
mean = output.inputs[1].get_tensor_value()
48+
scale = input2.inputs[1].get_tensor_value()
3949
else:
4050
# pattern 2
4151
mean = 0.0
42-
input2 = match.get_op('input2')
43-
if input2.type == 'Mul':
44-
scale = input2.inputs[1].get_tensor_value()
45-
else:
4652
scale = 1.0
4753
dtype = g.get_dtype(output.output[0])
4854
op_name = utils.make_name("RandomNormal")

tf2onnx/tf_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,10 @@ def tf_optimize_grappler(input_names, output_names, graph_def):
687687
'constfold', 'function'
688688
]
689689

690+
if is_tf2():
691+
# add for tf2.x lstm optimization.
692+
rewrite_options.optimizers.append('dependency')
693+
690694
if Version(tf.__version__) >= Version("2.5"):
691695
# This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights)
692696
rewrite_options.experimental_disable_folding_quantization_emulation = True
@@ -771,8 +775,8 @@ def toposort(data):
771775
try:
772776
func = function_def_to_graph(fdef, input_shapes=input_shapes)
773777
except: # pylint: disable=bare-except
774-
# if there is a missmatch between caller and function use the functions shape
775-
logger.warning("shape missmatch between caller and function: %s", k)
778+
# if there is a mismatch between caller and function use the functions shape
779+
logger.warning("shape mismatch between caller and function: %s", k)
776780
func = function_def_to_graph(fdef)
777781
_FUNCTIONS[k] = func
778782
_, _, _, _, _, tfunctions = tflist_to_onnx(func, {})

0 commit comments

Comments
 (0)