Skip to content

Commit 14f4a5f

Browse files
authored
Fix _mul_handler in TransposeOptimizer (#2088) (#2152)
Should not fuse Conv and Mul if the constant input of Mul is not with shape [N] or [1,..,1,N]. Signed-off-by: cosine <[email protected]>
1 parent 26bdcff commit 14f4a5f

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tests/test_optimizers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,41 @@ def test_transpose_add_with_conv_2(self, input_shape, weights_shape, output_shap
11861186
"W": np.random.randn(*weights_shape).astype(np.float32)},
11871187
model_proto, remaining_transpose_num=0)
11881188

1189+
@parameterized.expand([
1190+
((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), (1, 3, 1, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
1191+
((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1, 1, 3, 3, 3), (1, 3, 1, 1, 1), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1192+
((1, 1, 5, 5), (1, 1, 3, 3), (3, 3, 3, 3), (3, 1, 3, 3), [0, 2, 3, 1], [0, 3, 1, 2]),
1193+
((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (3, 3, 3, 3, 3), (3, 3, 1, 3, 3), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1194+
((1, 1, 5, 5), (1, 1, 3, 3), (1, 3, 3, 3), (1, 1, 3, 3), [0, 2, 3, 1], [0, 3, 1, 2]),
1195+
((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1, 3, 3, 3, 3), (1, 3, 1, 3, 3), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1196+
])
1197+
def test_transpose_mul_with_conv(self, input_shape, weights_shape, output_shape,
1198+
const_shape, perm_input, perm_output):
1199+
const_b_val = np.random.randn(*const_shape).astype(np.float32)
1200+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, const_shape, const_b_val.flatten())
1201+
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
1202+
1203+
const_w_val = np.random.randn(*weights_shape).astype(np.float32)
1204+
const_w = helper.make_tensor("const_w", TensorProto.FLOAT, weights_shape, const_w_val.flatten())
1205+
const_w_node = helper.make_node("Constant", [], ["const_w"], value=const_w, name="const_w")
1206+
1207+
node0 = helper.make_node("Conv", ["x", "const_w"], ["X"], name="conv", pads=[0] * 2 * (len(input_shape) - 2))
1208+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
1209+
node2 = helper.make_node("Mul", ["Y", "const_b"], ["Z"], name="mul")
1210+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm_output, name="trans_2")
1211+
1212+
graph = helper.make_graph(
1213+
[const_b_node, const_w_node, node0, node1, node2, node3],
1214+
"transpose-mul-test-with-conv",
1215+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape)],
1216+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
1217+
)
1218+
1219+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1220+
remaining_transpose_num = 1 if const_shape[0] != 1 or any(shape != 1 for shape in const_shape[2:]) else 0
1221+
self.run_transpose_compare(["res"], {"x": np.random.randn(*input_shape).astype(np.float32)},
1222+
model_proto, remaining_transpose_num=remaining_transpose_num)
1223+
11891224
@parameterized.expand([
11901225
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
11911226
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ def _mul_handler(self, trans, node):
573573
# make sure conv don't have bias set
574574
can_opt = t_p.type == "Conv" and t_p.inputs[1].is_const() and len(t_p.input) == 2 and trans_rank == 4
575575
can_opt = can_opt and self._nodes_has_single_consumer_node([t_p])
576+
# make sure multiplier with shape (N,) or (1, N) or (1, 1, N) ....
577+
can_opt = can_opt and trans.get_attr_value("perm") == NCHW_TO_NHWC \
578+
and all(shape == 1 for shape in multiplier.shape[:-1])
576579
if can_opt:
577580
conv = t_p
578581
numpy_val = conv.inputs[1].get_tensor_value(as_list=False)

0 commit comments

Comments
 (0)