Skip to content

Commit a8b55fa

Browse files
Allow passing optimizers to from_keras conversion function (onnx#1907)
* Allow to select optimizers in from_keras conversion function Signed-off-by: Sagar Shelke <[email protected]> * remove trailing whitespace Signed-off-by: Sagar Shelke <[email protected]> Co-authored-by: Sagar Shelke <[email protected]> Co-authored-by: Deyu Huang <[email protected]>
1 parent ad9af3f commit a8b55fa

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/convert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,
402402

403403
def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
404404
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None,
405-
target=None, large_model=False, output_path=None):
405+
target=None, large_model=False, output_path=None, optimizers=None):
406406
"""Returns a ONNX model_proto for a tf.keras model.
407407
408408
Args:
@@ -420,6 +420,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
420420
inputs_as_nchw: transpose inputs in list from nchw to nhwc
421421
large_model: use the ONNX external tensor storage format
422422
output_path: save model to output_path
423+
optimizers: list (subset) of tf2onnx optimizers if applying all optimizers is not desired.
423424
424425
Returns:
425426
An ONNX model_proto and an external_tensor_storage dict.
@@ -492,6 +493,7 @@ def wrap_call(*args, training=False, **kwargs):
492493
opset=opset,
493494
custom_ops=custom_ops,
494495
custom_op_handlers=custom_op_handlers,
496+
optimizers=optimizers,
495497
custom_rewriter=custom_rewriter,
496498
extra_opset=extra_opset,
497499
shape_override=shape_override,

0 commit comments

Comments
 (0)