Skip to content

Commit 9cea907

Browse files
Transpose optimization for Softmax and LogSoftmax (fixes #1716) (#1964)
* Transpose optimization for Softmax and LogSoftmax (fixes #1716) In opsets 13 and higher, the axis of the operation is arbitrary and can simply be changed according to the permutation of the Transpose. In lower opsets, Softmax always coerces its inputs to a 2D tensor, making Transpose operations necessary if the permutation moves axes between the coerced batch and feature dimensions. Signed-off-by: janbernloehr <[email protected]> Co-authored-by: fthielke <[email protected]>
1 parent a8f78ac commit 9cea907

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

tests/test_optimizers.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,130 @@ def test_transpose_argmax(self):
13691369
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
13701370
model_proto, remaining_transpose_num=0)
13711371

1372+
@check_opset_max_version(
1373+
12, "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
1374+
)
1375+
def test_transpose_softmax_valid_perm(self):
1376+
input_shape = [4, 4, 4, 4]
1377+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1378+
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=1, name="softmax")
1379+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
1380+
1381+
graph = helper.make_graph(
1382+
[node0, node1, node2],
1383+
"transpose-softmax-test",
1384+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1385+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
1386+
)
1387+
1388+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1389+
self.run_transpose_compare(
1390+
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
1391+
)
1392+
1393+
@check_opset_max_version(
1394+
12, "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
1395+
)
1396+
def test_transpose_softmax_invalid_perm(self):
1397+
input_shape = [4, 4, 4, 4]
1398+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1399+
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
1400+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
1401+
1402+
graph = helper.make_graph(
1403+
[node0, node1, node2],
1404+
"transpose-softmax-test",
1405+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1406+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
1407+
)
1408+
1409+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1410+
self.run_transpose_compare(
1411+
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=2
1412+
)
1413+
1414+
@check_opset_min_version(13, "Softmax can be optimized for all permutations since opset 13")
1415+
def test_transpose_softmax_13(self):
1416+
input_shape = [4, 4, 4, 4]
1417+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1418+
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
1419+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
1420+
1421+
graph = helper.make_graph(
1422+
[node0, node1, node2],
1423+
"transpose-softmax-test",
1424+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1425+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
1426+
)
1427+
1428+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1429+
self.run_transpose_compare(
1430+
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
1431+
)
1432+
1433+
@check_opset_max_version(
1434+
12,
1435+
"Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations",
1436+
)
1437+
def test_transpose_logsoftmax_valid_perm(self):
1438+
input_shape = [4, 4, 4, 4]
1439+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1440+
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=1, name="logsoftmax")
1441+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
1442+
1443+
graph = helper.make_graph(
1444+
[node0, node1, node2],
1445+
"transpose-logsoftmax-test",
1446+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1447+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
1448+
)
1449+
1450+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1451+
self.run_transpose_compare(
1452+
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
1453+
)
1454+
1455+
@check_opset_max_version(
1456+
12,
1457+
"Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations",
1458+
)
1459+
def test_transpose_logsoftmax_invalid_perm(self):
1460+
input_shape = [4, 4, 4, 4]
1461+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1462+
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=3, name="logsoftmax")
1463+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
1464+
1465+
graph = helper.make_graph(
1466+
[node0, node1, node2],
1467+
"transpose-logsoftmax-test",
1468+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1469+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
1470+
)
1471+
1472+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1473+
self.run_transpose_compare(
1474+
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=2
1475+
)
1476+
1477+
@check_opset_min_version(13, "LogSoftmax can be optimized for all permutations since opset 13")
1478+
def test_transpose_logsoftmax_13(self):
1479+
input_shape = [4, 4, 4, 4]
1480+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1481+
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=3, name="logsoftmax")
1482+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
1483+
1484+
graph = helper.make_graph(
1485+
[node0, node1, node2],
1486+
"transpose-logsoftmax-test",
1487+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1488+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
1489+
)
1490+
1491+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1492+
self.run_transpose_compare(
1493+
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
1494+
)
1495+
13721496
def test_transpose_tile(self):
13731497
input_shape = [1, 2, 3, 4]
13741498

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def _initialize_handlers(self):
205205
"Identity": self._identity_handler,
206206
"LeakyRelu": self._simple_through_handler,
207207
"Log": self._simple_through_handler,
208+
"LogSoftmax": self._softmax_handler,
208209
"Max": self._maxmin_handler,
209210
"Min": self._maxmin_handler,
210211
"Mul": self._mul_handler,
@@ -223,6 +224,7 @@ def _initialize_handlers(self):
223224
"Relu": self._simple_through_handler,
224225
"Shape": self._shape_handler,
225226
"Sigmoid": self._simple_through_handler,
227+
"Softmax": self._softmax_handler,
226228
"Sum": self._sum_handler,
227229
"Slice": self._slice_handler,
228230
"Split": self._split_handler,
@@ -827,6 +829,28 @@ def permute_pads(pads):
827829
def _prelu_handler(self, trans, node):
828830
return self._handle_node_having_branches(trans, node)
829831

832+
def _softmax_handler(self, trans, node):
833+
trans_rank = get_transpose_rank(trans)
834+
perm = trans.get_attr("perm").ints
835+
836+
if self._g.opset >= 13:
837+
# Softmax operates on an arbitrary axis since opset 13
838+
axis = node.get_attr_value("axis", -1)
839+
new_axis = perm[axis + trans_rank if axis < 0 else axis]
840+
if not self._switch_transpose_and_node(node, trans):
841+
return False
842+
node.set_attr("axis", new_axis)
843+
return True
844+
845+
# For older opsets, the "axis" attribute determines the coercion point for coercing the input tensor to 2D.
846+
# We can safely switch transpose and node if the permutation does not make any axes cross that boundary.
847+
coercion_axis = node.get_attr_value("axis", 1)
848+
for from_axis, to_axis in enumerate(perm):
849+
if (from_axis < coercion_axis <= to_axis) or (from_axis >= coercion_axis > to_axis):
850+
return False
851+
852+
return self._switch_transpose_and_node(node, trans)
853+
830854
def _arg_min_max_handler(self, trans, node):
831855
axis = node.get_attr_value("axis", 0)
832856
node.set_attr("axes", [axis])

0 commit comments

Comments
 (0)