Skip to content

Commit 2b1111d

Browse files
holdenkDB Tsai
authored andcommitted
[SPARK-7888] Be able to disable intercept in linear regression in ml package
Author: Holden Karau <[email protected]> Closes apache#6927 from holdenk/SPARK-7888-Be-able-to-disable-intercept-in-Linear-Regression-in-ML-package and squashes the following commits: 0ad384c [Holden Karau] Add MiMa excludes 4016fac [Holden Karau] Switch to wild card import, remove extra blank lines ae5baa8 [Holden Karau] CR feedback, move the fitIntercept down rather than changing ymean and etc above f34971c [Holden Karau] Fix some more long lines 319bd3f [Holden Karau] Fix long lines 3bb9ee1 [Holden Karau] Update the regression suite tests 7015b9f [Holden Karau] Our code performs the same with R, except we need more than one data point but that seems reasonable 0b0c8c0 [Holden Karau] fix the issue with the sample R code e2140ba [Holden Karau] Add a test, it fails! 5e84a0b [Holden Karau] Write out thoughts and use the correct trait 91ffc0a [Holden Karau] more murh 006246c [Holden Karau] murp?
1 parent 6f4cadf commit 2b1111d

File tree

3 files changed

+172
-12
lines changed

3 files changed

+172
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.Logging
2626
import org.apache.spark.annotation.Experimental
2727
import org.apache.spark.ml.PredictorParams
2828
import org.apache.spark.ml.param.ParamMap
29-
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
29+
import org.apache.spark.ml.param.shared._
3030
import org.apache.spark.ml.util.Identifiable
3131
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3232
import org.apache.spark.mllib.linalg.BLAS._
@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
4141
* Params for linear regression.
4242
*/
4343
private[regression] trait LinearRegressionParams extends PredictorParams
44-
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
44+
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
45+
with HasFitIntercept
4546

4647
/**
4748
* :: Experimental ::
@@ -72,6 +73,14 @@ class LinearRegression(override val uid: String)
7273
def setRegParam(value: Double): this.type = set(regParam, value)
7374
setDefault(regParam -> 0.0)
7475

76+
/**
77+
* Set if we should fit the intercept
78+
* Default is true.
79+
* @group setParam
80+
*/
81+
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
82+
setDefault(fitIntercept -> true)
83+
7584
/**
7685
* Set the ElasticNet mixing parameter.
7786
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
@@ -123,6 +132,7 @@ class LinearRegression(override val uid: String)
123132
val numFeatures = summarizer.mean.size
124133
val yMean = statCounter.mean
125134
val yStd = math.sqrt(statCounter.variance)
135+
// look at glmnet5.m L761 maaaybe that has info
126136

127137
// If the yStd is zero, then the intercept is yMean with zero weights;
128138
// as a result, training is not needed.
@@ -142,7 +152,7 @@ class LinearRegression(override val uid: String)
142152
val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
143153
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
144154

145-
val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
155+
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
146156
featuresStd, featuresMean, effectiveL2RegParam)
147157

148158
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
@@ -180,7 +190,7 @@ class LinearRegression(override val uid: String)
180190
// The intercept in R's GLMNET is computed using closed form after the coefficients are
181191
// converged. See the following discussion for detail.
182192
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
183-
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
193+
val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
184194
if (handlePersistence) instances.unpersist()
185195

186196
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
@@ -234,13 +244,18 @@ class LinearRegressionModel private[ml] (
234244
* See this discussion for detail.
235245
* http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
236246
*
247+
* When training with intercept enabled,
237248
* The objective function in the scaled space is given by
238249
* {{{
239250
* L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
240251
* }}}
241252
* where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
242253
* \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
243254
*
255+
* If we fitting the intercept disabled (that is forced through 0.0),
256+
* we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
257+
* of the respective means.
258+
*
244259
* This can be rewritten as
245260
* {{{
246261
* L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
@@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] (
255270
* \sum_i w_i^\prime x_i - y / \hat{y} + offset
256271
* }}}
257272
*
273+
*
258274
* Note that the effective weights and offset don't depend on training dataset,
259275
* so they can be precomputed.
260276
*
@@ -301,6 +317,7 @@ private class LeastSquaresAggregator(
301317
weights: Vector,
302318
labelStd: Double,
303319
labelMean: Double,
320+
fitIntercept: Boolean,
304321
featuresStd: Array[Double],
305322
featuresMean: Array[Double]) extends Serializable {
306323

@@ -321,7 +338,7 @@ private class LeastSquaresAggregator(
321338
}
322339
i += 1
323340
}
324-
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
341+
(weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
325342
}
326343

327344
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
@@ -404,6 +421,7 @@ private class LeastSquaresCostFun(
404421
data: RDD[(Double, Vector)],
405422
labelStd: Double,
406423
labelMean: Double,
424+
fitIntercept: Boolean,
407425
featuresStd: Array[Double],
408426
featuresMean: Array[Double],
409427
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
@@ -412,7 +430,7 @@ private class LeastSquaresCostFun(
412430
val w = Vectors.fromBreeze(weights)
413431

414432
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
415-
labelMean, featuresStd, featuresMean))(
433+
labelMean, fitIntercept, featuresStd, featuresMean))(
416434
seqOp = (c, v) => (c, v) match {
417435
case (aggregator, (label, features)) => aggregator.add(label, features)
418436
},

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 143 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
2626
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
2727

2828
@transient var dataset: DataFrame = _
29+
@transient var datasetWithoutIntercept: DataFrame = _
2930

3031
/**
3132
* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
3435
*
3536
* import org.apache.spark.mllib.util.LinearDataGenerator
3637
* val data =
37-
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
38-
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
38+
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
39+
* Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
40+
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
41+
* .saveAsTextFile("path")
3942
*/
4043
override def beforeAll(): Unit = {
4144
super.beforeAll()
4245
dataset = sqlContext.createDataFrame(
4346
sc.parallelize(LinearDataGenerator.generateLinearInput(
4447
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
48+
/**
49+
* datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
50+
* training model without intercept
51+
*/
52+
datasetWithoutIntercept = sqlContext.createDataFrame(
53+
sc.parallelize(LinearDataGenerator.generateLinearInput(
54+
0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
55+
4556
}
4657

4758
test("linear regression with intercept without regularization") {
@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
7889
}
7990
}
8091

92+
test("linear regression without intercept without regularization") {
93+
val trainer = (new LinearRegression).setFitIntercept(false)
94+
val model = trainer.fit(dataset)
95+
val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
96+
97+
/**
98+
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
99+
* intercept = FALSE))
100+
* > weights
101+
* 3 x 1 sparse Matrix of class "dgCMatrix"
102+
* s0
103+
* (Intercept) .
104+
* as.numeric.data.V2. 6.995908
105+
* as.numeric.data.V3. 5.275131
106+
*/
107+
val weightsR = Array(6.995908, 5.275131)
108+
109+
assert(model.intercept ~== 0 relTol 1E-3)
110+
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
111+
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
112+
/**
113+
* Then again with the data with no intercept:
114+
* > weightsWithoutIntercept
115+
* 3 x 1 sparse Matrix of class "dgCMatrix"
116+
* s0
117+
* (Intercept) .
118+
* as.numeric.data3.V2. 4.70011
119+
* as.numeric.data3.V3. 7.19943
120+
*/
121+
val weightsWithoutInterceptR = Array(4.70011, 7.19943)
122+
123+
assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
124+
assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
125+
assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
126+
}
127+
81128
test("linear regression with intercept with L1 regularization") {
82129
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
83130
val model = trainer.fit(dataset)
@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
87134
* > weights
88135
* 3 x 1 sparse Matrix of class "dgCMatrix"
89136
* s0
90-
* (Intercept) 6.311546
91-
* as.numeric.data.V2. 2.123522
92-
* as.numeric.data.V3. 4.605651
137+
* (Intercept) 6.24300
138+
* as.numeric.data.V2. 4.024821
139+
* as.numeric.data.V3. 6.679841
93140
*/
94-
val interceptR = 6.243000
141+
val interceptR = 6.24300
95142
val weightsR = Array(4.024821, 6.679841)
96143

97144
assert(model.intercept ~== interceptR relTol 1E-3)
@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
106153
}
107154
}
108155

156+
test("linear regression without intercept with L1 regularization") {
157+
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
158+
.setFitIntercept(false)
159+
val model = trainer.fit(dataset)
160+
161+
/**
162+
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
163+
* intercept=FALSE))
164+
* > weights
165+
* 3 x 1 sparse Matrix of class "dgCMatrix"
166+
* s0
167+
* (Intercept) .
168+
* as.numeric.data.V2. 6.299752
169+
* as.numeric.data.V3. 4.772913
170+
*/
171+
val interceptR = 0.0
172+
val weightsR = Array(6.299752, 4.772913)
173+
174+
assert(model.intercept ~== interceptR relTol 1E-3)
175+
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
176+
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
177+
178+
model.transform(dataset).select("features", "prediction").collect().foreach {
179+
case Row(features: DenseVector, prediction1: Double) =>
180+
val prediction2 =
181+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
182+
assert(prediction1 ~== prediction2 relTol 1E-5)
183+
}
184+
}
185+
109186
test("linear regression with intercept with L2 regularization") {
110187
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
111188
val model = trainer.fit(dataset)
@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
134211
}
135212
}
136213

214+
test("linear regression without intercept with L2 regularization") {
215+
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
216+
.setFitIntercept(false)
217+
val model = trainer.fit(dataset)
218+
219+
/**
220+
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
221+
* intercept = FALSE))
222+
* > weights
223+
* 3 x 1 sparse Matrix of class "dgCMatrix"
224+
* s0
225+
* (Intercept) .
226+
* as.numeric.data.V2. 5.522875
227+
* as.numeric.data.V3. 4.214502
228+
*/
229+
val interceptR = 0.0
230+
val weightsR = Array(5.522875, 4.214502)
231+
232+
assert(model.intercept ~== interceptR relTol 1E-3)
233+
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
234+
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
235+
236+
model.transform(dataset).select("features", "prediction").collect().foreach {
237+
case Row(features: DenseVector, prediction1: Double) =>
238+
val prediction2 =
239+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
240+
assert(prediction1 ~== prediction2 relTol 1E-5)
241+
}
242+
}
243+
137244
test("linear regression with intercept with ElasticNet regularization") {
138245
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
139246
val model = trainer.fit(dataset)
@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
161268
assert(prediction1 ~== prediction2 relTol 1E-5)
162269
}
163270
}
271+
272+
test("linear regression without intercept with ElasticNet regularization") {
273+
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
274+
.setFitIntercept(false)
275+
val model = trainer.fit(dataset)
276+
277+
/**
278+
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
279+
* intercept=FALSE))
280+
* > weights
281+
* 3 x 1 sparse Matrix of class "dgCMatrix"
282+
* s0
283+
* (Intercept) .
284+
* as.numeric.dataM.V2. 5.673348
285+
* as.numeric.dataM.V3. 4.322251
286+
*/
287+
val interceptR = 0.0
288+
val weightsR = Array(5.673348, 4.322251)
289+
290+
assert(model.intercept ~== interceptR relTol 1E-3)
291+
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
292+
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
293+
294+
model.transform(dataset).select("features", "prediction").collect().foreach {
295+
case Row(features: DenseVector, prediction1: Double) =>
296+
val prediction2 =
297+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
298+
assert(prediction1 ~== prediction2 relTol 1E-5)
299+
}
300+
}
164301
}

project/MimaExcludes.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ object MimaExcludes {
5353
// Removing a testing method from a private class
5454
ProblemFilters.exclude[MissingMethodProblem](
5555
"org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
56+
// While private MiMa is still not happy about the changes,
57+
ProblemFilters.exclude[MissingMethodProblem](
58+
"org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
59+
ProblemFilters.exclude[MissingMethodProblem](
60+
"org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
5661
// SQL execution is considered private.
5762
excludePackage("org.apache.spark.sql.execution"),
5863
// NanoTime and CatalystTimestampConverter is only used inside catalyst,

0 commit comments

Comments
 (0)