Skip to content

Commit 98f72d4

Browse files
committed
add the tf1 keras missing optimization
Signed-off-by: hwangdeyu <[email protected]>
1 parent 621b4b1 commit 98f72d4

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,28 +1198,6 @@ def test_conv2d_transpose_2(runner, padding):
11981198
assert runner(onnx_model.graph.name, onnx_model, data, expected)
11991199

12001200

1201-
def test_conv2d_bias_add(runner):
1202-
size = 128
1203-
input_img = Input((size, size, 64))
1204-
x = tf.keras.layers.ZeroPadding2D(
1205-
padding=(0, 4),
1206-
data_format="channels_first",
1207-
name="padding")(input_img)
1208-
y = tf.keras.layers.Conv2D(
1209-
filters=1,
1210-
kernel_size=(9, 9),
1211-
strides=(1, 1),
1212-
use_bias=True,
1213-
data_format="channels_first",
1214-
name="conv2d",
1215-
)(x)
1216-
model = tf.keras.Model(inputs=input_img, outputs=y)
1217-
data = np.random.rand(1, size, size, 64).astype(np.float32)
1218-
onnx_model = convert_keras(model, model.name)
1219-
expected = model.predict(data)
1220-
assert runner(onnx_model.graph.name, onnx_model, data, expected)
1221-
1222-
12231201
def test_conv2d_padding_same(conv2_runner):
12241202
conv2_runner(3, 5, (2, 2), (1, 1), (5, 5), padding='same')
12251203
conv2_runner(8, 16, (1, 1), (2, 2), (60, 60), padding='same')

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,26 @@ def func(x):
740740
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
741741
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
742742

743+
@check_tf_min_version("1.15")
744+
def test_conv2d_biasadd_rewriter(self):
745+
x_shape = [2, 3, 32, 16]
746+
x_val = make_xval(x_shape)
747+
def func(x):
748+
middles = tf.keras.layers.ZeroPadding2D(padding=(0, 4),
749+
data_format="channels_first",
750+
name="padding")(x)
751+
t = tf.keras.layers.Conv2D(
752+
filters=768,
753+
kernel_size=3,
754+
strides=1,
755+
use_bias=True,
756+
data_format="channels_first",
757+
name="conv2d",
758+
)(middles)
759+
return tf.identity(t, name=_TFOUTPUT)
760+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True, constant_fold=False,
761+
graph_validator=lambda g: check_op_count(g, "Add", 0, disabled =False))
762+
743763
@check_tf_min_version("1.15")
744764
def test_conv2d_dilations_rewriter(self):
745765
x_shape = [2, 32, 16, 3]

tf2onnx/convert.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu
373373

374374
with tf.device("/cpu:0"):
375375
frozen_graph, initialized_tables = tf_loader.freeze_session(sess, input_names, output_names, get_tables=True)
376+
tf_loader.tf_reset_default_graph()
377+
tf.import_graph_def(frozen_graph, name='')
378+
sess = tf.keras.backend.get_session(model.outputs)
379+
frozen_graph = tf_loader.tf_optimize(input_names, output_names, frozen_graph)
376380
model_proto, external_tensor_storage = _convert_common(
377381
frozen_graph,
378382
name=model.name,

tf2onnx/tf_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non
681681
rewrite_options = config.graph_options.rewrite_options
682682
config.graph_options.infer_shapes = True
683683
# TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter
684-
# depends on so for now don't turn this on.
684+
# depends on so for now don't turn this on, fold_constant is always enabled now.
685685
rewrite_options.optimizers[:] = [
686686
# 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
687687
'constfold', 'function'

0 commit comments

Comments
 (0)