@@ -133,15 +133,34 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm):
133
133
graph = helper .make_graph (
134
134
[node1 , node2 , node3 ],
135
135
"test_transpose_with_split" ,
136
- [helper .make_tensor_value_info ("X" , TensorProto .INT64 , input_shape_with_trans )],
137
- [helper .make_tensor_value_info ("res" , TensorProto .INT64 , output_shape )],
136
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape_with_trans )],
137
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , output_shape )],
138
138
)
139
139
140
140
model_proto = self .make_model (graph , producer_name = "onnx-tests" )
141
-
142
- feed_dict = {"X" : np .random .randn (* input_shape_with_trans ).astype (np .int64 )}
141
+ feed_dict = {"X" : np .random .randn (* input_shape_with_trans ).astype (np .float32 )}
143
142
self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 0 )
144
143
144
+ @parameterized .expand ([
145
+ ((1 , - 1 ), (1 , 1710 ), (1710 ,), [1 , 0 ]),
146
+ ((3 , 1 , 1 , 5 , - 1 ), (3 , 1 , 1 , 5 , 6 ), (3 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ]),
147
+ ])
148
+ def test_transpose_with_split_dynamic_shape (self , input_shape , specific_input , output_shape , perm ):
149
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm , name = "trans" )
150
+ node2 = helper .make_node ("Split" , ["Y" ], ["Z" ], axis = 1 , split = [1 ], name = "split" )
151
+ node3 = helper .make_node ("Squeeze" , ["Z" ], ["B" ], name = "squeeze" )
152
+
153
+ graph = helper .make_graph (
154
+ [node1 , node2 , node3 ],
155
+ "test_transpose_with_split_dynamic_shape" ,
156
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
157
+ [helper .make_tensor_value_info ("B" , TensorProto .FLOAT , output_shape )],
158
+ )
159
+
160
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
161
+ self .run_transpose_compare (["B" ], {"X" : np .random .randn (* specific_input ).astype (np .float32 )},
162
+ model_proto , remaining_transpose_num = 0 )
163
+
145
164
@parameterized .expand ([
146
165
((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
147
166
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
0 commit comments