@@ -42,7 +42,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
42
42
private var convergenceTol = 1E-4
43
43
private var maxNumIterations = 100
44
44
private var regParam = 0.0
45
- private var miniBatchFraction = 1.0
46
45
47
46
/**
48
47
* Set the number of corrections used in the LBFGS update. Default 10.
@@ -57,14 +56,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
57
56
this
58
57
}
59
58
60
- /**
61
- * Set fraction of data to be used for each L-BFGS iteration. Default 1.0.
62
- */
63
- def setMiniBatchFraction (fraction : Double ): this .type = {
64
- this .miniBatchFraction = fraction
65
- this
66
- }
67
-
68
59
/**
69
60
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
70
61
* Smaller value will lead to higher accuracy with the cost of more iterations.
@@ -110,15 +101,14 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
110
101
}
111
102
112
103
override def optimize (data : RDD [(Double , Vector )], initialWeights : Vector ): Vector = {
113
- val (weights, _) = LBFGS .runMiniBatchLBFGS (
104
+ val (weights, _) = LBFGS .runLBFGS (
114
105
data,
115
106
gradient,
116
107
updater,
117
108
numCorrections,
118
109
convergenceTol,
119
110
maxNumIterations,
120
111
regParam,
121
- miniBatchFraction,
122
112
initialWeights)
123
113
weights
124
114
}
@@ -132,10 +122,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
132
122
@ DeveloperApi
133
123
object LBFGS extends Logging {
134
124
/**
135
- * Run Limited-memory BFGS (L-BFGS) in parallel using mini batches.
136
- * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
137
- * in order to compute a gradient estimate.
138
- * Sampling, and averaging the subgradients over this subset is performed using one standard
125
+ * Run Limited-memory BFGS (L-BFGS) in parallel.
126
+ * Averaging the subgradients over different partitions is performed using one standard
139
127
* spark map-reduce in each iteration.
140
128
*
141
129
* @param data - Input data for L-BFGS. RDD of the set of data examples, each of
@@ -147,38 +135,46 @@ object LBFGS extends Logging {
147
135
* @param convergenceTol - The convergence tolerance of iterations for L-BFGS
148
136
* @param maxNumIterations - Maximal number of iterations that L-BFGS can be run.
149
137
* @param regParam - Regularization parameter
150
- * @param miniBatchFraction - Fraction of the input data set that should be used for
151
- * one iteration of L-BFGS. Default value 1.0.
152
138
*
153
139
* @return A tuple containing two elements. The first element is a column matrix containing
154
140
* weights for every feature, and the second element is an array containing the loss
155
141
* computed for every iteration.
156
142
*/
157
- def runMiniBatchLBFGS (
143
+ def runLBFGS (
158
144
data : RDD [(Double , Vector )],
159
145
gradient : Gradient ,
160
146
updater : Updater ,
161
147
numCorrections : Int ,
162
148
convergenceTol : Double ,
163
149
maxNumIterations : Int ,
164
150
regParam : Double ,
165
- miniBatchFraction : Double ,
166
151
initialWeights : Vector ): (Vector , Array [Double ]) = {
167
152
168
153
val lossHistory = new ArrayBuffer [Double ](maxNumIterations)
169
154
170
155
val numExamples = data.count()
171
- val miniBatchSize = numExamples * miniBatchFraction
172
156
173
157
val costFun =
174
- new CostFun (data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize )
158
+ new CostFun (data, gradient, updater, regParam, numExamples )
175
159
176
160
val lbfgs = new BreezeLBFGS [BDV [Double ]](maxNumIterations, numCorrections, convergenceTol)
177
161
178
- val weights = Vectors .fromBreeze(
179
- lbfgs.minimize(new CachedDiffFunction (costFun), initialWeights.toBreeze.toDenseVector))
162
+ val states =
163
+ lbfgs.iterations(new CachedDiffFunction (costFun), initialWeights.toBreeze.toDenseVector)
164
+
165
+ /**
166
+ * NOTE: lossSum and loss is computed using the weights from the previous iteration
167
+ * and regVal is the regularization value computed in the previous iteration as well.
168
+ */
169
+ var state = states.next()
170
+ while (states.hasNext) {
171
+ lossHistory.append(state.value)
172
+ state = states.next()
173
+ }
174
+ lossHistory.append(state.value)
175
+ val weights = Vectors .fromBreeze(state.x)
180
176
181
- logInfo(" LBFGS.runMiniBatchSGD finished. Last 10 losses %s" .format(
177
+ logInfo(" LBFGS.runLBFGS finished. Last 10 losses %s" .format(
182
178
lossHistory.takeRight(10 ).mkString(" , " )))
183
179
184
180
(weights, lossHistory.toArray)
@@ -193,9 +189,7 @@ object LBFGS extends Logging {
193
189
gradient : Gradient ,
194
190
updater : Updater ,
195
191
regParam : Double ,
196
- miniBatchFraction : Double ,
197
- lossHistory : ArrayBuffer [Double ],
198
- miniBatchSize : Double ) extends DiffFunction [BDV [Double ]] {
192
+ numExamples : Long ) extends DiffFunction [BDV [Double ]] {
199
193
200
194
private var i = 0
201
195
@@ -204,8 +198,7 @@ object LBFGS extends Logging {
204
198
val localData = data
205
199
val localGradient = gradient
206
200
207
- val (gradientSum, lossSum) = localData.sample(false , miniBatchFraction, 42 + i)
208
- .aggregate((BDV .zeros[Double ](weights.size), 0.0 ))(
201
+ val (gradientSum, lossSum) = localData.aggregate((BDV .zeros[Double ](weights.size), 0.0 ))(
209
202
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
210
203
val l = localGradient.compute(
211
204
features, label, Vectors .fromBreeze(weights), Vectors .fromBreeze(grad))
@@ -223,7 +216,7 @@ object LBFGS extends Logging {
223
216
Vectors .fromBreeze(weights),
224
217
Vectors .dense(new Array [Double ](weights.size)), 0 , 1 , regParam)._2
225
218
226
- val loss = lossSum / miniBatchSize + regVal
219
+ val loss = lossSum / numExamples + regVal
227
220
/**
228
221
* It will return the gradient part of regularization using updater.
229
222
*
@@ -245,14 +238,8 @@ object LBFGS extends Logging {
245
238
Vectors .fromBreeze(weights),
246
239
Vectors .dense(new Array [Double ](weights.size)), 1 , 1 , regParam)._1.toBreeze
247
240
248
- // gradientTotal = gradientSum / miniBatchSize + gradientTotal
249
- axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
250
-
251
- /**
252
- * NOTE: lossSum and loss is computed using the weights from the previous iteration
253
- * and regVal is the regularization value computed in the previous iteration as well.
254
- */
255
- lossHistory.append(loss)
241
+ // gradientTotal = gradientSum / numExamples + gradientTotal
242
+ axpy(1.0 / numExamples, gradientSum, gradientTotal)
256
243
257
244
i += 1
258
245
0 commit comments