@@ -116,6 +116,52 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm):
116
116
}
117
117
self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
118
118
119
+ @parameterized .expand ([
120
+ ((2 , 3 , 4 , 5 ), [0 , 3 , 1 , 2 ], [0 , 2 , 3 , 1 ]),
121
+ ((2 , 3 , 4 , 5 , 6 ), [0 , 4 , 1 , 2 , 3 ], [0 , 2 , 3 , 4 , 1 ]),
122
+ ((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
123
+ ])
124
+ def test_transpose_with_split (self , input_shape , perm , inner_perm ):
125
+ input_shape_with_trans = [input_shape [i ] for i in perm ]
126
+ output_before_trans = list (input_shape )
127
+ output_shape = [output_before_trans [i ] for i in perm ]
128
+ for axis in range (len (input_shape )):
129
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = inner_perm , name = "trans1" )
130
+ node2 = helper .make_node ("Split" , ["Y" ], ["Z" ], axis = axis , name = "split" )
131
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = perm , name = "trans2" )
132
+
133
+ graph = helper .make_graph (
134
+ [node1 , node2 , node3 ],
135
+ "test_transpose_with_split" ,
136
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape_with_trans )],
137
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , output_shape )],
138
+ )
139
+
140
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
141
+ feed_dict = {"X" : np .random .randn (* input_shape_with_trans ).astype (np .float32 )}
142
+ self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 0 )
143
+
144
+ @parameterized .expand ([
145
+ ((1 , - 1 ), (1 , 1710 ), (1710 ,), [1 , 0 ]),
146
+ ((3 , 1 , 1 , 5 , - 1 ), (3 , 1 , 1 , 5 , 6 ), (3 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ]),
147
+ ])
148
+ @check_opset_max_version (12 , "split attribute changed to input in opset 13" )
149
+ def test_transpose_with_split_dynamic_shape (self , input_shape , specific_input , output_shape , perm ):
150
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm , name = "trans" )
151
+ node2 = helper .make_node ("Split" , ["Y" ], ["Z" ], axis = 1 , split = [1 ], name = "split" )
152
+ node3 = helper .make_node ("Squeeze" , ["Z" ], ["B" ], name = "squeeze" )
153
+
154
+ graph = helper .make_graph (
155
+ [node1 , node2 , node3 ],
156
+ "test_transpose_with_split_dynamic_shape" ,
157
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
158
+ [helper .make_tensor_value_info ("B" , TensorProto .FLOAT , output_shape )],
159
+ )
160
+
161
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
162
+ self .run_transpose_compare (["B" ], {"X" : np .random .randn (* specific_input ).astype (np .float32 )},
163
+ model_proto , remaining_transpose_num = 0 )
164
+
119
165
@parameterized .expand ([
120
166
((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
121
167
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
0 commit comments