Skip to content

Commit 6764628

Browse files
authored
Merge pull request #59 from Sergio0694/dev
New APIs, minor improvements
2 parents b92b399 + 144a970 commit 6764628

File tree

20 files changed

+361
-421
lines changed

20 files changed

+361
-421
lines changed

NeuralNetwork.NET/APIs/DatasetLoader.cs

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
using JetBrains.Annotations;
55
using NeuralNetworkNET.APIs.Interfaces.Data;
66
using NeuralNetworkNET.Extensions;
7+
using NeuralNetworkNET.Helpers;
78
using NeuralNetworkNET.SupervisedLearning.Data;
89
using NeuralNetworkNET.SupervisedLearning.Optimization.Parameters;
910
using NeuralNetworkNET.SupervisedLearning.Optimization.Progress;
11+
using SixLabors.ImageSharp;
12+
using SixLabors.ImageSharp.PixelFormats;
1013

1114
namespace NeuralNetworkNET.APIs
1215
{
@@ -30,12 +33,12 @@ public static class DatasetLoader
3033
/// <summary>
3134
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input collection, with the specified batch size
3235
/// </summary>
33-
/// <param name="data">The source collection to use to build the training dataset</param>
36+
/// <param name="data">The source collection to use to build the training dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
3437
/// <param name="size">The desired dataset batch size</param>
3538
[PublicAPI]
3639
[Pure, NotNull]
3740
[CollectionAccess(CollectionAccessType.Read)]
38-
public static ITrainingDataset Training([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, int size) => BatchesCollection.From(data, size);
41+
public static ITrainingDataset Training([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, int size) => BatchesCollection.From(data, size);
3942

4043
/// <summary>
4144
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input matrices, with the specified batch size
@@ -47,6 +50,34 @@ public static class DatasetLoader
4750
[CollectionAccess(CollectionAccessType.Read)]
4851
public static ITrainingDataset Training((float[,] X, float[,] Y) data, int size) => BatchesCollection.From(data, size);
4952

53+
/// <summary>
54+
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input data, where each input sample is an image in a specified format
55+
/// </summary>
56+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
57+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
58+
/// <param name="size">The desired dataset batch size</param>
59+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
60+
[PublicAPI]
61+
[Pure, NotNull]
62+
[CollectionAccess(CollectionAccessType.Read)]
63+
public static ITrainingDataset Training<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, int size, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
64+
where TPixel : struct, IPixel<TPixel>
65+
=> BatchesCollection.From(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)), size);
66+
67+
/// <summary>
68+
/// Creates a new <see cref="ITrainingDataset"/> instance to train a network from the input data, where each input sample is an image in a specified format
69+
/// </summary>
70+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
71+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
72+
/// <param name="size">The desired dataset batch size</param>
73+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
74+
[PublicAPI]
75+
[Pure, NotNull]
76+
[CollectionAccess(CollectionAccessType.Read)]
77+
public static ITrainingDataset Training<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, int size, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
78+
where TPixel : struct, IPixel<TPixel>
79+
=> BatchesCollection.From(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())), size);
80+
5081
#endregion
5182

5283
#region Validation
@@ -66,13 +97,13 @@ public static IValidationDataset Validation([NotNull] IEnumerable<(float[] X, fl
6697
/// <summary>
6798
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
6899
/// </summary>
69-
/// <param name="data">The source collection to use to build the validation dataset</param>
100+
/// <param name="data">The source collection to use to build the validation dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
70101
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
71102
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
72103
[PublicAPI]
73104
[Pure, NotNull]
74105
[CollectionAccess(CollectionAccessType.Read)]
75-
public static IValidationDataset Validation([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, float tolerance = 1e-2f, int epochs = 5)
106+
public static IValidationDataset Validation([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, float tolerance = 1e-2f, int epochs = 5)
76107
=> Validation(data.AsParallel().Select(f => f()), tolerance, epochs);
77108

78109
/// <summary>
@@ -86,6 +117,36 @@ public static IValidationDataset Validation([NotNull] IEnumerable<Func<(float[]
86117
[CollectionAccess(CollectionAccessType.Read)]
87118
public static IValidationDataset Validation((float[,] X, float[,] Y) data, float tolerance = 1e-2f, int epochs = 5) => new ValidationDataset(data, tolerance, epochs);
88119

120+
/// <summary>
121+
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
122+
/// </summary>
123+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
124+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
125+
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
126+
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
127+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
128+
[PublicAPI]
129+
[Pure, NotNull]
130+
[CollectionAccess(CollectionAccessType.Read)]
131+
public static IValidationDataset Validation<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, float tolerance = 1e-2f, int epochs = 5, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
132+
where TPixel : struct, IPixel<TPixel>
133+
=> Validation(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)).AsParallel(), tolerance, epochs);
134+
135+
/// <summary>
136+
/// Creates a new <see cref="IValidationDataset"/> instance to validate a network accuracy from the input collection
137+
/// </summary>
138+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
139+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
140+
/// <param name="tolerance">The desired tolerance to test the network for convergence</param>
141+
/// <param name="epochs">The epochs interval to consider when testing the network for convergence</param>
142+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
143+
[PublicAPI]
144+
[Pure, NotNull]
145+
[CollectionAccess(CollectionAccessType.Read)]
146+
public static IValidationDataset Validation<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, float tolerance = 1e-2f, int epochs = 5, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
147+
where TPixel : struct, IPixel<TPixel>
148+
=> Validation(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())).AsParallel(), tolerance, epochs);
149+
89150
#endregion
90151

91152
#region Test
@@ -104,12 +165,12 @@ public static ITestDataset Test([NotNull] IEnumerable<(float[] X, float[] Y)> da
104165
/// <summary>
105166
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
106167
/// </summary>
107-
/// <param name="data">The source collection to use to build the test dataset</param>
168+
/// <param name="data">The source collection to use to build the test dataset, where the samples will be extracted from the input <see cref="Func{TResult}"/> instances in parallel</param>
108169
/// <param name="progress">The optional progress callback to use</param>
109170
[PublicAPI]
110171
[Pure, NotNull]
111172
[CollectionAccess(CollectionAccessType.Read)]
112-
public static ITestDataset Test([NotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null)
173+
public static ITestDataset Test([NotNull, ItemNotNull] IEnumerable<Func<(float[] X, float[] Y)>> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null)
113174
=> Test(data.AsParallel().Select(f => f()), progress);
114175

115176
/// <summary>
@@ -122,6 +183,34 @@ public static ITestDataset Test([NotNull] IEnumerable<Func<(float[] X, float[] Y
122183
[CollectionAccess(CollectionAccessType.Read)]
123184
public static ITestDataset Test((float[,] X, float[,] Y) data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null) => new TestDataset(data, progress);
124185

186+
/// <summary>
187+
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
188+
/// </summary>
189+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
190+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a vector with the expected outputs</param>
191+
/// <param name="progress">The optional progress callback to use</param>
192+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
193+
[PublicAPI]
194+
[Pure, NotNull]
195+
[CollectionAccess(CollectionAccessType.Read)]
196+
public static ITestDataset Test<TPixel>([NotNull] IEnumerable<(String X, float[] Y)> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
197+
where TPixel : struct, IPixel<TPixel>
198+
=> Test(data.Select<(String X, float[] Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y)).AsParallel(), progress);
199+
200+
/// <summary>
201+
/// Creates a new <see cref="ITestDataset"/> instance to test a network from the input collection
202+
/// </summary>
203+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
204+
/// <param name="data">A list of <see cref="ValueTuple{T1, T2}"/> items, where the first element is the image path and the second is a <see cref="Func{TResult}"/> returning a vector with the expected outputs</param>
205+
/// <param name="progress">The optional progress callback to use</param>
206+
/// <param name="modify">An optional <see cref="Action{T}"/> to modify each sample image when loading the dataset</param>
207+
[PublicAPI]
208+
[Pure, NotNull]
209+
[CollectionAccess(CollectionAccessType.Read)]
210+
public static ITestDataset Test<TPixel>([NotNull] IEnumerable<(String X, Func<float[]> Y)> data, [CanBeNull] IProgress<TrainingProgressEventArgs> progress = null, [CanBeNull] Action<IImageProcessingContext<TPixel>> modify = null)
211+
where TPixel : struct, IPixel<TPixel>
212+
=> Test(data.Select<(String X, Func<float[]> Y), Func<(float[], float[])>>(xy => () => (ImageLoader.Load(xy.X, modify), xy.Y())).AsParallel(), progress);
213+
125214
#endregion
126215
}
127216
}

NeuralNetwork.NET/APIs/Structs/TensorInfo.cs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Diagnostics;
55
using System.Runtime.CompilerServices;
6+
using SixLabors.ImageSharp.PixelFormats;
67

78
namespace NeuralNetworkNET.APIs.Structs
89
{
@@ -67,30 +68,38 @@ internal TensorInfo(int height, int width, int channels)
6768
}
6869

6970
/// <summary>
70-
/// Creates a new <see cref="TensorInfo"/> instance for an RGB image
71+
/// Creates a new <see cref="TensorInfo"/> instance for a linear network layer, without keeping track of spatial info
7172
/// </summary>
72-
/// <param name="height">The height of the input image</param>
73-
/// <param name="width">The width of the input image</param>
73+
/// <param name="size">The input size</param>
7474
[PublicAPI]
7575
[Pure]
76-
public static TensorInfo CreateForRgbImage(int height, int width) => new TensorInfo(height, width, 3);
76+
public static TensorInfo Linear(int size) => new TensorInfo(1, 1, size);
7777

7878
/// <summary>
79-
/// Creates a new <see cref="TensorInfo"/> instance for a grayscale image
79+
/// Creates a new <see cref="TensorInfo"/> instance for an image with a user-defined pixel type
8080
/// </summary>
81+
/// <typeparam name="TPixel">The type of image pixels. It must be either <see cref="Alpha8"/>, <see cref="Rgb24"/> or <see cref="Argb32"/></typeparam>
8182
/// <param name="height">The height of the input image</param>
8283
/// <param name="width">The width of the input image</param>
8384
[PublicAPI]
8485
[Pure]
85-
public static TensorInfo CreateForGrayscaleImage(int height, int width) => new TensorInfo(height, width, 1);
86+
public static TensorInfo Image<TPixel>(int height, int width) where TPixel : struct, IPixel<TPixel>
87+
{
88+
if (typeof(TPixel) == typeof(Alpha8)) return new TensorInfo(height, width, 1);
89+
if (typeof(TPixel) == typeof(Rgb24)) return new TensorInfo(height, width, 3);
90+
if (typeof(TPixel) == typeof(Argb32)) return new TensorInfo(height, width, 4);
91+
throw new InvalidOperationException($"The {typeof(TPixel).Name} pixel format isn't currently supported");
92+
}
8693

8794
/// <summary>
88-
/// Creates a new <see cref="TensorInfo"/> instance for a linear network layer, without keeping track of spatial info
95+
/// Creates a new <see cref="TensorInfo"/> instance for with a custom 3D shape
8996
/// </summary>
90-
/// <param name="size">The input size</param>
97+
/// <param name="height">The input volume height</param>
98+
/// <param name="width">The input volume width</param>
99+
/// <param name="channels">The number of channels in the input volume</param>
91100
[PublicAPI]
92101
[Pure]
93-
public static TensorInfo CreateLinear(int size) => new TensorInfo(1, 1, size);
102+
public static TensorInfo Volume(int height, int width, int channels) => new TensorInfo(height, width, channels);
94103

95104
#endregion
96105

NeuralNetwork.NET/Extensions/MiscExtensions.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Runtime.CompilerServices;
5+
using System.Text;
56
using System.Threading.Tasks;
67
using JetBrains.Annotations;
78

@@ -20,8 +21,10 @@ public static class MiscExtensions
2021
/// <param name="item">The item to cast</param>
2122
[Pure, NotNull]
2223
[MethodImpl(MethodImplOptions.AggressiveInlining)]
23-
public static TOut To<TIn, TOut>([NotNull] this TIn item) where TOut : class, TIn => item as TOut
24-
?? throw new InvalidOperationException($"The item of type {typeof(TIn)} is a {item.GetType()} instance and can't be cast to {typeof(TOut)}");
24+
public static TOut To<TIn, TOut>([NotNull] this TIn item)
25+
where TIn : class
26+
where TOut : TIn
27+
=> (TOut)item;
2528

2629
/// <summary>
2730
/// Returns the maximum value between two numbers
@@ -138,5 +141,20 @@ public static void AssertCompleted(in this ParallelLoopResult result)
138141
{
139142
if (!result.IsCompleted) throw new InvalidOperationException("Error while performing the parallel loop");
140143
}
144+
145+
/// <summary>
146+
/// Removes the left spaces from the input verbatim string
147+
/// </summary>
148+
/// <param name="text">The string to trim</param>
149+
[Pure, NotNull]
150+
public static String TrimVerbatim([NotNull] this String text)
151+
{
152+
String[] lines = text.Split(new[] { Environment.NewLine }, StringSplitOptions.None);
153+
return lines.Aggregate(new StringBuilder(), (b, s) =>
154+
{
155+
b.AppendLine(s.Trim());
156+
return b;
157+
}).ToString();
158+
}
141159
}
142160
}

0 commit comments

Comments
 (0)