@@ -116,6 +116,32 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm):
116
116
}
117
117
self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
118
118
119
+ @parameterized .expand ([
120
+ ((2 , 3 , 4 , 5 ), [0 , 3 , 1 , 2 ], [0 , 2 , 3 , 1 ]),
121
+ ((2 , 3 , 4 , 5 , 6 ), [0 , 4 , 1 , 2 , 3 ], [0 , 2 , 3 , 4 , 1 ]),
122
+ ((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
123
+ ])
124
+ def test_transpose_with_split (self , input_shape , perm , inner_perm ):
125
+ input_shape_with_trans = [input_shape [i ] for i in perm ]
126
+ for axis in range (len (input_shape )):
127
+ output_before_trans = list (input_shape )
128
+ output_shape = [output_before_trans [i ] for i in perm ]
129
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = inner_perm , name = "trans1" )
130
+ node2 = helper .make_node ("Split" , ["Y" ], ["Z" ], axis = axis , name = "split" )
131
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = perm , name = "trans2" )
132
+
133
+ graph = helper .make_graph (
134
+ [node1 , node2 , node3 ],
135
+ "test_transpose_with_split" ,
136
+ [helper .make_tensor_value_info ("X" , TensorProto .INT64 , input_shape_with_trans )],
137
+ [helper .make_tensor_value_info ("res" , TensorProto .INT64 , output_shape )],
138
+ )
139
+
140
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
141
+
142
+ feed_dict = {"X" : np .random .randn (* input_shape_with_trans ).astype (np .int64 )}
143
+ self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 0 )
144
+
119
145
@parameterized .expand ([
120
146
((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
121
147
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
0 commit comments