Skip to content

Commit 5d25c0b

Browse files
committed
[SPARK-3078][MLLIB] Make LRWithLBFGS API consistent with others
Should ask users to set parameters through the optimizer. dbtsai Author: Xiangrui Meng <[email protected]> Closes #1973 from mengxr/lr-lbfgs and squashes the following commits: e3efbb1 [Xiangrui Meng] fix tests 21b3579 [Xiangrui Meng] fix method name 641eea4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lr-lbfgs 456ab7c [Xiangrui Meng] update LRWithLBFGS
1 parent cc36487 commit 5d25c0b

File tree

5 files changed

+33
-53
lines changed

5 files changed

+33
-53
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.log4j.{Level, Logger}
2121
import scopt.OptionParser
2222

2323
import org.apache.spark.{SparkConf, SparkContext}
24-
import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD}
24+
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD}
2525
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
2626
import org.apache.spark.mllib.util.MLUtils
2727
import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}
@@ -66,7 +66,8 @@ object BinaryClassification {
6666
.text("number of iterations")
6767
.action((x, c) => c.copy(numIterations = x))
6868
opt[Double]("stepSize")
69-
.text(s"initial step size, default: ${defaultParams.stepSize}")
69+
.text("initial step size (ignored by logistic regression), " +
70+
s"default: ${defaultParams.stepSize}")
7071
.action((x, c) => c.copy(stepSize = x))
7172
opt[String]("algorithm")
7273
.text(s"algorithm (${Algorithm.values.mkString(",")}), " +
@@ -125,10 +126,9 @@ object BinaryClassification {
125126

126127
val model = params.algorithm match {
127128
case LR =>
128-
val algorithm = new LogisticRegressionWithSGD()
129+
val algorithm = new LogisticRegressionWithLBFGS()
129130
algorithm.optimizer
130131
.setNumIterations(params.numIterations)
131-
.setStepSize(params.stepSize)
132132
.setUpdater(updater)
133133
.setRegParam(params.regParam)
134134
algorithm.run(training).clearThreshold()

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class LogisticRegressionModel (
7373
/**
7474
* Train a classification model for Logistic Regression using Stochastic Gradient Descent.
7575
* NOTE: Labels used in Logistic Regression should be {0, 1}
76+
*
77+
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
7678
*/
7779
class LogisticRegressionWithSGD private (
7880
private var stepSize: Double,
@@ -191,51 +193,19 @@ object LogisticRegressionWithSGD {
191193

192194
/**
193195
* Train a classification model for Logistic Regression using Limited-memory BFGS.
196+
* Standard feature scaling and L2 regularization are used by default.
194197
* NOTE: Labels used in Logistic Regression should be {0, 1}
195198
*/
196-
class LogisticRegressionWithLBFGS private (
197-
private var convergenceTol: Double,
198-
private var maxNumIterations: Int,
199-
private var regParam: Double)
199+
class LogisticRegressionWithLBFGS
200200
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
201201

202-
/**
203-
* Construct a LogisticRegression object with default parameters
204-
*/
205-
def this() = this(1E-4, 100, 0.0)
206-
207202
this.setFeatureScaling(true)
208203

209-
private val gradient = new LogisticGradient()
210-
private val updater = new SimpleUpdater()
211-
// Have to return new LBFGS object every time since users can reset the parameters anytime.
212-
override def optimizer = new LBFGS(gradient, updater)
213-
.setNumCorrections(10)
214-
.setConvergenceTol(convergenceTol)
215-
.setMaxNumIterations(maxNumIterations)
216-
.setRegParam(regParam)
204+
override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)
217205

218206
override protected val validators = List(DataValidators.binaryLabelValidator)
219207

220-
/**
221-
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
222-
* Smaller value will lead to higher accuracy with the cost of more iterations.
223-
*/
224-
def setConvergenceTol(convergenceTol: Double): this.type = {
225-
this.convergenceTol = convergenceTol
226-
this
227-
}
228-
229-
/**
230-
* Set the maximal number of iterations for L-BFGS. Default 100.
231-
*/
232-
def setNumIterations(numIterations: Int): this.type = {
233-
this.maxNumIterations = numIterations
234-
this
235-
}
236-
237208
override protected def createModel(weights: Vector, intercept: Double) = {
238209
new LogisticRegressionModel(weights, intercept)
239210
}
240-
241211
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
6969

7070
/**
7171
* Set the maximal number of iterations for L-BFGS. Default 100.
72+
* @deprecated use [[LBFGS#setNumIterations]] instead
7273
*/
74+
@deprecated("use setNumIterations instead", "1.1.0")
7375
def setMaxNumIterations(iters: Int): this.type = {
76+
this.setNumIterations(iters)
77+
}
78+
79+
/**
80+
* Set the maximal number of iterations for L-BFGS. Default 100.
81+
*/
82+
def setNumIterations(iters: Int): this.type = {
7483
this.maxNumIterations = iters
7584
this
7685
}

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,9 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont
272272
}.cache()
273273
// If we serialize data directly in the task closure, the size of the serialized task would be
274274
// greater than 1MB and hence Spark would throw an error.
275-
val model =
276-
(new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points)
275+
val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
276+
lr.optimizer.setNumIterations(2)
277+
val model = lr.run(points)
277278

278279
val predictions = model.predict(points.map(_.features))
279280

mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
5555

5656
val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray)
5757
val convergenceTol = 1e-12
58-
val maxNumIterations = 10
58+
val numIterations = 10
5959

6060
val (_, loss) = LBFGS.runLBFGS(
6161
dataRDD,
6262
gradient,
6363
simpleUpdater,
6464
numCorrections,
6565
convergenceTol,
66-
maxNumIterations,
66+
numIterations,
6767
regParam,
6868
initialWeightsWithIntercept)
6969

@@ -99,15 +99,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
9999
// Prepare another non-zero weights to compare the loss in the first iteration.
100100
val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
101101
val convergenceTol = 1e-12
102-
val maxNumIterations = 10
102+
val numIterations = 10
103103

104104
val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS(
105105
dataRDD,
106106
gradient,
107107
squaredL2Updater,
108108
numCorrections,
109109
convergenceTol,
110-
maxNumIterations,
110+
numIterations,
111111
regParam,
112112
initialWeightsWithIntercept)
113113

@@ -140,10 +140,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
140140

141141
/**
142142
* For the first run, we set the convergenceTol to 0.0, so that the algorithm will
143-
* run up to the maxNumIterations which is 8 here.
143+
* run up to the numIterations which is 8 here.
144144
*/
145145
val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0)
146-
val maxNumIterations = 8
146+
val numIterations = 8
147147
var convergenceTol = 0.0
148148

149149
val (_, lossLBFGS1) = LBFGS.runLBFGS(
@@ -152,7 +152,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
152152
squaredL2Updater,
153153
numCorrections,
154154
convergenceTol,
155-
maxNumIterations,
155+
numIterations,
156156
regParam,
157157
initialWeightsWithIntercept)
158158

@@ -167,7 +167,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
167167
squaredL2Updater,
168168
numCorrections,
169169
convergenceTol,
170-
maxNumIterations,
170+
numIterations,
171171
regParam,
172172
initialWeightsWithIntercept)
173173

@@ -182,7 +182,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
182182
squaredL2Updater,
183183
numCorrections,
184184
convergenceTol,
185-
maxNumIterations,
185+
numIterations,
186186
regParam,
187187
initialWeightsWithIntercept)
188188

@@ -200,12 +200,12 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
200200
// Prepare another non-zero weights to compare the loss in the first iteration.
201201
val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
202202
val convergenceTol = 1e-12
203-
val maxNumIterations = 10
203+
val numIterations = 10
204204

205205
val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater)
206206
.setNumCorrections(numCorrections)
207207
.setConvergenceTol(convergenceTol)
208-
.setMaxNumIterations(maxNumIterations)
208+
.setNumIterations(numIterations)
209209
.setRegParam(regParam)
210210

211211
val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept)
@@ -241,7 +241,7 @@ class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
241241
val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
242242
.setNumCorrections(1)
243243
.setConvergenceTol(1e-12)
244-
.setMaxNumIterations(1)
244+
.setNumIterations(1)
245245
.setRegParam(1.0)
246246
val random = new Random(0)
247247
// If we serialize data directly in the task closure, the size of the serialized task would be

0 commit comments

Comments
 (0)