9
9
using NeuralNetworkNET . Extensions ;
10
10
using NeuralNetworkNET . Helpers ;
11
11
using NeuralNetworkNET . SupervisedLearning . Progress ;
12
+ using SixLabors . ImageSharp ;
13
+ using SixLabors . ImageSharp . Advanced ;
14
+ using SixLabors . ImageSharp . PixelFormats ;
12
15
13
16
namespace NeuralNetworkNET . APIs . Datasets
14
17
{
@@ -25,11 +28,14 @@ public static class Cifar10
25
28
// 32*32 RGB images
26
29
private const int SampleSize = 3072 ;
27
30
31
+ // A single 32*32 image
32
+ private const int ImageSize = 1024 ;
33
+
28
34
private const String DatasetURL = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" ;
29
35
30
36
[ NotNull , ItemNotNull ]
31
37
private static readonly IReadOnlyList < String > TrainingBinFilenames = Enumerable . Range ( 1 , 5 ) . Select ( i => $ "data_batch_{ i } .bin") . ToArray ( ) ;
32
-
38
+
33
39
private const String TestBinFilename = "test_batch.bin" ;
34
40
35
41
#endregion
@@ -38,12 +44,13 @@ public static class Cifar10
38
44
/// Downloads the CIFAR-10 training datasets and returns a new <see cref="ITestDataset"/> instance
39
45
/// </summary>
40
46
/// <param name="size">The desired dataset batch size</param>
47
+ /// <param name="callback">The optional progress calback</param>
41
48
/// <param name="token">An optional cancellation token for the operation</param>
42
49
[ PublicAPI ]
43
50
[ Pure , ItemCanBeNull ]
44
- public static async Task < ITrainingDataset > GetTrainingDatasetAsync ( int size , CancellationToken token = default )
51
+ public static async Task < ITrainingDataset > GetTrainingDatasetAsync ( int size , [ CanBeNull ] IProgress < HttpProgress > callback = null , CancellationToken token = default )
45
52
{
46
- IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , token ) ;
53
+ IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , callback , token ) ;
47
54
if ( map == null ) return null ;
48
55
IReadOnlyList < ( float [ ] , float [ ] ) > [ ] data = new IReadOnlyList < ( float [ ] , float [ ] ) > [ TrainingBinFilenames . Count ] ;
49
56
Parallel . For ( 0 , TrainingBinFilenames . Count , i => data [ i ] = ParseSamples ( map [ TrainingBinFilenames [ i ] ] , TrainingSamplesInBinFiles ) ) . AssertCompleted ( ) ;
@@ -54,25 +61,45 @@ public static async Task<ITrainingDataset> GetTrainingDatasetAsync(int size, Can
54
61
/// Downloads the CIFAR-10 test datasets and returns a new <see cref="ITestDataset"/> instance
55
62
/// </summary>
56
63
/// <param name="progress">The optional progress callback to use</param>
64
+ /// <param name="callback">The optional progress calback</param>
57
65
/// <param name="token">An optional cancellation token for the operation</param>
58
66
[ PublicAPI ]
59
67
[ Pure , ItemCanBeNull ]
60
- public static async Task < ITestDataset > GetTestDatasetAsync ( [ CanBeNull ] Action < TrainingProgressEventArgs > progress = null , CancellationToken token = default )
68
+ public static async Task < ITestDataset > GetTestDatasetAsync ( [ CanBeNull ] Action < TrainingProgressEventArgs > progress = null , [ CanBeNull ] IProgress < HttpProgress > callback = null , CancellationToken token = default )
61
69
{
62
- IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , token ) ;
70
+ IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , callback , token ) ;
63
71
if ( map == null ) return null ;
64
72
IReadOnlyList < ( float [ ] , float [ ] ) > data = ParseSamples ( map [ TestBinFilename ] , TrainingSamplesInBinFiles ) ;
65
73
return DatasetLoader . Test ( data , progress ) ;
66
74
}
67
75
76
+ /// <summary>
77
+ /// Downloads and exports the full CIFAR-10 dataset (both training and test samples) to the target directory
78
+ /// </summary>
79
+ /// <param name="directory">The target directory</param>
80
+ /// <param name="token">The cancellation token for the operation</param>
81
+ [ PublicAPI ]
82
+ public static async Task < bool > ExportDatasetAsync ( [ NotNull ] DirectoryInfo directory , CancellationToken token = default )
83
+ {
84
+ IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , null , token ) ;
85
+ if ( map == null ) return false ;
86
+ if ( ! directory . Exists ) directory . Create ( ) ;
87
+ ParallelLoopResult result = Parallel . ForEach ( TrainingBinFilenames . Concat ( new [ ] { TestBinFilename } ) , ( name , state ) =>
88
+ {
89
+ ExportSamples ( directory , ( name , map [ name ] ) , TrainingSamplesInBinFiles , token ) ;
90
+ if ( token . IsCancellationRequested ) state . Stop ( ) ;
91
+ } ) ;
92
+ return result . IsCompleted && ! token . IsCancellationRequested ;
93
+ }
94
+
68
95
#region Tools
69
96
70
97
/// <summary>
71
98
/// Parses a CIFAR-10 .bin file
72
99
/// </summary>
73
100
/// <param name="factory">A <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
74
101
/// <param name="count">The number of samples to parse</param>
75
- private static unsafe IReadOnlyList < ( float [ ] , float [ ] ) > ParseSamples ( Func < Stream > factory , int count )
102
+ private static unsafe IReadOnlyList < ( float [ ] , float [ ] ) > ParseSamples ( [ NotNull ] Func < Stream > factory , int count )
76
103
{
77
104
using ( Stream stream = factory ( ) )
78
105
{
@@ -89,8 +116,12 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
89
116
fixed ( float * px = x )
90
117
{
91
118
stream . Read ( temp , 0 , SampleSize ) ;
92
- for ( int j = 0 ; j < SampleSize ; j ++ )
119
+ for ( int j = 0 ; j < ImageSize ; j ++ )
120
+ {
93
121
px [ j ] = ptemp [ j ] / 255f ; // Normalized samples
122
+ px [ j ] = ptemp [ j + ImageSize ] / 255f ;
123
+ px [ j ] = ptemp [ j + 2 * ImageSize ] / 255f ;
124
+ }
94
125
}
95
126
data [ i ] = ( x , y ) ;
96
127
}
@@ -99,6 +130,38 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
99
130
}
100
131
}
101
132
133
+ /// <summary>
134
+ /// Exports a CIFAR-10 .bin file
135
+ /// </summary>
136
+ /// <param name="folder">The target folder to use to save the images</param>
137
+ /// <param name="source">The source filename and a <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
138
+ /// <param name="count">The number of samples to parse</param>
139
+ /// <param name="token">A token for the operation</param>
140
+ private static unsafe void ExportSamples ( [ NotNull ] DirectoryInfo folder , ( String Name , Func < Stream > Factory ) source , int count , CancellationToken token )
141
+ {
142
+ using ( Stream stream = source . Factory ( ) )
143
+ {
144
+ byte [ ] temp = new byte [ SampleSize ] ;
145
+ fixed ( byte * ptemp = temp )
146
+ {
147
+ for ( int i = 0 ; i < count ; i ++ )
148
+ {
149
+ if ( token . IsCancellationRequested ) return ;
150
+ int label = stream . ReadByte ( ) ;
151
+ stream . Read ( temp , 0 , SampleSize ) ;
152
+ using ( Image < Rgb24 > image = new Image < Rgb24 > ( 32 , 32 ) )
153
+ fixed ( Rgb24 * p0 = & image . DangerousGetPinnableReferenceToPixelBuffer ( ) )
154
+ {
155
+ for ( int j = 0 ; j < ImageSize ; j ++ )
156
+ p0 [ j ] = new Rgb24 ( ptemp [ j ] , ptemp [ j + ImageSize ] , ptemp [ j + 2 * ImageSize ] ) ;
157
+ using ( FileStream file = File . OpenWrite ( Path . Combine ( folder . FullName , $ "[{ source . Name } ][{ i } ][{ label } ].bmp") ) )
158
+ image . SaveAsBmp ( file ) ;
159
+ }
160
+ }
161
+ }
162
+ }
163
+ }
164
+
102
165
#endregion
103
166
}
104
167
}
0 commit comments