@@ -6102,5 +6102,50 @@ def func(x):
6102
6102
x_val = make_xval ([2 , 3 ])
6103
6103
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
6104
6104
6105
+ @check_tf_min_version ("2.3.0" )
6106
+ @check_opset_min_version (16 , "ScatterND" )
6107
+ @skip_tfjs ("not supported in tfjs" )
6108
+ def test_tensor_scatter_max (self ):
6109
+ def func (tensor , indices , updates ):
6110
+ op = tf .tensor_scatter_nd_max (tensor , indices , updates )
6111
+ return tf .identity (op , name = _TFOUTPUT )
6112
+
6113
+ tensor_val = make_xval ([3 , 4 , 5 ])
6114
+ indices_val = np .array ([[2 , 3 ], [0 , 1 ]], np .int32 )
6115
+ indices64_val = indices_val .astype (np .int64 )
6116
+ updates_val = make_xval ([2 , 5 ]) + 3
6117
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : tensor_val , _INPUT1 : indices_val , _INPUT2 : updates_val })
6118
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : tensor_val , _INPUT1 : indices64_val , _INPUT2 : updates_val })
6119
+
6120
+ @check_tf_min_version ("2.3.0" )
6121
+ @check_opset_min_version (16 , "ScatterND" )
6122
+ @skip_tfjs ("not supported in tfjs" )
6123
+ def test_tensor_scatter_min (self ):
6124
+ def func (tensor , indices , updates ):
6125
+ op = tf .tensor_scatter_nd_min (tensor , indices , updates )
6126
+ return tf .identity (op , name = _TFOUTPUT )
6127
+
6128
+ tensor_val = make_xval ([3 , 4 , 5 ])
6129
+ indices_val = np .array ([[2 , 3 ], [0 , 1 ]], np .int32 )
6130
+ indices64_val = indices_val .astype (np .int64 )
6131
+ updates_val = make_xval ([2 , 5 ]) + 3
6132
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : tensor_val , _INPUT1 : indices_val , _INPUT2 : updates_val })
6133
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : tensor_val , _INPUT1 : indices64_val , _INPUT2 : updates_val })
6134
+
6135
+ @check_tf_min_version ("1.12.1" )
6136
+ @check_opset_min_version (16 , "ScatterND" )
6137
+ @skip_tfjs ("not supported in tfjs" )
6138
+ def test_tensor_scatter_sub (self ):
6139
+ def func (tensor , indices , updates ):
6140
+ op = tf .tensor_scatter_nd_sub (tensor , indices , updates )
6141
+ return tf .identity (op , name = _TFOUTPUT )
6142
+
6143
+ tensor_val = make_xval ([3 , 4 , 5 ])
6144
+ indices_val = np .array ([[2 , 3 ], [0 , 1 ]], np .int32 )
6145
+ indices64_val = indices_val .astype (np .int64 )
6146
+ updates_val = make_xval ([2 , 5 ]) + 3
6147
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : tensor_val , _INPUT1 : indices_val , _INPUT2 : updates_val })
6148
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : tensor_val , _INPUT1 : indices64_val , _INPUT2 : updates_val })
6149
+
6105
6150
if __name__ == '__main__' :
6106
6151
unittest_main ()
0 commit comments