Skip to content

Commit 3452997

Browse files
DB Tsaipwendell
authored andcommitted
[SPARK-1157][MLlib] Bug fix: lossHistory should exclude rejection steps, and remove miniBatch
Getting the lossHistory from Breeze's API which already excludes the rejection steps in line search. Also, remove the miniBatch in LBFGS since those quasi-Newton methods approximate the inverse of Hessian. It doesn't make sense if the gradients are computed from a varying objective. Author: DB Tsai <[email protected]> Closes #582 from dbtsai/dbtsai-lbfgs-bug and squashes the following commits: 9cc6cf9 [DB Tsai] Removed the miniBatch in LBFGS. 1ba6a33 [DB Tsai] Formatting the code. d72c679 [DB Tsai] Using Breeze's states to get the loss. (cherry picked from commit 910a13b) Signed-off-by: Patrick Wendell <[email protected]>
1 parent d81d626 commit 3452997

File tree

2 files changed

+30
-48
lines changed

2 files changed

+30
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
4242
private var convergenceTol = 1E-4
4343
private var maxNumIterations = 100
4444
private var regParam = 0.0
45-
private var miniBatchFraction = 1.0
4645

4746
/**
4847
* Set the number of corrections used in the LBFGS update. Default 10.
@@ -57,14 +56,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
5756
this
5857
}
5958

60-
/**
61-
* Set fraction of data to be used for each L-BFGS iteration. Default 1.0.
62-
*/
63-
def setMiniBatchFraction(fraction: Double): this.type = {
64-
this.miniBatchFraction = fraction
65-
this
66-
}
67-
6859
/**
6960
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
7061
* Smaller value will lead to higher accuracy with the cost of more iterations.
@@ -110,15 +101,14 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
110101
}
111102

112103
override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
113-
val (weights, _) = LBFGS.runMiniBatchLBFGS(
104+
val (weights, _) = LBFGS.runLBFGS(
114105
data,
115106
gradient,
116107
updater,
117108
numCorrections,
118109
convergenceTol,
119110
maxNumIterations,
120111
regParam,
121-
miniBatchFraction,
122112
initialWeights)
123113
weights
124114
}
@@ -132,10 +122,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
132122
@DeveloperApi
133123
object LBFGS extends Logging {
134124
/**
135-
* Run Limited-memory BFGS (L-BFGS) in parallel using mini batches.
136-
* In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
137-
* in order to compute a gradient estimate.
138-
* Sampling, and averaging the subgradients over this subset is performed using one standard
125+
* Run Limited-memory BFGS (L-BFGS) in parallel.
126+
* Averaging the subgradients over different partitions is performed using one standard
139127
* spark map-reduce in each iteration.
140128
*
141129
* @param data - Input data for L-BFGS. RDD of the set of data examples, each of
@@ -147,38 +135,46 @@ object LBFGS extends Logging {
147135
* @param convergenceTol - The convergence tolerance of iterations for L-BFGS
148136
* @param maxNumIterations - Maximal number of iterations that L-BFGS can be run.
149137
* @param regParam - Regularization parameter
150-
* @param miniBatchFraction - Fraction of the input data set that should be used for
151-
* one iteration of L-BFGS. Default value 1.0.
152138
*
153139
* @return A tuple containing two elements. The first element is a column matrix containing
154140
* weights for every feature, and the second element is an array containing the loss
155141
* computed for every iteration.
156142
*/
157-
def runMiniBatchLBFGS(
143+
def runLBFGS(
158144
data: RDD[(Double, Vector)],
159145
gradient: Gradient,
160146
updater: Updater,
161147
numCorrections: Int,
162148
convergenceTol: Double,
163149
maxNumIterations: Int,
164150
regParam: Double,
165-
miniBatchFraction: Double,
166151
initialWeights: Vector): (Vector, Array[Double]) = {
167152

168153
val lossHistory = new ArrayBuffer[Double](maxNumIterations)
169154

170155
val numExamples = data.count()
171-
val miniBatchSize = numExamples * miniBatchFraction
172156

173157
val costFun =
174-
new CostFun(data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize)
158+
new CostFun(data, gradient, updater, regParam, numExamples)
175159

176160
val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
177161

178-
val weights = Vectors.fromBreeze(
179-
lbfgs.minimize(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector))
162+
val states =
163+
lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
164+
165+
/**
166+
* NOTE: lossSum and loss is computed using the weights from the previous iteration
167+
* and regVal is the regularization value computed in the previous iteration as well.
168+
*/
169+
var state = states.next()
170+
while(states.hasNext) {
171+
lossHistory.append(state.value)
172+
state = states.next()
173+
}
174+
lossHistory.append(state.value)
175+
val weights = Vectors.fromBreeze(state.x)
180176

181-
logInfo("LBFGS.runMiniBatchSGD finished. Last 10 losses %s".format(
177+
logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
182178
lossHistory.takeRight(10).mkString(", ")))
183179

184180
(weights, lossHistory.toArray)
@@ -193,9 +189,7 @@ object LBFGS extends Logging {
193189
gradient: Gradient,
194190
updater: Updater,
195191
regParam: Double,
196-
miniBatchFraction: Double,
197-
lossHistory: ArrayBuffer[Double],
198-
miniBatchSize: Double) extends DiffFunction[BDV[Double]] {
192+
numExamples: Long) extends DiffFunction[BDV[Double]] {
199193

200194
private var i = 0
201195

@@ -204,8 +198,7 @@ object LBFGS extends Logging {
204198
val localData = data
205199
val localGradient = gradient
206200

207-
val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 42 + i)
208-
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
201+
val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
209202
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
210203
val l = localGradient.compute(
211204
features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
@@ -223,7 +216,7 @@ object LBFGS extends Logging {
223216
Vectors.fromBreeze(weights),
224217
Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
225218

226-
val loss = lossSum / miniBatchSize + regVal
219+
val loss = lossSum / numExamples + regVal
227220
/**
228221
* It will return the gradient part of regularization using updater.
229222
*
@@ -245,14 +238,8 @@ object LBFGS extends Logging {
245238
Vectors.fromBreeze(weights),
246239
Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze
247240

248-
// gradientTotal = gradientSum / miniBatchSize + gradientTotal
249-
axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
250-
251-
/**
252-
* NOTE: lossSum and loss is computed using the weights from the previous iteration
253-
* and regVal is the regularization value computed in the previous iteration as well.
254-
*/
255-
lossHistory.append(loss)
241+
// gradientTotal = gradientSum / numExamples + gradientTotal
242+
axpy(1.0 / numExamples, gradientSum, gradientTotal)
256243

257244
i += 1
258245

mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,14 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
5959
val convergenceTol = 1e-12
6060
val maxNumIterations = 10
6161

62-
val (_, loss) = LBFGS.runMiniBatchLBFGS(
62+
val (_, loss) = LBFGS.runLBFGS(
6363
dataRDD,
6464
gradient,
6565
simpleUpdater,
6666
numCorrections,
6767
convergenceTol,
6868
maxNumIterations,
6969
regParam,
70-
miniBatchFrac,
7170
initialWeightsWithIntercept)
7271

7372
// Since the cost function is convex, the loss is guaranteed to be monotonically decreasing
@@ -104,15 +103,14 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
104103
val convergenceTol = 1e-12
105104
val maxNumIterations = 10
106105

107-
val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS(
106+
val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS(
108107
dataRDD,
109108
gradient,
110109
squaredL2Updater,
111110
numCorrections,
112111
convergenceTol,
113112
maxNumIterations,
114113
regParam,
115-
miniBatchFrac,
116114
initialWeightsWithIntercept)
117115

118116
val numGDIterations = 50
@@ -150,47 +148,44 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
150148
val maxNumIterations = 8
151149
var convergenceTol = 0.0
152150

153-
val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS(
151+
val (_, lossLBFGS1) = LBFGS.runLBFGS(
154152
dataRDD,
155153
gradient,
156154
squaredL2Updater,
157155
numCorrections,
158156
convergenceTol,
159157
maxNumIterations,
160158
regParam,
161-
miniBatchFrac,
162159
initialWeightsWithIntercept)
163160

164161
// Note that the first loss is computed with initial weights,
165162
// so the total numbers of loss will be numbers of iterations + 1
166163
assert(lossLBFGS1.length == 9)
167164

168165
convergenceTol = 0.1
169-
val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS(
166+
val (_, lossLBFGS2) = LBFGS.runLBFGS(
170167
dataRDD,
171168
gradient,
172169
squaredL2Updater,
173170
numCorrections,
174171
convergenceTol,
175172
maxNumIterations,
176173
regParam,
177-
miniBatchFrac,
178174
initialWeightsWithIntercept)
179175

180176
// Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed.
181177
assert(lossLBFGS2.length == 4)
182178
assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol)
183179

184180
convergenceTol = 0.01
185-
val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS(
181+
val (_, lossLBFGS3) = LBFGS.runLBFGS(
186182
dataRDD,
187183
gradient,
188184
squaredL2Updater,
189185
numCorrections,
190186
convergenceTol,
191187
maxNumIterations,
192188
regParam,
193-
miniBatchFrac,
194189
initialWeightsWithIntercept)
195190

196191
// With smaller convergenceTol, it takes more steps.

0 commit comments

Comments
 (0)