Skip to content

Commit 456ab7c

Browse files
committed
update LRWithLBFGS
1 parent a75bc7a commit 456ab7c

File tree

3 files changed

+18
-39
lines changed

3 files changed

+18
-39
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.log4j.{Level, Logger}
2121
import scopt.OptionParser
2222

2323
import org.apache.spark.{SparkConf, SparkContext}
24-
import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD}
24+
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD}
2525
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
2626
import org.apache.spark.mllib.util.MLUtils
2727
import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}
@@ -66,7 +66,8 @@ object BinaryClassification {
6666
.text("number of iterations")
6767
.action((x, c) => c.copy(numIterations = x))
6868
opt[Double]("stepSize")
69-
.text(s"initial step size, default: ${defaultParams.stepSize}")
69+
.text("initial step size (ignored by logistic regression), " +
70+
s"default: ${defaultParams.stepSize}")
7071
.action((x, c) => c.copy(stepSize = x))
7172
opt[String]("algorithm")
7273
.text(s"algorithm (${Algorithm.values.mkString(",")}), " +
@@ -125,10 +126,9 @@ object BinaryClassification {
125126

126127
val model = params.algorithm match {
127128
case LR =>
128-
val algorithm = new LogisticRegressionWithSGD()
129+
val algorithm = new LogisticRegressionWithLBFGS()
129130
algorithm.optimizer
130131
.setNumIterations(params.numIterations)
131-
.setStepSize(params.stepSize)
132132
.setUpdater(updater)
133133
.setRegParam(params.regParam)
134134
algorithm.run(training).clearThreshold()

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class LogisticRegressionModel (
7373
/**
7474
* Train a classification model for Logistic Regression using Stochastic Gradient Descent.
7575
* NOTE: Labels used in Logistic Regression should be {0, 1}
76+
*
77+
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
7678
*/
7779
class LogisticRegressionWithSGD private (
7880
private var stepSize: Double,
@@ -191,51 +193,19 @@ object LogisticRegressionWithSGD {
191193

192194
/**
193195
* Train a classification model for Logistic Regression using Limited-memory BFGS.
196+
* Standard feature scaling and L2 regularization are used by default.
194197
* NOTE: Labels used in Logistic Regression should be {0, 1}
195198
*/
196-
class LogisticRegressionWithLBFGS private (
197-
private var convergenceTol: Double,
198-
private var maxNumIterations: Int,
199-
private var regParam: Double)
199+
class LogisticRegressionWithLBFGS
200200
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
201201

202-
/**
203-
* Construct a LogisticRegression object with default parameters
204-
*/
205-
def this() = this(1E-4, 100, 0.0)
206-
207202
this.setFeatureScaling(true)
208203

209-
private val gradient = new LogisticGradient()
210-
private val updater = new SimpleUpdater()
211-
// Have to return new LBFGS object every time since users can reset the parameters anytime.
212-
override def optimizer = new LBFGS(gradient, updater)
213-
.setNumCorrections(10)
214-
.setConvergenceTol(convergenceTol)
215-
.setMaxNumIterations(maxNumIterations)
216-
.setRegParam(regParam)
204+
override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)
217205

218206
override protected val validators = List(DataValidators.binaryLabelValidator)
219207

220-
/**
221-
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
222-
* Smaller value will lead to higher accuracy with the cost of more iterations.
223-
*/
224-
def setConvergenceTol(convergenceTol: Double): this.type = {
225-
this.convergenceTol = convergenceTol
226-
this
227-
}
228-
229-
/**
230-
* Set the maximal number of iterations for L-BFGS. Default 100.
231-
*/
232-
def setNumIterations(numIterations: Int): this.type = {
233-
this.maxNumIterations = numIterations
234-
this
235-
}
236-
237208
override protected def createModel(weights: Vector, intercept: Double) = {
238209
new LogisticRegressionModel(weights, intercept)
239210
}
240-
241211
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
6969

7070
/**
7171
* Set the maximal number of iterations for L-BFGS. Default 100.
72+
* @deprecated use [[setNumIterations()]] instead
7273
*/
74+
@deprecated("use setNumIterations instead", "1.1.0")
7375
def setMaxNumIterations(iters: Int): this.type = {
76+
this.setNumCorrections(iters)
77+
}
78+
79+
/**
80+
* Set the maximal number of iterations for L-BFGS. Default 100.
81+
*/
82+
def setNumIterations(iters: Int): this.type = {
7483
this.maxNumIterations = iters
7584
this
7685
}

0 commit comments

Comments
 (0)