Skip to content

Commit 2f9078e

Browse files
authored
add script to run optimizer on onnx files (#1538)
* add script to run optimizer on onnx files Signed-off-by: Guenther Schmuelling <[email protected]> * pylint Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent cba8f8d commit 2f9078e

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

tools/onnx-optimize.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
A simple tool to try optimizations on onnx graphs.
6+
This makes use of the fact that tensorflow-onnx internal graph representation is onnx
7+
so all graph, rewrite, matching and utility libaries do work which makes things easy.
8+
"""
9+
10+
# pylint: disable=invalid-name,missing-docstring, unused-argument
11+
12+
import argparse
13+
import logging
14+
15+
import onnx
16+
from onnx import helper
17+
18+
from tf2onnx.graph import GraphUtil
19+
from tf2onnx import logging, optimizer
20+
21+
22+
logging.basicConfig(level=logging.INFO)
23+
logger = logging.getLogger("onnx-optimize")
24+
25+
26+
def get_args():
27+
"""Parse commandline."""
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--input", required=True, help="onnx input model file")
30+
parser.add_argument("--output", help="output model file")
31+
args = parser.parse_args()
32+
return args
33+
34+
35+
def load_graph(fname):
36+
model_proto = onnx.ModelProto()
37+
with open(fname, "rb") as f:
38+
data = f.read()
39+
model_proto.ParseFromString(data)
40+
g = GraphUtil.create_graph_from_onnx_model(model_proto)
41+
return g, model_proto
42+
43+
44+
def main():
45+
args = get_args()
46+
47+
g, org_model_proto = load_graph(args.input)
48+
49+
g = optimizer.optimize_graph(g)
50+
51+
onnx_graph = g.make_graph(org_model_proto.graph.doc_string + " (+tf2onnx/onnx-optimize)")
52+
53+
kwargs = {"producer_name": org_model_proto.producer_name,
54+
"producer_version": org_model_proto.producer_version,
55+
"opset_imports": org_model_proto.opset_import,
56+
"ir_version": org_model_proto.ir_version}
57+
58+
model_proto = helper.make_model(onnx_graph, **kwargs)
59+
60+
# write onnx graph
61+
if args.output:
62+
with open(args.output, "wb") as f:
63+
f.write(model_proto.SerializeToString())
64+
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)