@@ -1369,6 +1369,130 @@ def test_transpose_argmax(self):
1369
1369
self .run_transpose_compare (["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )},
1370
1370
model_proto , remaining_transpose_num = 0 )
1371
1371
1372
+ @check_opset_max_version (
1373
+ 12 , "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
1374
+ )
1375
+ def test_transpose_softmax_valid_perm (self ):
1376
+ input_shape = [4 , 4 , 4 , 4 ]
1377
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
1378
+ node1 = helper .make_node ("Softmax" , ["Y" ], ["Z" ], axis = 1 , name = "softmax" )
1379
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
1380
+
1381
+ graph = helper .make_graph (
1382
+ [node0 , node1 , node2 ],
1383
+ "transpose-softmax-test" ,
1384
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
1385
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
1386
+ )
1387
+
1388
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1389
+ self .run_transpose_compare (
1390
+ ["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )}, model_proto , remaining_transpose_num = 0
1391
+ )
1392
+
1393
+ @check_opset_max_version (
1394
+ 12 , "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
1395
+ )
1396
+ def test_transpose_softmax_invalid_perm (self ):
1397
+ input_shape = [4 , 4 , 4 , 4 ]
1398
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
1399
+ node1 = helper .make_node ("Softmax" , ["Y" ], ["Z" ], axis = 3 , name = "softmax" )
1400
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
1401
+
1402
+ graph = helper .make_graph (
1403
+ [node0 , node1 , node2 ],
1404
+ "transpose-softmax-test" ,
1405
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
1406
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
1407
+ )
1408
+
1409
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1410
+ self .run_transpose_compare (
1411
+ ["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )}, model_proto , remaining_transpose_num = 2
1412
+ )
1413
+
1414
+ @check_opset_min_version (13 , "Softmax can be optimized for all permutations since opset 13" )
1415
+ def test_transpose_softmax_13 (self ):
1416
+ input_shape = [4 , 4 , 4 , 4 ]
1417
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
1418
+ node1 = helper .make_node ("Softmax" , ["Y" ], ["Z" ], axis = 3 , name = "softmax" )
1419
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
1420
+
1421
+ graph = helper .make_graph (
1422
+ [node0 , node1 , node2 ],
1423
+ "transpose-softmax-test" ,
1424
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
1425
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
1426
+ )
1427
+
1428
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1429
+ self .run_transpose_compare (
1430
+ ["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )}, model_proto , remaining_transpose_num = 0
1431
+ )
1432
+
1433
+ @check_opset_max_version (
1434
+ 12 ,
1435
+ "Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations" ,
1436
+ )
1437
+ def test_transpose_logsoftmax_valid_perm (self ):
1438
+ input_shape = [4 , 4 , 4 , 4 ]
1439
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
1440
+ node1 = helper .make_node ("LogSoftmax" , ["Y" ], ["Z" ], axis = 1 , name = "logsoftmax" )
1441
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
1442
+
1443
+ graph = helper .make_graph (
1444
+ [node0 , node1 , node2 ],
1445
+ "transpose-logsoftmax-test" ,
1446
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
1447
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
1448
+ )
1449
+
1450
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1451
+ self .run_transpose_compare (
1452
+ ["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )}, model_proto , remaining_transpose_num = 0
1453
+ )
1454
+
1455
+ @check_opset_max_version (
1456
+ 12 ,
1457
+ "Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations" ,
1458
+ )
1459
+ def test_transpose_logsoftmax_invalid_perm (self ):
1460
+ input_shape = [4 , 4 , 4 , 4 ]
1461
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
1462
+ node1 = helper .make_node ("LogSoftmax" , ["Y" ], ["Z" ], axis = 3 , name = "logsoftmax" )
1463
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
1464
+
1465
+ graph = helper .make_graph (
1466
+ [node0 , node1 , node2 ],
1467
+ "transpose-logsoftmax-test" ,
1468
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
1469
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
1470
+ )
1471
+
1472
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1473
+ self .run_transpose_compare (
1474
+ ["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )}, model_proto , remaining_transpose_num = 2
1475
+ )
1476
+
1477
+ @check_opset_min_version (13 , "LogSoftmax can be optimized for all permutations since opset 13" )
1478
+ def test_transpose_logsoftmax_13 (self ):
1479
+ input_shape = [4 , 4 , 4 , 4 ]
1480
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
1481
+ node1 = helper .make_node ("LogSoftmax" , ["Y" ], ["Z" ], axis = 3 , name = "logsoftmax" )
1482
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
1483
+
1484
+ graph = helper .make_graph (
1485
+ [node0 , node1 , node2 ],
1486
+ "transpose-logsoftmax-test" ,
1487
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
1488
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
1489
+ )
1490
+
1491
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
1492
+ self .run_transpose_compare (
1493
+ ["res" ], {"X" : np .random .randn (* input_shape ).astype (np .float32 )}, model_proto , remaining_transpose_num = 0
1494
+ )
1495
+
1372
1496
def test_transpose_tile (self ):
1373
1497
input_shape = [1 , 2 , 3 , 4 ]
1374
1498
0 commit comments