@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
41
41
* Params for linear regression.
42
42
*/
43
43
private [regression] trait LinearRegressionParams extends PredictorParams
44
- with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
44
+ with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with
45
+ HasIntercept
45
46
46
47
/**
47
48
* :: Experimental ::
@@ -121,8 +122,9 @@ class LinearRegression(override val uid: String)
121
122
})
122
123
123
124
val numFeatures = summarizer.mean.size
124
- val yMean = statCounter.mean
125
- val yStd = math.sqrt(statCounter.variance)
125
+ val yMean = if (hasIntercept) statCounter.mean else 0.0
126
+ val yStd = if (hasIntercept) math.sqrt(statCounter.variance) else
127
+ // look at glmnet6.m L761 maaaybe that has info
126
128
127
129
// If the yStd is zero, then the intercept is yMean with zero weights;
128
130
// as a result, training is not needed.
@@ -180,6 +182,7 @@ class LinearRegression(override val uid: String)
180
182
// The intercept in R's GLMNET is computed using closed form after the coefficients are
181
183
// converged. See the following discussion for detail.
182
184
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
185
+ // Also see the scikit learn impl at https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/base.py
183
186
val intercept = yMean - dot(weights, Vectors .dense(featuresMean))
184
187
if (handlePersistence) instances.unpersist()
185
188
0 commit comments