Skip to content

Commit d9d1089

Browse files
authored
Merge pull request #80 from Sergio0694/dev
Batch normalization, APIs adjustments
2 parents e9170bd + 5936f85 commit d9d1089

39 files changed

+2130
-308
lines changed

NeuralNetwork.NET/APIs/CuDnnNetworkLayers.cs

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
using System;
2-
using System.Linq;
3-
using JetBrains.Annotations;
1+
using JetBrains.Annotations;
42
using NeuralNetworkNET.APIs.Delegates;
53
using NeuralNetworkNET.APIs.Enums;
64
using NeuralNetworkNET.APIs.Structs;
7-
using NeuralNetworkNET.Extensions;
5+
using NeuralNetworkNET.cuDNN;
86
using NeuralNetworkNET.Networks.Layers.Cuda;
97

108
namespace NeuralNetworkNET.APIs
@@ -17,22 +15,7 @@ public static class CuDnnNetworkLayers
1715
/// <summary>
1816
/// Gets whether or not the Cuda acceleration is supported on the current system
1917
/// </summary>
20-
public static bool IsCudaSupportAvailable
21-
{
22-
get
23-
{
24-
try
25-
{
26-
// Calling this directly would could a crash in the <Module> loader due to the missing .dll files
27-
return CuDnnSupportHelper.IsGpuAccelerationSupported();
28-
}
29-
catch (TypeInitializationException)
30-
{
31-
// Missing .dll file
32-
return false;
33-
}
34-
}
35-
}
18+
public static bool IsCudaSupportAvailable => CuDnnService.IsAvailable;
3619

3720
/// <summary>
3821
/// Creates a new fully connected layer with the specified number of input and output neurons, and the given activation function
@@ -132,41 +115,14 @@ public static LayerFactory Convolutional(
132115
public static LayerFactory Inception(InceptionInfo info, BiasInitializationMode biasMode = BiasInitializationMode.Zero)
133116
=> input => new CuDnnInceptionLayer(input, info, biasMode);
134117

135-
#region Feature helper
136-
137118
/// <summary>
138-
/// A private class that is used to create a new standalone type that contains the actual test method (decoupling is needed to &lt;Module&gt; loading crashes)
119+
/// Creates a new batch normalization layer
139120
/// </summary>
140-
private static class CuDnnSupportHelper
141-
{
142-
/// <summary>
143-
/// Checks whether or not the Cuda features are currently supported
144-
/// </summary>
145-
public static bool IsGpuAccelerationSupported()
146-
{
147-
try
148-
{
149-
// CUDA test
150-
Alea.Gpu gpu = Alea.Gpu.Default;
151-
if (gpu == null) return false;
152-
if (!Alea.cuDNN.Dnn.IsAvailable) return false; // cuDNN
153-
using (Alea.DeviceMemory<float> sample_gpu = gpu.AllocateDevice<float>(1024))
154-
{
155-
Alea.deviceptr<float> ptr = sample_gpu.Ptr;
156-
void Kernel(int i) => ptr[i] = i;
157-
Alea.Parallel.GpuExtension.For(gpu, 0, 1024, Kernel); // JIT test
158-
float[] sample = Alea.Gpu.CopyToHost(sample_gpu);
159-
return Enumerable.Range(0, 1024).Select<int, float>(i => i).ToArray().ContentEquals(sample);
160-
}
161-
}
162-
catch
163-
{
164-
// Missing .dll or other errors
165-
return false;
166-
}
167-
}
168-
}
169-
170-
#endregion
121+
/// <param name="mode">The normalization mode to use for the new layer</param>
122+
/// <param name="activation">The desired activation function to use in the network layer</param>
123+
[PublicAPI]
124+
[Pure, NotNull]
125+
public static LayerFactory BatchNormalization(NormalizationMode mode, ActivationType activation)
126+
=> input => new CuDnnBatchNormalizationLayer(input, mode, activation);
171127
}
172128
}

NeuralNetwork.NET/APIs/DatasetLoader.cs

Lines changed: 69 additions & 24 deletions
Large diffs are not rendered by default.

NeuralNetwork.NET/APIs/Datasets/Cifar10.cs

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
using NeuralNetworkNET.Extensions;
1010
using NeuralNetworkNET.Helpers;
1111
using NeuralNetworkNET.SupervisedLearning.Progress;
12+
using SixLabors.ImageSharp;
13+
using SixLabors.ImageSharp.Advanced;
14+
using SixLabors.ImageSharp.PixelFormats;
1215

1316
namespace NeuralNetworkNET.APIs.Datasets
1417
{
@@ -25,11 +28,14 @@ public static class Cifar10
2528
// 32*32 RGB images
2629
private const int SampleSize = 3072;
2730

31+
// A single 32*32 image
32+
private const int ImageSize = 1024;
33+
2834
private const String DatasetURL = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz";
2935

3036
[NotNull, ItemNotNull]
3137
private static readonly IReadOnlyList<String> TrainingBinFilenames = Enumerable.Range(1, 5).Select(i => $"data_batch_{i}.bin").ToArray();
32-
38+
3339
private const String TestBinFilename = "test_batch.bin";
3440

3541
#endregion
@@ -38,12 +44,13 @@ public static class Cifar10
3844
/// Downloads the CIFAR-10 training datasets and returns a new <see cref="ITestDataset"/> instance
3945
/// </summary>
4046
/// <param name="size">The desired dataset batch size</param>
47+
/// <param name="callback">The optional progress calback</param>
4148
/// <param name="token">An optional cancellation token for the operation</param>
4249
[PublicAPI]
4350
[Pure, ItemCanBeNull]
44-
public static async Task<ITrainingDataset> GetTrainingDatasetAsync(int size, CancellationToken token = default)
51+
public static async Task<ITrainingDataset> GetTrainingDatasetAsync(int size, [CanBeNull] IProgress<HttpProgress> callback = null, CancellationToken token = default)
4552
{
46-
IReadOnlyDictionary<String, Func<Stream>> map = await DatasetsDownloader.GetArchiveAsync(DatasetURL, token);
53+
IReadOnlyDictionary<String, Func<Stream>> map = await DatasetsDownloader.GetArchiveAsync(DatasetURL, callback, token);
4754
if (map == null) return null;
4855
IReadOnlyList<(float[], float[])>[] data = new IReadOnlyList<(float[], float[])>[TrainingBinFilenames.Count];
4956
Parallel.For(0, TrainingBinFilenames.Count, i => data[i] = ParseSamples(map[TrainingBinFilenames[i]], TrainingSamplesInBinFiles)).AssertCompleted();
@@ -54,25 +61,45 @@ public static async Task<ITrainingDataset> GetTrainingDatasetAsync(int size, Can
5461
/// Downloads the CIFAR-10 test datasets and returns a new <see cref="ITestDataset"/> instance
5562
/// </summary>
5663
/// <param name="progress">The optional progress callback to use</param>
64+
/// <param name="callback">The optional progress calback</param>
5765
/// <param name="token">An optional cancellation token for the operation</param>
5866
[PublicAPI]
5967
[Pure, ItemCanBeNull]
60-
public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<TrainingProgressEventArgs> progress = null, CancellationToken token = default)
68+
public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<TrainingProgressEventArgs> progress = null, [CanBeNull] IProgress<HttpProgress> callback = null, CancellationToken token = default)
6169
{
62-
IReadOnlyDictionary<String, Func<Stream>> map = await DatasetsDownloader.GetArchiveAsync(DatasetURL, token);
70+
IReadOnlyDictionary<String, Func<Stream>> map = await DatasetsDownloader.GetArchiveAsync(DatasetURL, callback, token);
6371
if (map == null) return null;
6472
IReadOnlyList<(float[], float[])> data = ParseSamples(map[TestBinFilename], TrainingSamplesInBinFiles);
6573
return DatasetLoader.Test(data, progress);
6674
}
6775

76+
/// <summary>
77+
/// Downloads and exports the full CIFAR-10 dataset (both training and test samples) to the target directory
78+
/// </summary>
79+
/// <param name="directory">The target directory</param>
80+
/// <param name="token">The cancellation token for the operation</param>
81+
[PublicAPI]
82+
public static async Task<bool> ExportDatasetAsync([NotNull] DirectoryInfo directory, CancellationToken token = default)
83+
{
84+
IReadOnlyDictionary<String, Func<Stream>> map = await DatasetsDownloader.GetArchiveAsync(DatasetURL, null, token);
85+
if (map == null) return false;
86+
if (!directory.Exists) directory.Create();
87+
ParallelLoopResult result = Parallel.ForEach(TrainingBinFilenames.Concat(new[] { TestBinFilename }), (name, state) =>
88+
{
89+
ExportSamples(directory, (name, map[name]), TrainingSamplesInBinFiles, token);
90+
if (token.IsCancellationRequested) state.Stop();
91+
});
92+
return result.IsCompleted && !token.IsCancellationRequested;
93+
}
94+
6895
#region Tools
6996

7097
/// <summary>
7198
/// Parses a CIFAR-10 .bin file
7299
/// </summary>
73100
/// <param name="factory">A <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
74101
/// <param name="count">The number of samples to parse</param>
75-
private static unsafe IReadOnlyList<(float[], float[])> ParseSamples(Func<Stream> factory, int count)
102+
private static unsafe IReadOnlyList<(float[], float[])> ParseSamples([NotNull] Func<Stream> factory, int count)
76103
{
77104
using (Stream stream = factory())
78105
{
@@ -89,8 +116,12 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
89116
fixed (float* px = x)
90117
{
91118
stream.Read(temp, 0, SampleSize);
92-
for (int j = 0; j < SampleSize; j++)
119+
for (int j = 0; j < ImageSize; j++)
120+
{
93121
px[j] = ptemp[j] / 255f; // Normalized samples
122+
px[j] = ptemp[j + ImageSize] / 255f;
123+
px[j] = ptemp[j + 2 * ImageSize] / 255f;
124+
}
94125
}
95126
data[i] = (x, y);
96127
}
@@ -99,6 +130,38 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
99130
}
100131
}
101132

133+
/// <summary>
134+
/// Exports a CIFAR-10 .bin file
135+
/// </summary>
136+
/// <param name="folder">The target folder to use to save the images</param>
137+
/// <param name="source">The source filename and a <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
138+
/// <param name="count">The number of samples to parse</param>
139+
/// <param name="token">A token for the operation</param>
140+
private static unsafe void ExportSamples([NotNull] DirectoryInfo folder, (String Name, Func<Stream> Factory) source, int count, CancellationToken token)
141+
{
142+
using (Stream stream = source.Factory())
143+
{
144+
byte[] temp = new byte[SampleSize];
145+
fixed (byte* ptemp = temp)
146+
{
147+
for (int i = 0; i < count; i++)
148+
{
149+
if (token.IsCancellationRequested) return;
150+
int label = stream.ReadByte();
151+
stream.Read(temp, 0, SampleSize);
152+
using (Image<Rgb24> image = new Image<Rgb24>(32, 32))
153+
fixed (Rgb24* p0 = &image.DangerousGetPinnableReferenceToPixelBuffer())
154+
{
155+
for (int j = 0; j < ImageSize; j++)
156+
p0[j] = new Rgb24(ptemp[j], ptemp[j + ImageSize], ptemp[j + 2 * ImageSize]);
157+
using (FileStream file = File.OpenWrite(Path.Combine(folder.FullName, $"[{source.Name}][{i}][{label}].bmp")))
158+
image.SaveAsBmp(file);
159+
}
160+
}
161+
}
162+
}
163+
}
164+
102165
#endregion
103166
}
104167
}

0 commit comments

Comments
 (0)