diff --git a/tf2onnx/rewriter/random_normal_rewriter.py b/tf2onnx/rewriter/random_normal_rewriter.py index 6d106907e..3691b5f00 100644 --- a/tf2onnx/rewriter/random_normal_rewriter.py +++ b/tf2onnx/rewriter/random_normal_rewriter.py @@ -33,16 +33,22 @@ def rewrite_random_normal(g, ops): match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') - if output.type == 'Add': + input2 = match.get_op('input2') + is_output = False + for output_name in g.outputs: + # input2 and output can not be output node. + if input2.name in output_name or output.name in output_name: + is_output = True + break + if is_output: + continue + if output.type == 'Add' and input2.type == 'Mul': # pattern 1 mean = output.inputs[1].get_tensor_value() + scale = input2.inputs[1].get_tensor_value() else: # pattern 2 mean = 0.0 - input2 = match.get_op('input2') - if input2.type == 'Mul': - scale = input2.inputs[1].get_tensor_value() - else: scale = 1.0 dtype = g.get_dtype(output.output[0]) op_name = utils.make_name("RandomNormal") diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 7269a26e0..d9d72a8dc 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -687,6 +687,10 @@ def tf_optimize_grappler(input_names, output_names, graph_def): 'constfold', 'function' ] + if is_tf2(): + # add for tf2.x lstm optimization. + rewrite_options.optimizers.append('dependency') + if Version(tf.__version__) >= Version("2.5"): # This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights) rewrite_options.experimental_disable_folding_quantization_emulation = True @@ -771,8 +775,8 @@ def toposort(data): try: func = function_def_to_graph(fdef, input_shapes=input_shapes) except: # pylint: disable=bare-except - # if there is a missmatch between caller and function use the functions shape - logger.warning("shape missmatch between caller and function: %s", k) + # if there is a mismatch between caller and function use the functions shape + logger.warning("shape mismatch between caller and function: %s", k) func = function_def_to_graph(fdef) _FUNCTIONS[k] = func _, _, _, _, _, tfunctions = tflist_to_onnx(func, {})