Skip to content

Commit 486ee01

Browse files
committed
Fix problem with adding more than one tf.newaxis at the same time
1 parent 71105c1 commit 486ee01

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5893,5 +5893,19 @@ def func(x):
58935893
x_val = make_xval([3, 4])
58945894
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
58955895

5896+
def test_addition_two_newaxis_simultaneously(self):
5897+
x_val = make_xval([2, 0])
5898+
def func(x):
5899+
op = x[..., tf.newaxis, tf.newaxis]
5900+
return tf.identity(op, name=_TFOUTPUT)
5901+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5902+
5903+
def test_addition_three_newaxis_simultaneously(self):
5904+
x_val = make_xval([2, 0])
5905+
def func(x):
5906+
op = x[..., tf.newaxis, tf.newaxis, tf.newaxis]
5907+
return tf.identity(op, name=_TFOUTPUT)
5908+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5909+
58965910
if __name__ == '__main__':
58975911
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,29 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
974974
begin_mask |= 1 << bit
975975
end_mask |= 1 << bit
976976

977+
if ellipsis_mask:
978+
unqueeze_at = []
979+
ellipsis_gap = 0
980+
num_new = 0
981+
end_mask = node.get_attr("end_mask")
982+
end_mask = end_mask.i if end_mask is not None else 0
983+
begin_mask = node.get_attr("begin_mask")
984+
begin_mask = begin_mask.i if begin_mask is not None else 0
985+
986+
for bit in range(32):
987+
new_axis_flag = (new_axis_mask >> bit) & 1
988+
ellipsis_flag = (ellipsis_mask >> bit) & 1
989+
num_new += not ellipsis_flag and new_axis_flag
990+
991+
for bit in range(32):
992+
if (ellipsis_mask >> bit) & 1:
993+
ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1
994+
elif (new_axis_mask >> bit) & 1:
995+
effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1
996+
unqueeze_at.append(effective_bit)
997+
begin_mask |= 1 << bit
998+
end_mask |= 1 << bit
999+
9771000
input_x = GraphBuilder(ctx).make_unsqueeze(
9781001
{'data': input_x, 'axes': unqueeze_at})
9791002

0 commit comments

Comments
 (0)