Skip to content

Commit d7fe8bb

Browse files
Merge pull request #1090 from onnx/tom/ConvertLargeModels
Added support for converting large models
2 parents 62f1e70 + 297352a commit d7fe8bb

File tree

8 files changed

+134
-27
lines changed

8 files changed

+134
-27
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ Only valid with parameter `--saved_model`. Specifies which signature to use with
193193

194194
Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored.
195195

196+
#### --large_model
197+
198+
(This is experimental, valid only for TF2.x models)
199+
200+
Only valid with parameter `--saved_model`. When set, creates a zip file containing the ONNX protobuf model and large tensor values stored externally. This allows for converting models that exceed the 2 GB protobuf limit.
201+
196202
#### --target
197203

198204
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.
@@ -274,7 +280,8 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
274280
opset=None, custom_op_handlers=None,
275281
custom_rewriter=None, extra_opset=None,
276282
shape_override=None, inputs_as_nchw=None,
277-
input_names=None, output_names=None):
283+
input_names=None, output_names=None,
284+
const_node_values=None):
278285
"""Convert tensorflow graph to onnx graph.
279286
Args:
280287
tf_graph: tensorflow graph
@@ -289,6 +296,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
289296
inputs_as_nchw: transpose inputs in list from nchw to nchw
290297
input_names: list of input node names in graph, input name format as node_name:port_id
291298
output_names: list of output node names in graph, output name format as node_name:port_id
299+
const_node_values: an optional dict mapping node names to tensor values
292300
Return:
293301
onnx graph
294302
"""

tests/backend_test_base.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from tf2onnx import optimizer
2727
from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, tf_placeholder, from_function, freeze_session
2828
from tf2onnx.tf_loader import tf_optimize, is_tf2
29+
from tf2onnx.tf_utils import compress_graph_def
30+
from tf2onnx.graph import ExternalTensorStorage
2931

3032

3133
class Tf2OnnxBackendTestBase(unittest.TestCase):
@@ -72,9 +74,10 @@ def run_onnxruntime(self, model_path, inputs, output_names):
7274
results = m.run(output_names, inputs)
7375
return results
7476

75-
def run_backend(self, g, outputs, input_dict):
76-
model_proto = g.make_model("test")
77-
model_path = self.save_onnx_model(model_proto, input_dict)
77+
def run_backend(self, g, outputs, input_dict, large_model=False):
78+
tensor_storage = ExternalTensorStorage() if large_model else None
79+
model_proto = g.make_model("test", external_tensor_storage=tensor_storage)
80+
model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage)
7881

7982
if self.config.backend == "onnxruntime":
8083
y = self.run_onnxruntime(model_path, input_dict, outputs)
@@ -86,7 +89,8 @@ def run_backend(self, g, outputs, input_dict):
8689

8790
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
8891
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
89-
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False):
92+
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False,
93+
large_model=False):
9094
# optional - passed to process_tf_graph
9195
if process_args is None:
9296
process_args = {}
@@ -121,7 +125,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
121125
concrete_func = tf.function(func, input_signature=tuple(input_tensors))
122126
concrete_func = concrete_func.get_concrete_function()
123127
graph_def = from_function(concrete_func,
124-
input_names=list(feed_dict.keys()), output_names=output_names_with_port)
128+
input_names=list(feed_dict.keys()),
129+
output_names=output_names_with_port,
130+
large_model=large_model)
125131
else:
126132
#
127133
# use graph to execute the tensorflow func
@@ -151,6 +157,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
151157

152158
tf_reset_default_graph()
153159
with tf_session() as sess:
160+
const_node_values = None
161+
if large_model:
162+
const_node_values = compress_graph_def(graph_def)
154163
tf.import_graph_def(graph_def, name='')
155164

156165
if self.config.is_debug_mode:
@@ -161,9 +170,11 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
161170
g = process_tf_graph(sess.graph, opset=self.config.opset,
162171
input_names=list(feed_dict.keys()),
163172
output_names=output_names_with_port,
164-
target=self.config.target, **process_args)
173+
target=self.config.target,
174+
const_node_values=const_node_values,
175+
**process_args)
165176
g = optimizer.optimize_graph(g)
166-
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict)
177+
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
167178

168179
for expected_val, actual_val in zip(expected, actual):
169180
if check_value:
@@ -180,10 +191,11 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
180191

181192
return g
182193

183-
def save_onnx_model(self, model_proto, feed_dict, postfix=""):
194+
def save_onnx_model(self, model_proto, feed_dict, postfix="", external_tensor_storage=None):
184195
target_path = utils.save_onnx_model(self.test_data_directory, self._testMethodName + postfix, feed_dict,
185196
model_proto, include_test_data=self.config.is_debug_mode,
186-
as_text=self.config.is_debug_mode)
197+
as_text=self.config.is_debug_mode,
198+
external_tensor_storage=external_tensor_storage)
187199

188200
self.logger.debug("create model file: %s", target_path)
189201
return target_path

tests/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,15 @@ def func(x):
820820
return tf.identity(x_, name=_TFOUTPUT)
821821
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
822822

823+
@check_tf_min_version("2.2")
824+
def test_large_model_format(self):
825+
x_val = np.array([2.0], dtype=np.float32)
826+
y_const = np.arange(2000, dtype=np.float32)
827+
def func(x):
828+
x_ = tf.multiply(x, tf.constant(y_const))
829+
return tf.identity(x_, name=_TFOUTPUT)
830+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, large_model=True)
831+
823832
@check_target('rs6', 'GatherNd')
824833
def test_gathernd(self):
825834
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)

tests/test_convert.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import unittest
99

1010
from tf2onnx import convert
11-
11+
from common import check_tf_min_version
1212

1313
def run_test_case(args):
1414
""" run case and clean up """
@@ -33,6 +33,18 @@ def test_convert_saved_model(self):
3333
'--output',
3434
'converted_saved_model.onnx']))
3535

36+
@check_tf_min_version("2.2")
37+
def test_convert_large_model(self):
38+
""" convert saved model to onnx large model format """
39+
self.assertTrue(run_test_case(['',
40+
'--large_model',
41+
'--saved-model',
42+
'tests/models/regression/saved_model',
43+
'--tag',
44+
'serve',
45+
'--output',
46+
'converted_saved_model.zip']))
47+
3648
def test_convert_graphdef(self):
3749
""" convert graphdef """
3850
self.assertTrue(run_test_case(['',

tf2onnx/convert.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tf2onnx.tfonnx import process_tf_graph
2323
from tf2onnx import constants, logging, utils, optimizer
2424
from tf2onnx import tf_loader
25+
from tf2onnx.graph import ExternalTensorStorage
26+
from tf2onnx.tf_utils import compress_graph_def
2527

2628
# pylint: disable=unused-argument
2729

@@ -53,6 +55,7 @@ def get_args():
5355
help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)")
5456
parser.add_argument("--checkpoint", help="input from checkpoint")
5557
parser.add_argument("--keras", help="input from keras model")
58+
parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true")
5659
parser.add_argument("--output", help="output model file")
5760
parser.add_argument("--inputs", help="model input_names")
5861
parser.add_argument("--outputs", help="model output_names")
@@ -129,7 +132,8 @@ def main():
129132
model_path = args.checkpoint
130133
if args.saved_model:
131134
graph_def, inputs, outputs = tf_loader.from_saved_model(
132-
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function)
135+
args.saved_model, args.inputs, args.outputs, args.tag,
136+
args.signature_def, args.concrete_function, args.large_model)
133137
model_path = args.saved_model
134138
if args.keras:
135139
graph_def, inputs, outputs = tf_loader.from_keras(
@@ -141,6 +145,9 @@ def main():
141145
logger.info("outputs: %s", outputs)
142146

143147
with tf.Graph().as_default() as tf_graph:
148+
const_node_values = None
149+
if args.large_model:
150+
const_node_values = compress_graph_def(graph_def)
144151
tf.import_graph_def(graph_def, name='')
145152
with tf_loader.tf_session(graph=tf_graph):
146153
g = process_tf_graph(tf_graph,
@@ -152,17 +159,24 @@ def main():
152159
shape_override=args.shape_override,
153160
input_names=inputs,
154161
output_names=outputs,
155-
inputs_as_nchw=args.inputs_as_nchw)
162+
inputs_as_nchw=args.inputs_as_nchw,
163+
const_node_values=const_node_values)
156164

157165
onnx_graph = optimizer.optimize_graph(g)
158-
model_proto = onnx_graph.make_model("converted from {}".format(model_path))
166+
167+
tensor_storage = ExternalTensorStorage() if args.large_model else None
168+
model_proto = onnx_graph.make_model("converted from {}".format(model_path), external_tensor_storage=tensor_storage)
159169

160170
# write onnx graph
161171
logger.info("")
162172
logger.info("Successfully converted TensorFlow model %s to ONNX", model_path)
163173
if args.output:
164-
utils.save_protobuf(args.output, model_proto)
165-
logger.info("ONNX model is saved at %s", args.output)
174+
if args.large_model:
175+
utils.save_onnx_zip(args.output, model_proto, tensor_storage)
176+
logger.info("Zipped ONNX model is saved at %s. Unzip before opening in onnxruntime.", args.output)
177+
else:
178+
utils.save_protobuf(args.output, model_proto)
179+
logger.info("ONNX model is saved at %s", args.output)
166180
else:
167181
logger.info("To export ONNX model to file, please run with `--output` option")
168182

tf2onnx/tf_loader.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,32 @@ def inputs_without_resource(sess, input_names):
9595
pass
9696
return input_names
9797

98+
def convert_variables_to_constants_large_model(func):
99+
# For large models we use internal tf methods as a hack
100+
101+
if tf.__version__.startswith("2.2."):
102+
try:
103+
from tensorflow.python.framework.convert_to_constants import \
104+
_convert_variables_to_constants_v2_impl # pylint: disable=protected-access
105+
except ImportError:
106+
_not_implemented_tf_placeholder("_convert_variables_to_constants_v2_impl")()
107+
frozen_graph_def, _ = \
108+
_convert_variables_to_constants_v2_impl(func, lower_control_flow=False, aggressive_inlining=False)
109+
return frozen_graph_def
110+
111+
try:
112+
from tensorflow.python.framework.convert_to_constants import \
113+
_FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access
114+
except ImportError:
115+
_not_implemented_tf_placeholder("_replace_variables_by_constants")()
116+
converter_data = _FunctionConverterData(func=func, lower_control_flow=False, aggressive_inlining=False)
117+
frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data)
118+
return frozen_graph_def
119+
120+
def from_function(func, input_names, output_names, large_model=False):
121+
if large_model:
122+
return convert_variables_to_constants_large_model(func)
98123

99-
def from_function(func, input_names, output_names):
100124
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
101125
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
102126
# output_names = [i.name for i in frozen_func.outputs]
@@ -223,7 +247,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
223247
return frozen_graph, input_names, output_names
224248

225249

226-
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index):
250+
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def,
251+
concrete_function_index, large_model):
227252
"""Load tensorflow graph from saved_model."""
228253

229254
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
@@ -234,6 +259,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
234259
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
235260
err_no_sig = "No signatures found in model. Try --concrete_function instead."
236261
err_sig_nomatch = "Specified signature not in model %s"
262+
err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."
237263

238264
if tag is None:
239265
tag = ['serve']
@@ -274,18 +300,25 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
274300
if output_names:
275301
outputs = list(set(output_names) & set(outputs))
276302

277-
frozen_graph = from_function(concrete_func, inputs, outputs)
303+
try:
304+
frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
305+
except ValueError as e:
306+
if "exceeds maximum protobuf size of 2GB" in str(e):
307+
raise ValueError(err_large_model)
308+
raise e
309+
278310
return frozen_graph, inputs, outputs
279311

280312

281-
def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None):
313+
def from_saved_model(model_path, input_names, output_names, tag=None,
314+
signatures=None, concrete_function=None, large_model=False):
282315
"""Load tensorflow graph from saved_model."""
283316
if signatures is None:
284317
signatures = []
285318
tf_reset_default_graph()
286319
if is_tf2():
287320
frozen_graph, input_names, output_names = \
288-
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function)
321+
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function, large_model)
289322
else:
290323
with tf_session() as sess:
291324
frozen_graph, input_names, output_names = \

tf2onnx/tfonnx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def run_rewriters(g, funcs, continue_on_error):
334334
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
335335
opset=None, custom_op_handlers=None, custom_rewriter=None,
336336
extra_opset=None, shape_override=None, inputs_as_nchw=None,
337-
input_names=None, output_names=None, is_subgraph=False):
337+
input_names=None, output_names=None, is_subgraph=False, const_node_values=None):
338338
"""Convert tensorflow graph to onnx graph.
339339
Args:
340340
tf_graph: tensorflow graph
@@ -349,6 +349,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
349349
inputs_as_nchw: transpose inputs in list from nchw to nchw
350350
input_names: list of input node names in graph, input name format as node_name:port_id
351351
output_names: list of output node names in graph, output name format as node_name:port_id
352+
const_node_values: a dict returned by compress_graph_def mapping node names to tensor values
352353
Return:
353354
onnx graph
354355
"""
@@ -377,7 +378,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
377378
if target is None:
378379
target = constants.DEFAULT_TARGET
379380

380-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = tensorflow_to_onnx(tf_graph, shape_override)
381+
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
382+
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)
381383
if not is_subgraph:
382384
# make tf2onnx internal subgraphs from the tensorflow subgraphs
383385
ordered_func = resolve_functions(tf_graph)
@@ -387,7 +389,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
387389
fg = process_tf_graph(func, continue_on_error, False, target, opset,
388390
custom_op_handlers, custom_rewriter,
389391
extra_opset, shape_override, inputs_as_nchw,
390-
f_inputs_names, f_output_names, is_subgraph=True)
392+
f_inputs_names, f_output_names, is_subgraph=True,
393+
const_node_values=const_node_values)
391394
fg.graph_name = func.name
392395
fg.func_inputs = f_inputs_names
393396
set_function(func.name, fg)

tf2onnx/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import re
1414
import shutil
1515
import tempfile
16+
import zipfile
1617

1718
import requests
1819
from requests.adapters import HTTPAdapter
@@ -161,7 +162,8 @@ def find_opset(opset):
161162
return opset
162163

163164

164-
def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, include_test_data=False, as_text=False):
165+
def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, include_test_data=False, as_text=False,
166+
external_tensor_storage=None):
165167
"""Save onnx model as file. Save a pbtxt file as well if as_text is True"""
166168
save_path = save_path_root
167169
if not os.path.exists(save_path):
@@ -181,12 +183,26 @@ def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, incl
181183
save_protobuf(data_full_path, t)
182184
i += 1
183185

184-
target_path = os.path.join(save_path, onnx_file_name + ".onnx")
185-
save_protobuf(target_path, model_proto)
186+
if external_tensor_storage is None:
187+
target_path = os.path.join(save_path, onnx_file_name + ".onnx")
188+
save_protobuf(target_path, model_proto)
189+
else:
190+
zip_path = os.path.join(save_path, onnx_file_name + ".zip")
191+
save_onnx_zip(zip_path, model_proto, external_tensor_storage)
192+
with zipfile.ZipFile(zip_path, 'r') as z:
193+
z.extractall(save_path)
194+
target_path = os.path.join(save_path, "__MODEL_PROTO.onnx")
195+
186196
if as_text:
187197
save_protobuf(target_path + ".pbtxt", model_proto, as_text=True)
198+
188199
return target_path
189200

201+
def save_onnx_zip(target_path, model_proto, external_tensor_storage):
202+
with zipfile.ZipFile(target_path, 'w') as z:
203+
z.writestr("__MODEL_PROTO.onnx", model_proto.SerializeToString())
204+
for k, v in external_tensor_storage.name_to_tensor_data.items():
205+
z.writestr(k, v)
190206

191207
def make_sure(bool_val, error_msg, *args):
192208
if not bool_val:

0 commit comments

Comments
 (0)