1
1
using System ;
2
2
using System . IO ;
3
+ using System . Runtime . CompilerServices ;
3
4
using JetBrains . Annotations ;
4
5
using NeuralNetworkNET . APIs . Enums ;
6
+ using NeuralNetworkNET . APIs . Interfaces ;
5
7
using NeuralNetworkNET . APIs . Structs ;
6
8
using NeuralNetworkNET . Extensions ;
7
9
using NeuralNetworkNET . Helpers ;
@@ -30,14 +32,20 @@ internal abstract class BatchNormalizationLayerBase : WeightedLayerBase
30
32
[ NotNull ]
31
33
public float [ ] Sigma2 { get ; }
32
34
33
- // The current iteration number (for the Cumulative Moving Average)
34
- private int _Iteration ;
35
+ /// <summary>
36
+ /// Gets the current iteration number (for the Cumulative Moving Average)
37
+ /// </summary>
38
+ public int Iteration { get ; private set ; }
35
39
36
40
/// <summary>
37
41
/// Gets the current CMA factor used to update the <see cref="Mu"/> and <see cref="Sigma2"/> tensors
38
42
/// </summary>
39
43
[ JsonProperty ( nameof ( CumulativeMovingAverageFactor ) , Order = 6 ) ]
40
- public float CumulativeMovingAverageFactor => 1f / ( 1 + _Iteration ) ;
44
+ public float CumulativeMovingAverageFactor
45
+ {
46
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
47
+ get => 1f / ( 1 + Iteration ) ;
48
+ }
41
49
42
50
/// <inheritdoc/>
43
51
public override String Hash => Convert . ToBase64String ( Sha256 . Hash ( Weights , Biases , Mu , Sigma2 ) ) ;
@@ -74,24 +82,26 @@ protected BatchNormalizationLayerBase(in TensorInfo shape, NormalizationMode mod
74
82
NormalizationMode = mode ;
75
83
}
76
84
77
- protected BatchNormalizationLayerBase ( in TensorInfo shape , NormalizationMode mode , [ NotNull ] float [ ] w , [ NotNull ] float [ ] b , [ NotNull ] float [ ] mu , [ NotNull ] float [ ] sigma2 , ActivationType activation )
85
+ protected BatchNormalizationLayerBase ( in TensorInfo shape , NormalizationMode mode , [ NotNull ] float [ ] w , [ NotNull ] float [ ] b , int iteration , [ NotNull ] float [ ] mu , [ NotNull ] float [ ] sigma2 , ActivationType activation )
78
86
: base ( shape , shape , w , b , activation )
79
87
{
80
88
if ( w . Length != b . Length ) throw new ArgumentException ( "The size for both gamme and beta paarameters must be the same" ) ;
81
89
if ( mode == NormalizationMode . Spatial && w . Length != shape . Channels ||
82
90
mode == NormalizationMode . PerActivation && w . Length != shape . Size )
83
91
throw new ArgumentException ( "Invalid parameters size for the selected normalization mode" ) ;
92
+ if ( iteration < 0 ) throw new ArgumentOutOfRangeException ( nameof ( iteration ) , "The iteration value must be aat least equal to 0" ) ;
84
93
if ( mu . Length != w . Length || sigma2 . Length != w . Length )
85
94
throw new ArgumentException ( "The mu and sigma2 parameters must match the shape of the gamma and beta parameters" ) ;
86
95
NormalizationMode = mode ;
96
+ Iteration = iteration ;
87
97
Mu = mu ;
88
98
Sigma2 = sigma2 ;
89
99
}
90
100
91
101
/// <inheritdoc/>
92
102
public override void Forward ( in Tensor x , out Tensor z , out Tensor a )
93
103
{
94
- if ( NetworkTrainer . BackpropagationInProgress ) ForwardTraining ( 1f / ( 1 + _Iteration ++ ) , x , out z , out a ) ;
104
+ if ( NetworkTrainer . BackpropagationInProgress ) ForwardTraining ( 1f / ( 1 + Iteration ++ ) , x , out z , out a ) ;
95
105
else ForwardInference ( x , out z , out a ) ;
96
106
}
97
107
@@ -112,11 +122,22 @@ public override void Forward(in Tensor x, out Tensor z, out Tensor a)
112
122
/// <param name="a">The output activation on the current layer</param>
113
123
public abstract void ForwardTraining ( float factor , in Tensor x , out Tensor z , out Tensor a ) ;
114
124
125
+ /// <inheritdoc/>
126
+ public override bool Equals ( INetworkLayer other )
127
+ {
128
+ if ( ! base . Equals ( other ) ) return false ;
129
+ return other is BatchNormalizationLayerBase layer &&
130
+ Iteration == layer . Iteration &&
131
+ Mu . ContentEquals ( layer . Mu ) &&
132
+ Sigma2 . ContentEquals ( layer . Sigma2 ) ;
133
+ }
134
+
115
135
/// <inheritdoc/>
116
136
public override void Serialize ( Stream stream )
117
137
{
118
138
base . Serialize ( stream ) ;
119
139
stream . Write ( NormalizationMode ) ;
140
+ stream . Write ( Iteration ) ;
120
141
stream . Write ( Mu . Length ) ;
121
142
stream . WriteShuffled ( Mu ) ;
122
143
stream . Write ( Sigma2 . Length ) ;
0 commit comments