@@ -1186,6 +1186,41 @@ def test_transpose_add_with_conv_2(self, input_shape, weights_shape, output_shap
1186
1186
"W" : np .random .randn (* weights_shape ).astype (np .float32 )},
1187
1187
model_proto , remaining_transpose_num = 0 )
1188
1188
1189
+ @parameterized .expand ([
1190
+ ((1 , 1 , 5 , 5 ), (1 , 1 , 3 , 3 ), (1 , 1 , 3 , 3 ), (1 , 3 , 1 , 1 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1191
+ ((1 , 1 , 5 , 5 , 5 ), (1 , 1 , 3 , 3 , 3 ), (1 , 1 , 3 , 3 , 3 ), (1 , 3 , 1 , 1 , 1 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1192
+ ((1 , 1 , 5 , 5 ), (1 , 1 , 3 , 3 ), (3 , 3 , 3 , 3 ), (3 , 1 , 3 , 3 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1193
+ ((1 , 1 , 5 , 5 , 5 ), (1 , 1 , 3 , 3 , 3 ), (3 , 3 , 3 , 3 , 3 ), (3 , 3 , 1 , 3 , 3 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1194
+ ((1 , 1 , 5 , 5 ), (1 , 1 , 3 , 3 ), (1 , 3 , 3 , 3 ), (1 , 1 , 3 , 3 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1195
+ ((1 , 1 , 5 , 5 , 5 ), (1 , 1 , 3 , 3 , 3 ), (1 , 3 , 3 , 3 , 3 ), (1 , 3 , 1 , 3 , 3 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1196
+ ])
1197
+ def test_transpose_mul_with_conv (self , input_shape , weights_shape , output_shape ,
1198
+ const_shape , perm_input , perm_output ):
1199
+ const_b_val = np .random .randn (* const_shape ).astype (np .float32 )
1200
+ const_b = helper .make_tensor ("const_b" , TensorProto .FLOAT , const_shape , const_b_val .flatten ())
1201
+ const_b_node = helper .make_node ("Constant" , [], ["const_b" ], value = const_b , name = "const_b" )
1202
+
1203
+ const_w_val = np .random .randn (* weights_shape ).astype (np .float32 )
1204
+ const_w = helper .make_tensor ("const_w" , TensorProto .FLOAT , weights_shape , const_w_val .flatten ())
1205
+ const_w_node = helper .make_node ("Constant" , [], ["const_w" ], value = const_w , name = "const_w" )
1206
+
1207
+ node0 = helper .make_node ("Conv" , ["x" , "const_w" ], ["X" ], name = "conv" , pads = [0 ] * 2 * (len (input_shape ) - 2 ))
1208
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm_input , name = "trans_1" )
1209
+ node2 = helper .make_node ("Mul" , ["Y" , "const_b" ], ["Z" ], name = "mul" )
1210
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = perm_output , name = "trans_2" )
1211
+
1212
+ graph = helper .make_graph (
1213
+ [const_b_node , const_w_node , node0 , node1 , node2 , node3 ],
1214
+ "transpose-mul-test-with-conv" ,
1215
+ [helper .make_tensor_value_info ("x" , TensorProto .FLOAT , input_shape )],
1216
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , output_shape )],
1217
+ )
1218
+
1219
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1220
+ remaining_transpose_num = 1 if const_shape [0 ] != 1 or any (shape != 1 for shape in const_shape [2 :]) else 0
1221
+ self .run_transpose_compare (["res" ], {"x" : np .random .randn (* input_shape ).astype (np .float32 )},
1222
+ model_proto , remaining_transpose_num = remaining_transpose_num )
1223
+
1189
1224
@parameterized .expand ([
1190
1225
((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
1191
1226
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
0 commit comments