17
17
18
18
package org .apache .spark .mllib .regression
19
19
20
- import breeze .linalg .{Vector => BV }
21
-
22
20
import org .apache .spark .SparkContext
23
21
import org .apache .spark .rdd .RDD
24
22
import org .apache .spark .mllib .optimization ._
25
23
import org .apache .spark .mllib .util .MLUtils
26
- import org .apache .spark .mllib .linalg .{ Vectors , Vector }
24
+ import org .apache .spark .mllib .linalg .Vector
27
25
28
26
/**
29
27
* Regression model trained using RidgeRegression.
@@ -58,8 +56,7 @@ class RidgeRegressionWithSGD private (
58
56
var numIterations : Int ,
59
57
var regParam : Double ,
60
58
var miniBatchFraction : Double )
61
- extends GeneralizedLinearAlgorithm [RidgeRegressionModel ]
62
- with Serializable {
59
+ extends GeneralizedLinearAlgorithm [RidgeRegressionModel ] with Serializable {
63
60
64
61
val gradient = new LeastSquaresGradient ()
65
62
val updater = new SquaredL2Updater ()
@@ -72,10 +69,6 @@ class RidgeRegressionWithSGD private (
72
69
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
73
70
super .setIntercept(false )
74
71
75
- private var yMean = 0.0
76
- private var xColMean : BV [Double ] = _
77
- private var xColSd : BV [Double ] = _
78
-
79
72
/**
80
73
* Construct a RidgeRegression object with default parameters
81
74
*/
@@ -88,35 +81,7 @@ class RidgeRegressionWithSGD private (
88
81
}
89
82
90
83
override protected def createModel (weights : Vector , intercept : Double ) = {
91
- val weightsMat = weights.toBreeze
92
- val weightsScaled = weightsMat :/ xColSd
93
- val interceptScaled = yMean - weightsMat.dot(xColMean :/ xColSd)
94
-
95
- new RidgeRegressionModel (Vectors .fromBreeze(weightsScaled), interceptScaled)
96
- }
97
-
98
- override def run (
99
- input : RDD [LabeledPoint ],
100
- initialWeights : Vector )
101
- : RidgeRegressionModel =
102
- {
103
- val nfeatures : Int = input.first().features.size
104
- val nexamples : Long = input.count()
105
-
106
- // To avoid penalizing the intercept, we center and scale the data.
107
- val stats = MLUtils .computeStats(input, nfeatures, nexamples)
108
- yMean = stats._1
109
- xColMean = stats._2.toBreeze
110
- xColSd = stats._3.toBreeze
111
-
112
- val normalizedData = input.map { point =>
113
- val yNormalized = point.label - yMean
114
- val featuresMat = point.features.toBreeze
115
- val featuresNormalized = (featuresMat - xColMean) :/ xColSd
116
- LabeledPoint (yNormalized, Vectors .fromBreeze(featuresNormalized))
117
- }
118
-
119
- super .run(normalizedData, initialWeights)
84
+ new RidgeRegressionModel (weights, intercept)
120
85
}
121
86
}
122
87
@@ -145,9 +110,7 @@ object RidgeRegressionWithSGD {
145
110
stepSize : Double ,
146
111
regParam : Double ,
147
112
miniBatchFraction : Double ,
148
- initialWeights : Vector )
149
- : RidgeRegressionModel =
150
- {
113
+ initialWeights : Vector ): RidgeRegressionModel = {
151
114
new RidgeRegressionWithSGD (stepSize, numIterations, regParam, miniBatchFraction).run(
152
115
input, initialWeights)
153
116
}
@@ -168,9 +131,7 @@ object RidgeRegressionWithSGD {
168
131
numIterations : Int ,
169
132
stepSize : Double ,
170
133
regParam : Double ,
171
- miniBatchFraction : Double )
172
- : RidgeRegressionModel =
173
- {
134
+ miniBatchFraction : Double ): RidgeRegressionModel = {
174
135
new RidgeRegressionWithSGD (stepSize, numIterations, regParam, miniBatchFraction).run(input)
175
136
}
176
137
@@ -189,9 +150,7 @@ object RidgeRegressionWithSGD {
189
150
input : RDD [LabeledPoint ],
190
151
numIterations : Int ,
191
152
stepSize : Double ,
192
- regParam : Double )
193
- : RidgeRegressionModel =
194
- {
153
+ regParam : Double ): RidgeRegressionModel = {
195
154
train(input, numIterations, stepSize, regParam, 1.0 )
196
155
}
197
156
@@ -206,9 +165,7 @@ object RidgeRegressionWithSGD {
206
165
*/
207
166
def train (
208
167
input : RDD [LabeledPoint ],
209
- numIterations : Int )
210
- : RidgeRegressionModel =
211
- {
168
+ numIterations : Int ): RidgeRegressionModel = {
212
169
train(input, numIterations, 1.0 , 1.0 , 1.0 )
213
170
}
214
171
0 commit comments