Skip to content

Commit 1d85289

Browse files
author
DB Tsai
committed
Improve the convergence rate by minimize the condition number in LOR with LBFGS
1 parent bad21ed commit 1d85289

File tree

3 files changed

+121
-4
lines changed

3 files changed

+121
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class LogisticRegressionModel (
6262
override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
6363
intercept: Double) = {
6464
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
65-
val score = 1.0/ (1.0 + math.exp(-margin))
65+
val score = 1.0 / (1.0 + math.exp(-margin))
6666
threshold match {
6767
case Some(t) => if (score < t) 0.0 else 1.0
6868
case None => score
@@ -204,6 +204,8 @@ class LogisticRegressionWithLBFGS private (
204204
*/
205205
def this() = this(1E-4, 100, 0.0)
206206

207+
this.setFeatureScaling(true)
208+
207209
private val gradient = new LogisticGradient()
208210
private val updater = new SimpleUpdater()
209211
// Have to return new LBFGS object every time since users can reset the parameters anytime.

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.mllib.regression
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.mllib.feature.StandardScaler
2122
import org.apache.spark.{Logging, SparkException}
2223
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.mllib.optimization._
@@ -94,6 +95,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
9495

9596
protected var validateData: Boolean = true
9697

98+
/**
99+
* Whether to perform feature scaling before model training to reduce the condition numbers
100+
* which can significantly help the optimizer converging faster. The scaling correction will be
101+
* translated back to resulting model weights, so it's transparent to users.
102+
* Note: This technique is used in both libsvm and glmnet packages. Default false.
103+
*/
104+
private var useFeatureScaling = false
105+
106+
/**
107+
* Set if the algorithm should use feature scaling to improve the convergence during optimization.
108+
*/
109+
private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = {
110+
this.useFeatureScaling = useFeatureScaling
111+
this
112+
}
113+
97114
/**
98115
* Create a model given the weights and intercept
99116
*/
@@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
137154
throw new SparkException("Input validation failed.")
138155
}
139156

157+
/**
158+
* Scaling to minimize the condition number:
159+
*
160+
* During the optimization process, the convergence (rate) depends on the condition number of
161+
* the training dataset. Scaling the variables often reduces this condition number, thus
162+
* improving the convergence rate dramatically. Without reducing the condition number,
163+
* some training datasets mixing the columns with different scales may not be able to converge.
164+
*
165+
* GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return
166+
* the weights in the original scale.
167+
* See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
168+
*
169+
* Here, if useFeatureScaling is enabled, we will standardize the training features by dividing
170+
* the variance of each column (without subtracting the mean), and train the model in the
171+
* scaled space. Then we transform the coefficients from the scaled space to the original scale
172+
* as GLMNET and LIBSVM do.
173+
*
174+
* Currently, it's only enabled in LogisticRegressionWithLBFGS
175+
*/
176+
val scaler = if (useFeatureScaling) {
177+
(new StandardScaler).fit(input.map(x => x.features))
178+
} else {
179+
null
180+
}
181+
140182
// Prepend an extra variable consisting of all 1.0's for the intercept.
141183
val data = if (addIntercept) {
142-
input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
184+
if(useFeatureScaling) {
185+
input.map(labeledPoint =>
186+
(labeledPoint.label, appendBias(scaler.transform(labeledPoint.features))))
187+
} else {
188+
input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
189+
}
143190
} else {
144-
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
191+
if (useFeatureScaling) {
192+
input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features)))
193+
} else {
194+
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
195+
}
145196
}
146197

147198
val initialWeightsWithIntercept = if (addIntercept) {
@@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
153204
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
154205

155206
val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
156-
val weights =
207+
var weights =
157208
if (addIntercept) {
158209
Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
159210
} else {
160211
weightsWithIntercept
161212
}
162213

214+
/**
215+
* The weights and intercept are trained in the scaled space; we're converting them back to
216+
* the original scale.
217+
*
218+
* Math shows that if we only perform standardization without subtracting means, the intercept
219+
* will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i
220+
* is the coefficient in the original space, and v_i is the variance of the column i.
221+
*/
222+
if (useFeatureScaling) {
223+
weights = scaler.transform(weights)
224+
}
225+
163226
createModel(weights, intercept)
164227
}
165228
}

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,58 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
185185
// Test prediction on Array.
186186
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
187187
}
188+
189+
test("numerical stability of scaling features using logistic regression with LBFGS") {
190+
/**
191+
* If we rescale the features, the condition number will be changed so the convergence rate
192+
* and the solution will not equal to the original solution multiple by the scaling factor
193+
* which it should be.
194+
*
195+
* However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first,
196+
* no matter how we multiple a scaling factor into the dataset, the convergence rate should be
197+
* the same, and the solution should equal to the original solution multiple by the scaling
198+
* factor.
199+
*/
200+
201+
val nPoints = 10000
202+
val A = 2.0
203+
val B = -1.5
204+
205+
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
206+
207+
val initialWeights = Vectors.dense(0.0)
208+
209+
val testRDD1 = sc.parallelize(testData, 2)
210+
211+
val testRDD2 = sc.parallelize(
212+
testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2)
213+
214+
val testRDD3 = sc.parallelize(
215+
testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2)
216+
217+
testRDD1.cache()
218+
testRDD2.cache()
219+
testRDD3.cache()
220+
221+
val lrA = new LogisticRegressionWithLBFGS().setIntercept(true)
222+
val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false)
223+
224+
val modelA1 = lrA.run(testRDD1, initialWeights)
225+
val modelA2 = lrA.run(testRDD2, initialWeights)
226+
val modelA3 = lrA.run(testRDD3, initialWeights)
227+
228+
val modelB1 = lrB.run(testRDD1, initialWeights)
229+
val modelB2 = lrB.run(testRDD2, initialWeights)
230+
val modelB3 = lrB.run(testRDD3, initialWeights)
231+
232+
// Test the weights
233+
assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01)
234+
assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01)
235+
236+
assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1)
237+
assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1)
238+
}
239+
188240
}
189241

190242
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {

0 commit comments

Comments
 (0)