Skip to content

Commit f04fe8a

Browse files
committed
remove normalization from RidgeRegression and update tests
1 parent d088552 commit f04fe8a

File tree

3 files changed

+36
-83
lines changed

3 files changed

+36
-83
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import breeze.linalg.{Vector => BV}
21-
2220
import org.apache.spark.SparkContext
2321
import org.apache.spark.rdd.RDD
2422
import org.apache.spark.mllib.optimization._
2523
import org.apache.spark.mllib.util.MLUtils
26-
import org.apache.spark.mllib.linalg.{Vectors, Vector}
24+
import org.apache.spark.mllib.linalg.Vector
2725

2826
/**
2927
* Regression model trained using RidgeRegression.
@@ -58,8 +56,7 @@ class RidgeRegressionWithSGD private (
5856
var numIterations: Int,
5957
var regParam: Double,
6058
var miniBatchFraction: Double)
61-
extends GeneralizedLinearAlgorithm[RidgeRegressionModel]
62-
with Serializable {
59+
extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable {
6360

6461
val gradient = new LeastSquaresGradient()
6562
val updater = new SquaredL2Updater()
@@ -72,10 +69,6 @@ class RidgeRegressionWithSGD private (
7269
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
7370
super.setIntercept(false)
7471

75-
private var yMean = 0.0
76-
private var xColMean: BV[Double] = _
77-
private var xColSd: BV[Double] = _
78-
7972
/**
8073
* Construct a RidgeRegression object with default parameters
8174
*/
@@ -88,35 +81,7 @@ class RidgeRegressionWithSGD private (
8881
}
8982

9083
override protected def createModel(weights: Vector, intercept: Double) = {
91-
val weightsMat = weights.toBreeze
92-
val weightsScaled = weightsMat :/ xColSd
93-
val interceptScaled = yMean - weightsMat.dot(xColMean :/ xColSd)
94-
95-
new RidgeRegressionModel(Vectors.fromBreeze(weightsScaled), interceptScaled)
96-
}
97-
98-
override def run(
99-
input: RDD[LabeledPoint],
100-
initialWeights: Vector)
101-
: RidgeRegressionModel =
102-
{
103-
val nfeatures: Int = input.first().features.size
104-
val nexamples: Long = input.count()
105-
106-
// To avoid penalizing the intercept, we center and scale the data.
107-
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
108-
yMean = stats._1
109-
xColMean = stats._2.toBreeze
110-
xColSd = stats._3.toBreeze
111-
112-
val normalizedData = input.map { point =>
113-
val yNormalized = point.label - yMean
114-
val featuresMat = point.features.toBreeze
115-
val featuresNormalized = (featuresMat - xColMean) :/ xColSd
116-
LabeledPoint(yNormalized, Vectors.fromBreeze(featuresNormalized))
117-
}
118-
119-
super.run(normalizedData, initialWeights)
84+
new RidgeRegressionModel(weights, intercept)
12085
}
12186
}
12287

@@ -145,9 +110,7 @@ object RidgeRegressionWithSGD {
145110
stepSize: Double,
146111
regParam: Double,
147112
miniBatchFraction: Double,
148-
initialWeights: Vector)
149-
: RidgeRegressionModel =
150-
{
113+
initialWeights: Vector): RidgeRegressionModel = {
151114
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(
152115
input, initialWeights)
153116
}
@@ -168,9 +131,7 @@ object RidgeRegressionWithSGD {
168131
numIterations: Int,
169132
stepSize: Double,
170133
regParam: Double,
171-
miniBatchFraction: Double)
172-
: RidgeRegressionModel =
173-
{
134+
miniBatchFraction: Double): RidgeRegressionModel = {
174135
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
175136
}
176137

@@ -189,9 +150,7 @@ object RidgeRegressionWithSGD {
189150
input: RDD[LabeledPoint],
190151
numIterations: Int,
191152
stepSize: Double,
192-
regParam: Double)
193-
: RidgeRegressionModel =
194-
{
153+
regParam: Double): RidgeRegressionModel = {
195154
train(input, numIterations, stepSize, regParam, 1.0)
196155
}
197156

@@ -206,9 +165,7 @@ object RidgeRegressionWithSGD {
206165
*/
207166
def train(
208167
input: RDD[LabeledPoint],
209-
numIterations: Int)
210-
: RidgeRegressionModel =
211-
{
168+
numIterations: Int): RidgeRegressionModel = {
212169
train(input, numIterations, 1.0, 1.0, 1.0)
213170
}
214171

mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,27 @@ public void tearDown() {
5555
return errorSum / validationData.size();
5656
}
5757

58-
List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) {
58+
List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
5959
org.jblas.util.Random.seed(42);
6060
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
61-
DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
62-
// Set first two weights to eps
63-
w.put(0, 0, eps);
64-
w.put(1, 0, eps);
65-
return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
61+
DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5);
62+
return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std);
6663
}
6764

6865
@Test
6966
public void runRidgeRegressionUsingConstructor() {
70-
int nexamples = 200;
71-
int nfeatures = 20;
72-
double eps = 10.0;
73-
List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
67+
int numExamples = 50;
68+
int numFeatures = 20;
69+
List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);
7470

75-
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
76-
List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
71+
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
72+
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
7773

7874
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
79-
ridgeSGDImpl.optimizer().setStepSize(1.0)
80-
.setRegParam(0.0)
81-
.setNumIterations(200);
75+
ridgeSGDImpl.optimizer()
76+
.setStepSize(1.0)
77+
.setRegParam(0.0)
78+
.setNumIterations(200);
8279
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
8380
double unRegularizedErr = predictionError(validationData, model);
8481

@@ -91,13 +88,12 @@ public void runRidgeRegressionUsingConstructor() {
9188

9289
@Test
9390
public void runRidgeRegressionUsingStaticMethods() {
94-
int nexamples = 200;
95-
int nfeatures = 20;
96-
double eps = 10.0;
97-
List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
91+
int numExamples = 50;
92+
int numFeatures = 20;
93+
List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
9894

99-
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
100-
List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
95+
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
96+
List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
10197

10298
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
10399
double unRegularizedErr = predictionError(validationData, model);

mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
3131
}.reduceLeft(_ + _) / predictions.size
3232
}
3333

34-
test("regularization with skewed weights") {
35-
val nexamples = 200
36-
val nfeatures = 20
37-
val eps = 10
34+
test("ridge regression can help avoid overfitting") {
35+
36+
// For small number of examples and large variance of error distribution,
37+
// ridge regression should give smaller generalization error that linear regression.
38+
39+
val numExamples = 50
40+
val numFeatures = 20
3841

3942
org.jblas.util.Random.seed(42)
4043
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
41-
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
42-
// Set first two weights to eps
43-
w.put(0, 0, eps)
44-
w.put(1, 0, eps)
44+
val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5)
4545

4646
// Use half of data for training and other half for validation
47-
val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
48-
val testData = data.take(nexamples)
49-
val validationData = data.takeRight(nexamples)
47+
val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0)
48+
val testData = data.take(numExamples)
49+
val validationData = data.takeRight(numExamples)
5050

5151
val testRDD = sc.parallelize(testData, 2).cache()
5252
val validationRDD = sc.parallelize(validationData, 2).cache()
@@ -68,7 +68,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
6868
val ridgeErr = predictionError(
6969
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
7070

71-
// Ridge CV-error should be lower than linear regression
71+
// Ridge validation error should be lower than linear regression.
7272
assert(ridgeErr < linearErr,
7373
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
7474
}

0 commit comments

Comments
 (0)