Skip to content

Commit 5cd3b5b

Browse files
hwangdeyufatcat-z
andauthored
add unit32 unit64 type support (#1808)
fixes #1802 Signed-off-by: hwangdeyu <[email protected]> Co-authored-by: fatcat-z <[email protected]>
1 parent 6691850 commit 5cd3b5b

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

tests/test_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,6 +2353,12 @@ def func(x):
23532353
return tf.identity(x_, name=_TFOUTPUT)
23542354
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
23552355

2356+
x_val = np.array([1, 2, 3, 4], dtype=np.uint32).reshape((2, 2))
2357+
def func(x):
2358+
x_ = tf.cast(x, tf.uint64)
2359+
return tf.identity(x_, name=_TFOUTPUT)
2360+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2361+
23562362
@check_opset_min_version(7, "sign")
23572363
def test_sign(self):
23582364
x_vals = [np.array([1.0, 2.0, 0.0, -1.0, 0.0, -2.0], dtype=np.float32).reshape((2, 3)),

tf2onnx/tf_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
types_pb2.DT_INT8: onnx_pb.TensorProto.INT8,
3535
types_pb2.DT_UINT8: onnx_pb.TensorProto.UINT8,
3636
types_pb2.DT_UINT16: onnx_pb.TensorProto.UINT16,
37+
types_pb2.DT_UINT32: onnx_pb.TensorProto.UINT32,
38+
types_pb2.DT_UINT64: onnx_pb.TensorProto.UINT64,
3739
types_pb2.DT_INT64: onnx_pb.TensorProto.INT64,
3840
types_pb2.DT_STRING: onnx_pb.TensorProto.STRING,
3941
types_pb2.DT_COMPLEX64: onnx_pb.TensorProto.COMPLEX64,

tf2onnx/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
onnx_pb.TensorProto.INT8: np.int8,
3939
onnx_pb.TensorProto.UINT8: np.uint8,
4040
onnx_pb.TensorProto.UINT16: np.uint16,
41+
onnx_pb.TensorProto.UINT32: np.uint32,
42+
onnx_pb.TensorProto.UINT64: np.uint64,
4143
onnx_pb.TensorProto.INT64: np.int64,
4244
onnx_pb.TensorProto.UINT64: np.uint64,
4345
onnx_pb.TensorProto.BOOL: np.bool,
@@ -58,6 +60,8 @@
5860
onnx_pb.TensorProto.INT8: "int8",
5961
onnx_pb.TensorProto.UINT8: "uint8",
6062
onnx_pb.TensorProto.UINT16: "uint16",
63+
onnx_pb.TensorProto.UINT32: "uint32",
64+
onnx_pb.TensorProto.UINT64: "uint64",
6165
onnx_pb.TensorProto.INT64: "int64",
6266
onnx_pb.TensorProto.STRING: "string",
6367
onnx_pb.TensorProto.BOOL: "bool",

0 commit comments

Comments
 (0)