Skip to content

Commit 68631fd

Browse files
committed
TensorInfo initialization improved, more checks added
1 parent b031b47 commit 68631fd

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

NeuralNetwork.NET/APIs/Structs/ConvolutionInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ internal TensorInfo GetForwardOutputTensorInfo(in TensorInfo input, (int X, int
9292
int
9393
h = (input.Height - field.X + 2 * VerticalPadding) / VerticalStride + 1,
9494
w = (input.Width - field.Y + 2 * HorizontalPadding) / HorizontalStride + 1;
95+
if (h <= 0 || w <= 0) throw new InvalidOperationException("The input convolution kernels can't be applied to the input tensor shape");
9596
return new TensorInfo(h, w, kernels);
9697
}
9798

NeuralNetwork.NET/APIs/Structs/PoolingInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ internal TensorInfo GetForwardOutputTensorInfo(in TensorInfo input)
9999
int
100100
h = (input.Height - WindowHeight + 2 * VerticalPadding) / VerticalStride + 1,
101101
w = (input.Width - WindowWidth + 2 * HorizontalPadding) / HorizontalStride + 1;
102+
if (h <= 0 || w <= 0) throw new InvalidOperationException("The input tensor shape is not valid to apply the current pooling operation");
102103
return new TensorInfo(h, w, input.Channels);
103104
}
104105

NeuralNetwork.NET/APIs/Structs/TensorInfo.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ namespace NeuralNetworkNET.APIs.Structs
4040
[JsonProperty(nameof(Size), Order = 4)]
4141
public int Size
4242
{
43-
[Pure]
4443
[MethodImpl(MethodImplOptions.AggressiveInlining)]
4544
get => Height * Width * Channels;
4645
}
@@ -50,18 +49,27 @@ public int Size
5049
/// </summary>
5150
public int SliceSize
5251
{
53-
[Pure]
5452
[MethodImpl(MethodImplOptions.AggressiveInlining)]
5553
get => Height * Width;
5654
}
5755

56+
/// <summary>
57+
/// Gets whether the current <see cref="Tensor"/> instance is invalid (empty or with invalid parameters)
58+
/// </summary>
59+
public bool IsEmptyOrInvalid
60+
{
61+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
62+
get => Height <= 0 || Width <= 0 || Channels <= 0;
63+
}
64+
5865
#endregion
5966

6067
#region Constructors
6168

6269
internal TensorInfo(int height, int width, int channels)
6370
{
64-
if (height * width <= 0) throw new ArgumentException("The height and width of the kernels must be positive values");
71+
if (height <= 0 || width <= 0) throw new ArgumentException("The height and width of the kernels must be positive values");
72+
if (channels <= 0) throw new ArgumentException("The number of channels must be positive");
6573
Height = height;
6674
Width = width;
6775
Channels = channels >= 1 ? channels : throw new ArgumentOutOfRangeException(nameof(channels), "The number of channels must be at least equal to 1");

NeuralNetwork.NET/Networks/Layers/Abstract/NetworkLayerBase.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.IO;
1+
using System;
2+
using System.IO;
23
using System.Runtime.CompilerServices;
34
using JetBrains.Annotations;
45
using NeuralNetworkNET.APIs.Enums;
@@ -58,8 +59,8 @@ public ref readonly TensorInfo OutputInfo
5859

5960
protected NetworkLayerBase(in TensorInfo input, in TensorInfo output, ActivationFunctionType activation)
6061
{
61-
_InputInfo = input;
62-
_OutputInfo = output;
62+
_InputInfo = input.IsEmptyOrInvalid ? throw new ArgumentException("The layer input info is not valid", nameof(input)) : input;
63+
_OutputInfo = output.IsEmptyOrInvalid ? throw new ArgumentException("The layer output info is not valid", nameof(output)) : output;
6364
ActivationFunctionType = activation;
6465
ActivationFunctions = ActivationFunctionProvider.GetActivations(activation);
6566
}

0 commit comments

Comments
 (0)