Skip to content

Commit 72d6460

Browse files
authored
Handle 'Neg' nodes in transpose optimzier (#1785)
Signed-off-by: Jan Haug <[email protected]>
1 parent d018aa6 commit 72d6460

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

tests/test_optimizers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,27 @@ def test_transpose_add_with_conv_2(self, input_shape, weights_shape, output_shap
11091109
"W": np.random.randn(*weights_shape).astype(np.float32)},
11101110
model_proto, remaining_transpose_num=0)
11111111

1112+
@parameterized.expand([
1113+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
1114+
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
1115+
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1116+
])
1117+
def test_transpose_neg(self, shape, perm_input, perm_output):
1118+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans1")
1119+
node1 = helper.make_node("Neg", ["Y"], ["Z"], name="neg")
1120+
node2 = helper.make_node("Transpose", ["Z"], ["OUT"], perm=perm_output, name="trans2")
1121+
1122+
graph = helper.make_graph(
1123+
[node0, node1, node2],
1124+
"transpose-neg-test",
1125+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
1126+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, shape)],
1127+
)
1128+
1129+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1130+
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*shape).astype(np.float32)},
1131+
model_proto, remaining_transpose_num=0)
1132+
11121133
@parameterized.expand([
11131134
((3, 4, 5), (8, 4, 6), [1, 3, 0, 0, 2, 0], [2, 0, 1], [1, 2, 0]),
11141135
((1, 3, 4, 5), (2, 6, 4, 8), [1, 0, 1, 3, 0, 0, 2, 0], [0, 2, 3, 1], [0, 3, 1, 2]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _initialize_handlers(self):
208208
"Max": self._maxmin_handler,
209209
"Min": self._maxmin_handler,
210210
"Mul": self._mul_handler,
211+
"Neg": self._simple_through_handler,
211212
"Pad": self._pad_handler,
212213
"PRelu": self._prelu_handler,
213214
"Reciprocal": self._simple_through_handler,

0 commit comments

Comments
 (0)