Skip to content

Commit 256dd5c

Browse files
authored
Support for more TensorScatter* operations (#2179)
* Support for more TensorScatter* operations Added support for TensorScatterMax, TensorScatterMin and TensorScatterSub. Added tests for all TensorScatter* operations. Signed-off-by: Javier Dehesa <[email protected]>
1 parent 14f4a5f commit 256dd5c

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed

support_status.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@
254254
| TensorListSetItem | 7 ~ 17 |
255255
| TensorListStack | 7 ~ 17 |
256256
| TensorScatterAdd | 16 ~ 17 |
257+
| TensorScatterMax | 16 ~ 17 |
258+
| TensorScatterMin | 16 ~ 17 |
259+
| TensorScatterSub | 16 ~ 17 |
257260
| TensorScatterUpdate | 11 ~ 17 |
258261
| Tile | 1 ~ 17 |
259262
| TopKV2 | 1 ~ 17 |

tests/test_backend.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6102,5 +6102,50 @@ def func(x):
61026102
x_val = make_xval([2, 3])
61036103
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
61046104

6105+
@check_tf_min_version("2.3.0")
6106+
@check_opset_min_version(16, "ScatterND")
6107+
@skip_tfjs("not supported in tfjs")
6108+
def test_tensor_scatter_max(self):
6109+
def func(tensor, indices, updates):
6110+
op = tf.tensor_scatter_nd_max(tensor, indices, updates)
6111+
return tf.identity(op, name=_TFOUTPUT)
6112+
6113+
tensor_val = make_xval([3, 4, 5])
6114+
indices_val = np.array([[2, 3], [0, 1]], np.int32)
6115+
indices64_val = indices_val.astype(np.int64)
6116+
updates_val = make_xval([2, 5]) + 3
6117+
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val})
6118+
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val})
6119+
6120+
@check_tf_min_version("2.3.0")
6121+
@check_opset_min_version(16, "ScatterND")
6122+
@skip_tfjs("not supported in tfjs")
6123+
def test_tensor_scatter_min(self):
6124+
def func(tensor, indices, updates):
6125+
op = tf.tensor_scatter_nd_min(tensor, indices, updates)
6126+
return tf.identity(op, name=_TFOUTPUT)
6127+
6128+
tensor_val = make_xval([3, 4, 5])
6129+
indices_val = np.array([[2, 3], [0, 1]], np.int32)
6130+
indices64_val = indices_val.astype(np.int64)
6131+
updates_val = make_xval([2, 5]) + 3
6132+
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val})
6133+
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val})
6134+
6135+
@check_tf_min_version("1.12.1")
6136+
@check_opset_min_version(16, "ScatterND")
6137+
@skip_tfjs("not supported in tfjs")
6138+
def test_tensor_scatter_sub(self):
6139+
def func(tensor, indices, updates):
6140+
op = tf.tensor_scatter_nd_sub(tensor, indices, updates)
6141+
return tf.identity(op, name=_TFOUTPUT)
6142+
6143+
tensor_val = make_xval([3, 4, 5])
6144+
indices_val = np.array([[2, 3], [0, 1]], np.int32)
6145+
indices64_val = indices_val.astype(np.int64)
6146+
updates_val = make_xval([2, 5]) + 3
6147+
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices_val, _INPUT2: updates_val})
6148+
self._run_test_case(func, [_OUTPUT], {_INPUT: tensor_val, _INPUT1: indices64_val, _INPUT2: updates_val})
6149+
61056150
if __name__ == '__main__':
61066151
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,37 @@ def version_16(cls, ctx, node, **kwargs):
670670
node.set_attr("reduction", 'add')
671671

672672

673+
@tf_op("TensorScatterMax", onnx_op="ScatterND")
674+
class TensorScatterMax:
675+
@classmethod
676+
def version_16(cls, ctx, node, **kwargs):
677+
# indices input must be int64 in ONNX.
678+
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
679+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
680+
node.set_attr("reduction", 'max')
681+
682+
683+
@tf_op("TensorScatterMin", onnx_op="ScatterND")
684+
class TensorScatterMin:
685+
@classmethod
686+
def version_16(cls, ctx, node, **kwargs):
687+
# indices input must be int64 in ONNX.
688+
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
689+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
690+
node.set_attr("reduction", 'min')
691+
692+
693+
@tf_op("TensorScatterSub", onnx_op="ScatterND")
694+
class TensorScatterSub:
695+
@classmethod
696+
def version_16(cls, ctx, node, **kwargs):
697+
# indices input must be int64 in ONNX.
698+
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
699+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
700+
ctx.insert_new_node_on_input(node, "Neg", node.input[2])
701+
node.set_attr("reduction", 'add')
702+
703+
673704
@tf_op("TensorScatterUpdate", onnx_op="ScatterND")
674705
class TensorScatterUpdate:
675706
@classmethod

0 commit comments

Comments
 (0)