Skip to content

Commit 0da3c78

Browse files
hwangdeyufatcat-z
andauthored
support Rint op (#1833)
Signed-off-by: hwangdeyu <[email protected]> Co-authored-by: fatcat-z <[email protected]>
1 parent 59cfa79 commit 0da3c78

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4609,6 +4609,14 @@ def func(x):
46094609
return tf.identity(x_, name=_TFOUTPUT)
46104610
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
46114611

4612+
@check_opset_min_version(11, "Round")
4613+
def test_rint(self):
4614+
x_val = np.array([-2.7, -1.5, -0.0, +0.0, 0.3, 0.5, 1.5, 2.5, 3.4, 3.5, float('nan')], dtype=np.float32)
4615+
def func(x):
4616+
x_ = tf.math.rint(x)
4617+
return tf.identity(x_, name=_TFOUTPUT)
4618+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4619+
46124620
@check_opset_min_version(11, "Det")
46134621
@unittest.skip("unclear how this is called in tf-2, fix later")
46144622
def test_determinant(self):

tf2onnx/onnx_opset/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,15 @@ def version_11(cls, ctx, node, **kwargs):
559559
pass
560560

561561

562+
@tf_op("Rint", onnx_op="Round")
563+
class Rint:
564+
@classmethod
565+
def version_11(cls, ctx, node, **kwargs):
566+
# Same with tf round, two different people just happened to write the function.
567+
# https://github.com/tensorflow/tensorflow/issues/709
568+
pass
569+
570+
562571
@tf_op("MatrixDeterminant", onnx_op="Det")
563572
class Det:
564573
@classmethod

0 commit comments

Comments
 (0)