Skip to content

Commit c983cf4

Browse files
Fix problem with adding more than one tf.newaxis at the same time
Signed-off-by: southfreebird <[email protected]> Co-authored-by: southfreebird <[email protected]> Co-authored-by: iolkhovsky <[email protected]> Signed-off-by: southfreebird <[email protected]>
1 parent 71105c1 commit c983cf4

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5893,5 +5893,21 @@ def func(x):
58935893
x_val = make_xval([3, 4])
58945894
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
58955895

5896+
@check_opset_min_version(10, "Slice")
5897+
def test_addition_two_newaxis_simultaneously(self):
5898+
x_val = make_xval([2, 3])
5899+
def func(x):
5900+
op = x[..., tf.newaxis, tf.newaxis]
5901+
return tf.identity(op, name=_TFOUTPUT)
5902+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5903+
5904+
@check_opset_min_version(10, "Slice")
5905+
def test_addition_three_newaxis_simultaneously(self):
5906+
x_val = make_xval([2, 3])
5907+
def func(x):
5908+
op = x[..., tf.newaxis, tf.newaxis, tf.newaxis]
5909+
return tf.identity(op, name=_TFOUTPUT)
5910+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5911+
58965912
if __name__ == '__main__':
58975913
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,29 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
974974
begin_mask |= 1 << bit
975975
end_mask |= 1 << bit
976976

977+
if ellipsis_mask:
978+
unqueeze_at = []
979+
ellipsis_gap = 0
980+
num_new = 0
981+
end_mask = node.get_attr("end_mask")
982+
end_mask = end_mask.i if end_mask is not None else 0
983+
begin_mask = node.get_attr("begin_mask")
984+
begin_mask = begin_mask.i if begin_mask is not None else 0
985+
986+
for bit in range(32):
987+
new_axis_flag = (new_axis_mask >> bit) & 1
988+
ellipsis_flag = (ellipsis_mask >> bit) & 1
989+
num_new += not ellipsis_flag and new_axis_flag
990+
991+
for bit in range(32):
992+
if (ellipsis_mask >> bit) & 1:
993+
ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1
994+
elif (new_axis_mask >> bit) & 1:
995+
effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1
996+
unqueeze_at.append(effective_bit)
997+
begin_mask |= 1 << bit
998+
end_mask |= 1 << bit
999+
9771000
input_x = GraphBuilder(ctx).make_unsqueeze(
9781001
{'data': input_x, 'axes': unqueeze_at})
9791002

0 commit comments

Comments
 (0)