Skip to content

Commit 006246c

Browse files
committed
murp?
1 parent e9471d3 commit 006246c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
4141
* Params for linear regression.
4242
*/
4343
private[regression] trait LinearRegressionParams extends PredictorParams
44-
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
44+
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with
45+
HasIntercept
4546

4647
/**
4748
* :: Experimental ::
@@ -121,8 +122,9 @@ class LinearRegression(override val uid: String)
121122
})
122123

123124
val numFeatures = summarizer.mean.size
124-
val yMean = statCounter.mean
125-
val yStd = math.sqrt(statCounter.variance)
125+
val yMean = if (hasIntercept) statCounter.mean else 0.0
126+
val yStd = if (hasIntercept) math.sqrt(statCounter.variance) else
127+
// look at glmnet6.m L761 maaaybe that has info
126128

127129
// If the yStd is zero, then the intercept is yMean with zero weights;
128130
// as a result, training is not needed.
@@ -180,6 +182,7 @@ class LinearRegression(override val uid: String)
180182
// The intercept in R's GLMNET is computed using closed form after the coefficients are
181183
// converged. See the following discussion for detail.
182184
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
185+
// Also see the scikit learn impl at https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/base.py
183186
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
184187
if (handlePersistence) instances.unpersist()
185188

0 commit comments

Comments
 (0)