Skip to content

Commit 1aad911

Browse files
miccagiannmengxr
authored andcommitted
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods
Related to Jira Issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC) Author: Michael Giannakopoulos <[email protected]> Closes apache#1775 from miccagiann/linearMethodsReg and squashes the following commits: cb774c3 [Michael Giannakopoulos] MiniBatchFraction added in related PythonMLLibAPI java stubs. 81fcbc6 [Michael Giannakopoulos] Fixing a typo-error. 8ad263e [Michael Giannakopoulos] Adding regularizer type and intercept parameters to LogisticRegressionWithSGD and SVMWithSGD.
1 parent acff9a7 commit 1aad911

File tree

2 files changed

+95
-21
lines changed

2 files changed

+95
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable {
271271
.setNumIterations(numIterations)
272272
.setRegParam(regParam)
273273
.setStepSize(stepSize)
274+
.setMiniBatchFraction(miniBatchFraction)
274275
if (regType == "l2") {
275276
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
276277
} else if (regType == "l1") {
@@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable {
341342
stepSize: Double,
342343
regParam: Double,
343344
miniBatchFraction: Double,
344-
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
345+
initialWeightsBA: Array[Byte],
346+
regType: String,
347+
intercept: Boolean): java.util.List[java.lang.Object] = {
348+
val SVMAlg = new SVMWithSGD()
349+
SVMAlg.setIntercept(intercept)
350+
SVMAlg.optimizer
351+
.setNumIterations(numIterations)
352+
.setRegParam(regParam)
353+
.setStepSize(stepSize)
354+
.setMiniBatchFraction(miniBatchFraction)
355+
if (regType == "l2") {
356+
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
357+
} else if (regType == "l1") {
358+
SVMAlg.optimizer.setUpdater(new L1Updater)
359+
} else if (regType != "none") {
360+
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
361+
+ " Can only be initialized using the following string values: [l1, l2, none].")
362+
}
345363
trainRegressionModel(
346364
(data, initialWeights) =>
347-
SVMWithSGD.train(
348-
data,
349-
numIterations,
350-
stepSize,
351-
regParam,
352-
miniBatchFraction,
353-
initialWeights),
365+
SVMAlg.run(data, initialWeights),
354366
dataBytesJRDD,
355367
initialWeightsBA)
356368
}
@@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable {
363375
numIterations: Int,
364376
stepSize: Double,
365377
miniBatchFraction: Double,
366-
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
378+
initialWeightsBA: Array[Byte],
379+
regParam: Double,
380+
regType: String,
381+
intercept: Boolean): java.util.List[java.lang.Object] = {
382+
val LogRegAlg = new LogisticRegressionWithSGD()
383+
LogRegAlg.setIntercept(intercept)
384+
LogRegAlg.optimizer
385+
.setNumIterations(numIterations)
386+
.setRegParam(regParam)
387+
.setStepSize(stepSize)
388+
.setMiniBatchFraction(miniBatchFraction)
389+
if (regType == "l2") {
390+
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
391+
} else if (regType == "l1") {
392+
LogRegAlg.optimizer.setUpdater(new L1Updater)
393+
} else if (regType != "none") {
394+
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
395+
+ " Can only be initialized using the following string values: [l1, l2, none].")
396+
}
367397
trainRegressionModel(
368398
(data, initialWeights) =>
369-
LogisticRegressionWithSGD.train(
370-
data,
371-
numIterations,
372-
stepSize,
373-
miniBatchFraction,
374-
initialWeights),
399+
LogRegAlg.run(data, initialWeights),
375400
dataBytesJRDD,
376401
initialWeightsBA)
377402
}

python/pyspark/mllib/classification.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,36 @@ def predict(self, x):
7373

7474
class LogisticRegressionWithSGD(object):
7575
@classmethod
76-
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
77-
"""Train a logistic regression model on the given data."""
76+
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
77+
initialWeights=None, regParam=1.0, regType=None, intercept=False):
78+
"""
79+
Train a logistic regression model on the given data.
80+
81+
@param data: The training data.
82+
@param iterations: The number of iterations (default: 100).
83+
@param step: The step parameter used in SGD
84+
(default: 1.0).
85+
@param miniBatchFraction: Fraction of data to be used for each SGD
86+
iteration.
87+
@param initialWeights: The initial weights (default: None).
88+
@param regParam: The regularizer parameter (default: 1.0).
89+
@param regType: The type of regularizer used for training
90+
our model.
91+
Allowed values: "l1" for using L1Updater,
92+
"l2" for using
93+
SquaredL2Updater,
94+
"none" for no regularizer.
95+
(default: "none")
96+
@param intercept: Boolean parameter which indicates the use
97+
or not of the augmented representation for
98+
training data (i.e. whether bias features
99+
are activated or not).
100+
"""
78101
sc = data.context
102+
if regType is None:
103+
regType = "none"
79104
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
80-
d._jrdd, iterations, step, miniBatchFraction, i)
105+
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
81106
return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
82107
initialWeights)
83108

@@ -115,11 +140,35 @@ def predict(self, x):
115140
class SVMWithSGD(object):
116141
@classmethod
117142
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
118-
miniBatchFraction=1.0, initialWeights=None):
119-
"""Train a support vector machine on the given data."""
143+
miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False):
144+
"""
145+
Train a support vector machine on the given data.
146+
147+
@param data: The training data.
148+
@param iterations: The number of iterations (default: 100).
149+
@param step: The step parameter used in SGD
150+
(default: 1.0).
151+
@param regParam: The regularizer parameter (default: 1.0).
152+
@param miniBatchFraction: Fraction of data to be used for each SGD
153+
iteration.
154+
@param initialWeights: The initial weights (default: None).
155+
@param regType: The type of regularizer used for training
156+
our model.
157+
Allowed values: "l1" for using L1Updater,
158+
"l2" for using
159+
SquaredL2Updater,
160+
"none" for no regularizer.
161+
(default: "none")
162+
@param intercept: Boolean parameter which indicates the use
163+
or not of the augmented representation for
164+
training data (i.e. whether bias features
165+
are activated or not).
166+
"""
120167
sc = data.context
168+
if regType is None:
169+
regType = "none"
121170
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
122-
d._jrdd, iterations, step, regParam, miniBatchFraction, i)
171+
d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
123172
return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)
124173

125174

0 commit comments

Comments
 (0)