Skip to content

Commit 0e57aa4

Browse files
committed
update Lasso and RidgeRegression to parse the weights correctly from GLM
mark createModel protected mark predictPoint protected
1 parent d7f629f commit 0e57aa4

File tree

4 files changed

+38
-22
lines changed

4 files changed

+38
-22
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
4444
* @param weightMatrix Column vector containing the weights of the model
4545
* @param intercept Intercept of the model.
4646
*/
47-
def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
47+
protected def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
4848
intercept: Double): Double
4949

5050
/**

mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ class LassoModel(
3636
extends GeneralizedLinearModel(weights, intercept)
3737
with RegressionModel with Serializable {
3838

39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
39+
override protected def predictPoint(
40+
dataMatrix: DoubleMatrix,
41+
weightMatrix: DoubleMatrix,
42+
intercept: Double): Double = {
4143
dataMatrix.dot(weightMatrix) + intercept
4244
}
4345
}
@@ -66,7 +68,7 @@ class LassoWithSGD private (
6668
.setMiniBatchFraction(miniBatchFraction)
6769

6870
// We don't want to penalize the intercept, so set this to false.
69-
setIntercept(false)
71+
super.setIntercept(false)
7072

7173
var yMean = 0.0
7274
var xColMean: DoubleMatrix = _
@@ -77,10 +79,16 @@ class LassoWithSGD private (
7779
*/
7880
def this() = this(1.0, 100, 1.0, 1.0)
7981

80-
def createModel(weights: Array[Double], intercept: Double) = {
81-
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
82+
override def setIntercept(addIntercept: Boolean): this.type = {
83+
// TODO: Support adding intercept.
84+
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
85+
this
86+
}
87+
88+
override protected def createModel(weights: Array[Double], intercept: Double) = {
89+
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
8290
val weightsScaled = weightsMat.div(xColSd)
83-
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
91+
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
8492

8593
new LassoModel(weightsScaled.data, interceptScaled)
8694
}

mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
3131
* @param intercept Intercept computed for this model.
3232
*/
3333
class LinearRegressionModel(
34-
override val weights: Array[Double],
35-
override val intercept: Double)
36-
extends GeneralizedLinearModel(weights, intercept)
37-
with RegressionModel with Serializable {
38-
39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
34+
override val weights: Array[Double],
35+
override val intercept: Double)
36+
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
37+
38+
override protected def predictPoint(
39+
dataMatrix: DoubleMatrix,
40+
weightMatrix: DoubleMatrix,
41+
intercept: Double): Double = {
4142
dataMatrix.dot(weightMatrix) + intercept
4243
}
4344
}
@@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
5556
var stepSize: Double,
5657
var numIterations: Int,
5758
var miniBatchFraction: Double)
58-
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
59-
with Serializable {
59+
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
6060

6161
val gradient = new LeastSquaresGradient()
6262
val updater = new SimpleUpdater()
@@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
6969
*/
7070
def this() = this(1.0, 100, 1.0)
7171

72-
def createModel(weights: Array[Double], intercept: Double) = {
72+
override protected def createModel(weights: Array[Double], intercept: Double) = {
7373
new LinearRegressionModel(weights, intercept)
7474
}
7575
}

mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ class RidgeRegressionModel(
3636
extends GeneralizedLinearModel(weights, intercept)
3737
with RegressionModel with Serializable {
3838

39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
39+
override protected def predictPoint(
40+
dataMatrix: DoubleMatrix,
41+
weightMatrix: DoubleMatrix,
42+
intercept: Double): Double = {
4143
dataMatrix.dot(weightMatrix) + intercept
4244
}
4345
}
@@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
6769
.setMiniBatchFraction(miniBatchFraction)
6870

6971
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
70-
setIntercept(false)
72+
super.setIntercept(false)
7173

7274
var yMean = 0.0
7375
var xColMean: DoubleMatrix = _
@@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
7880
*/
7981
def this() = this(1.0, 100, 1.0, 1.0)
8082

81-
def createModel(weights: Array[Double], intercept: Double) = {
82-
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
83+
override def setIntercept(addIntercept: Boolean): this.type = {
84+
// TODO: Support adding intercept.
85+
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
86+
this
87+
}
88+
89+
override protected def createModel(weights: Array[Double], intercept: Double) = {
90+
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
8391
val weightsScaled = weightsMat.div(xColSd)
8492
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
8593

0 commit comments

Comments
 (0)