Skip to content

Commit 6b8411b

Browse files
committed
fix nightly CI failed (#1784)
* Fix onnxruntime-nightly-unittest-matrix CI failed. Co-authored-by: hwangdeyu [email protected] Signed-off-by: hwangdeyu <[email protected]>
1 parent 65aaa2c commit 6b8411b

File tree

2 files changed

+34
-26
lines changed

2 files changed

+34
-26
lines changed

tests/test_backend.py

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2830,7 +2830,7 @@ def func(x):
28302830
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
28312831

28322832
@check_opset_min_version(7, "batchnorm")
2833-
@check_tf_min_version("2.0", "tf-1.x does not support NDHWC")
2833+
@check_tf_min_version("2.4", "tf version above 2.4 supports NDHWC")
28342834
def test_fused_batchnorm_3d(self):
28352835
x_shape = [1, 28, 28, 2, 2]
28362836
x_dtype = np.float32

tf2onnx/rewriter/conv2d_with_add_rewriter.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,39 @@
1313
# pylint: disable=missing-docstring
1414

1515
def rewrite_biasadd_with_conv2d(g, ops):
16-
pattern = \
16+
pattern1 = \
1717
OpTypePattern('BiasAdd', name='biasadd', inputs=[
1818
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
19-
matcher = GraphMatcher(pattern)
20-
match_results = list(matcher.match_ops(ops))
21-
for match in match_results:
22-
biasadd = match.get_op('biasadd')
23-
conv = match.get_op('conv')
24-
25-
#backup the conv and biasadd values
26-
conv_type = conv.type
27-
conv_input = conv.input
28-
conv_attr = conv.attr
29-
dtype = g.get_dtype(conv.output[0])
30-
shape = g.get_shape(conv.output[0])
31-
conv_name = biasadd.name
32-
conv_output = biasadd.output
33-
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
34-
35-
if len(g.find_output_consumers(conv.output[0])) > 1:
36-
continue
37-
# Remove the Conv and BiasAdd node
38-
g.remove_node(conv.name)
39-
g.remove_node(biasadd.name)
40-
41-
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
42-
shapes=[shape], dtypes=[dtype], skip_conversion=False)
19+
pattern2 = \
20+
OpTypePattern('BiasAdd', name='biasadd', inputs=[
21+
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*', '*']), '*'], allow_reorder = True)
22+
23+
for pattern in [pattern1, pattern2]:
24+
matcher = GraphMatcher(pattern)
25+
match_results = list(matcher.match_ops(ops))
26+
for match in match_results:
27+
biasadd = match.get_op('biasadd')
28+
conv = match.get_op('conv')
29+
30+
#backup the conv and biasadd values
31+
conv_type = conv.type
32+
conv_input = conv.input
33+
conv_attr = conv.attr
34+
dtype = g.get_dtype(conv.output[0])
35+
shape = g.get_shape(conv.output[0])
36+
conv_name = biasadd.name
37+
conv_output = biasadd.output
38+
if pattern == pattern2:
39+
conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]]
40+
else:
41+
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
42+
43+
if len(g.find_output_consumers(conv.output[0])) > 1:
44+
continue
45+
# Remove the Conv and BiasAdd node
46+
g.remove_node(conv.name)
47+
g.remove_node(biasadd.name)
48+
49+
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
50+
shapes=[shape], dtypes=[dtype], skip_conversion=False)
4351
return ops

0 commit comments

Comments
 (0)