diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index fd4f93fc1..45b236c7b 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -107,9 +107,15 @@ public unsafe static implicit operator double(NDArray nd) public static implicit operator NDArray(bool value) => new NDArray(value); + public static implicit operator NDArray(byte value) + => new NDArray(value); + public static implicit operator NDArray(int value) => new NDArray(value); + public static implicit operator NDArray(long value) + => new NDArray(value); + public static implicit operator NDArray(float value) => new NDArray(value); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index a0b47aace..24c392155 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -84,8 +84,13 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); Tensor zeros = dtype switch { + TF_DataType.TF_BOOL => constant(false), TF_DataType.TF_DOUBLE => constant(0d), TF_DataType.TF_FLOAT => constant(0f), + TF_DataType.TF_INT64 => constant(0L), + TF_DataType.TF_UINT64 => constant((ulong)0), + TF_DataType.TF_INT32 => constant(0), + TF_DataType.TF_UINT32 => constant((uint)0), TF_DataType.TF_INT8 => constant((sbyte)0), TF_DataType.TF_UINT8 => constant((byte)0), _ => constant(0) @@ -108,9 +113,15 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT return _constant_if_small(0.0F, shape, dtype, name); case TF_DataType.TF_INT64: return _constant_if_small(0L, shape, dtype, name); + case TF_DataType.TF_UINT64: + return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT32: return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT32: + return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT8: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT8: return _constant_if_small(0, shape, dtype, name); default: throw new TypeError("can't find type for zeros");