Skip to content

Commit 621b4b1

Browse files
committed
fix Conv2D Bias Add fuse
Signed-off-by: hwangdeyu <[email protected]>
1 parent 72d6460 commit 621b4b1

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,28 @@ def test_conv2d_transpose_2(runner, padding):
11981198
assert runner(onnx_model.graph.name, onnx_model, data, expected)
11991199

12001200

1201+
def test_conv2d_bias_add(runner):
1202+
size = 128
1203+
input_img = Input((size, size, 64))
1204+
x = tf.keras.layers.ZeroPadding2D(
1205+
padding=(0, 4),
1206+
data_format="channels_first",
1207+
name="padding")(input_img)
1208+
y = tf.keras.layers.Conv2D(
1209+
filters=1,
1210+
kernel_size=(9, 9),
1211+
strides=(1, 1),
1212+
use_bias=True,
1213+
data_format="channels_first",
1214+
name="conv2d",
1215+
)(x)
1216+
model = tf.keras.Model(inputs=input_img, outputs=y)
1217+
data = np.random.rand(1, size, size, 64).astype(np.float32)
1218+
onnx_model = convert_keras(model, model.name)
1219+
expected = model.predict(data)
1220+
assert runner(onnx_model.graph.name, onnx_model, data, expected)
1221+
1222+
12011223
def test_conv2d_padding_same(conv2_runner):
12021224
conv2_runner(3, 5, (2, 2), (1, 1), (5, 5), padding='same')
12031225
conv2_runner(8, 16, (1, 1), (2, 2), (60, 60), padding='same')

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
__all__ = [
3030
"rewrite_cond",
31-
"rewrite_conv2d_with_pad",
3231
"rewrite_dropout",
3332
"rewrite_eye",
3433
"rewrite_flatten",
@@ -49,6 +48,7 @@
4948
"rewrite_quantize_and_dequantize",
5049
"rewrite_layer_normalization",
5150
"rewrite_conv_dilations",
51+
"rewrite_conv2d_with_pad",
5252
"rewrite_ragged_variant_shape",
5353
"rewriter_lstm_tf2",
5454
"rewrite_gru_tf2",

tf2onnx/rewriter/conv2d_with_add_rewriter.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,40 @@
1313
# pylint: disable=missing-docstring
1414

1515
def rewrite_biasadd_with_conv2d(g, ops):
16-
pattern = \
16+
pattern1 = \
1717
OpTypePattern('BiasAdd', name='biasadd', inputs=[
1818
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
19-
matcher = GraphMatcher(pattern)
20-
match_results = list(matcher.match_ops(ops))
21-
for match in match_results:
22-
biasadd = match.get_op('biasadd')
23-
conv = match.get_op('conv')
24-
25-
#backup the conv and biasadd values
26-
conv_type = conv.type
27-
conv_input = conv.input
28-
conv_attr = conv.attr
29-
dtype = g.get_dtype(conv.output[0])
30-
shape = g.get_shape(conv.output[0])
31-
conv_name = biasadd.name
32-
conv_output = biasadd.output
33-
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
34-
35-
if len(g.find_output_consumers(conv.output[0])) > 1:
36-
continue
37-
# Remove the Conv and BiasAdd node
38-
g.remove_node(conv.name)
39-
g.remove_node(biasadd.name)
40-
41-
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
42-
shapes=[shape], dtypes=[dtype], skip_conversion=False)
19+
pattern2 = \
20+
OpTypePattern('BiasAdd', name='biasadd', inputs=[
21+
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=[
22+
'*', '*', '*']), '*'], allow_reorder=True)
23+
24+
for pattern in [pattern1, pattern2]:
25+
matcher = GraphMatcher(pattern)
26+
match_results = list(matcher.match_ops(ops))
27+
for match in match_results:
28+
biasadd = match.get_op('biasadd')
29+
conv = match.get_op('conv')
30+
31+
# Backup the conv and biasadd values
32+
conv_type = conv.type
33+
conv_input = conv.input
34+
conv_attr = conv.attr
35+
dtype = g.get_dtype(conv.output[0])
36+
shape = g.get_shape(conv.output[0])
37+
conv_name = biasadd.name
38+
conv_output = biasadd.output
39+
if pattern == pattern2:
40+
conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]]
41+
else:
42+
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
43+
44+
if len(g.find_output_consumers(conv.output[0])) > 1:
45+
continue
46+
# Remove the Conv and BiasAdd node
47+
g.remove_node(conv.name)
48+
g.remove_node(biasadd.name)
49+
50+
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
51+
shapes=[shape], dtypes=[dtype], skip_conversion=False)
4352
return ops

0 commit comments

Comments
 (0)