From fcd10447abb20e50ed2d67e313c2f75566319649 Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Fri, 23 Jun 2023 13:39:36 +0800 Subject: [PATCH 1/3] add more type case for tensor.zeros --- src/TensorFlowNET.Core/Operations/array_ops.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index a0b47aace..24c392155 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -84,8 +84,13 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); Tensor zeros = dtype switch { + TF_DataType.TF_BOOL => constant(false), TF_DataType.TF_DOUBLE => constant(0d), TF_DataType.TF_FLOAT => constant(0f), + TF_DataType.TF_INT64 => constant(0L), + TF_DataType.TF_UINT64 => constant((ulong)0), + TF_DataType.TF_INT32 => constant(0), + TF_DataType.TF_UINT32 => constant((uint)0), TF_DataType.TF_INT8 => constant((sbyte)0), TF_DataType.TF_UINT8 => constant((byte)0), _ => constant(0) @@ -108,9 +113,15 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT return _constant_if_small(0.0F, shape, dtype, name); case TF_DataType.TF_INT64: return _constant_if_small(0L, shape, dtype, name); + case TF_DataType.TF_UINT64: + return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT32: return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT32: + return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT8: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT8: return _constant_if_small(0, shape, dtype, name); default: throw new TypeError("can't find type for zeros"); From e749aaeaae197464f817e1c7bfffe6f922d55b6a Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Fri, 23 Jun 2023 14:04:44 +0800 Subject: [PATCH 2/3] add more implicit operator for NDArray and UnitTest for `keras.datasets.imdb` --- src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs | 6 ++++++ .../TensorFlowNET.UnitTest/Dataset/DatasetTest.cs | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index fd4f93fc1..45b236c7b 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -107,9 +107,15 @@ public unsafe static implicit operator double(NDArray nd) public static implicit operator NDArray(bool value) => new NDArray(value); + public static implicit operator NDArray(byte value) + => new NDArray(value); + public static implicit operator NDArray(int value) => new NDArray(value); + public static implicit operator NDArray(long value) + => new NDArray(value); + public static implicit operator NDArray(float value) => new NDArray(value); diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 8317346ea..875e50019 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -2,6 +2,7 @@ using System; using System.Linq; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace TensorFlowNET.UnitTest.Dataset { @@ -195,5 +196,19 @@ public void Shuffle() Assert.IsFalse(allEqual); } + [TestMethod] + public void GetData() + { + var vocab_size = 20000; + var dataset = keras.datasets.imdb.load_data(num_words: vocab_size); + var x_train = dataset.Train.Item1; + Assert.AreEqual(x_train.dims[0], 25000); + var y_train = dataset.Train.Item2; + Assert.AreEqual(y_train.dims[0], 25000); + var x_val = dataset.Test.Item1; + Assert.AreEqual(x_val.dims[0], 25000); + var y_val = dataset.Test.Item2; + Assert.AreEqual(y_val.dims[0], 25000); + } } } From c23b24633fa1111d613deeedba5c9869ea463dd8 Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Fri, 23 Jun 2023 14:21:27 +0800 Subject: [PATCH 3/3] remove UnitTest for `keras.datasets.imdb` --- .../TensorFlowNET.UnitTest/Dataset/DatasetTest.cs | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 875e50019..8317346ea 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -2,7 +2,6 @@ using System; using System.Linq; using static Tensorflow.Binding; -using static Tensorflow.KerasApi; namespace TensorFlowNET.UnitTest.Dataset { @@ -196,19 +195,5 @@ public void Shuffle() Assert.IsFalse(allEqual); } - [TestMethod] - public void GetData() - { - var vocab_size = 20000; - var dataset = keras.datasets.imdb.load_data(num_words: vocab_size); - var x_train = dataset.Train.Item1; - Assert.AreEqual(x_train.dims[0], 25000); - var y_train = dataset.Train.Item2; - Assert.AreEqual(y_train.dims[0], 25000); - var x_val = dataset.Test.Item1; - Assert.AreEqual(x_val.dims[0], 25000); - var y_val = dataset.Test.Item2; - Assert.AreEqual(y_val.dims[0], 25000); - } } }