diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 2f969f77e..2a30eeae2 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -290,8 +290,8 @@ def from_function(func, input_names, output_names, large_model=False): def freeze_session(sess, input_names=None, output_names=None, get_tables=False): """Freezes the state of a session into a pruned computation graph.""" - output_node_names = [i.split(':')[:-1][0] for i in output_names] - keep_var_names = [i.split(':')[:-1][0] for i in input_names] + output_node_names = [i.split(':')[0] for i in output_names] + keep_var_names = [i.split(':')[0] for i in input_names] with sess.graph.as_default(): output_node_names = output_node_names or [] output_node_names += [v.op.name for v in tf_global_variables()]