Skip to content

Commit 8ad263e

Browse files
committed
Adding regularizer type and intercept parameters to LogisticRegressionWithSGD and SVMWithSGD.
1 parent 8e7d5ba commit 8ad263e

File tree

2 files changed

+92
-21
lines changed

2 files changed

+92
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -341,16 +341,26 @@ class PythonMLLibAPI extends Serializable {
341341
stepSize: Double,
342342
regParam: Double,
343343
miniBatchFraction: Double,
344-
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
344+
initialWeightsBA: Array[Byte],
345+
regType: String,
346+
intercept: Boolean): java.util.List[java.lang.Object] = {
347+
val SVMAlg = new SVMWithSGD()
348+
SVMAlg.setIntercept(intercept)
349+
SVMAlg.optimizer
350+
.setNumIterations(numIterations)
351+
.setRegParam(regParam)
352+
.setStepSize(stepSize)
353+
if (regType == "l2") {
354+
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
355+
} else if (regType == "l1") {
356+
SVMAlg.optimizer.setUpdater(new L1Updater)
357+
} else if (regType != "none") {
358+
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
359+
+ " Can only be initialized using the following string values: [l1, l2, none].")
360+
}
345361
trainRegressionModel(
346362
(data, initialWeights) =>
347-
SVMWithSGD.train(
348-
data,
349-
numIterations,
350-
stepSize,
351-
regParam,
352-
miniBatchFraction,
353-
initialWeights),
363+
SVMAlg.run(data, initialWeights),
354364
dataBytesJRDD,
355365
initialWeightsBA)
356366
}
@@ -363,15 +373,27 @@ class PythonMLLibAPI extends Serializable {
363373
numIterations: Int,
364374
stepSize: Double,
365375
miniBatchFraction: Double,
366-
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
376+
initialWeightsBA: Array[Byte],
377+
regParam: Double,
378+
regType: String,
379+
intercept: Boolean): java.util.List[java.lang.Object] = {
380+
val LogRegAlg = new LogisticRegressionWithSGD()
381+
LogRegAlg.setIntercept(intercept)
382+
LogRegAlg.optimizer
383+
.setNumIterations(numIterations)
384+
.setRegParam(regParam)
385+
.setStepSize(stepSize)
386+
if (regType == "l2") {
387+
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
388+
} else if (regType == "l1") {
389+
LogRegAlg.optimizer.setUpdater(new L1Updater)
390+
} else if (regType != "none") {
391+
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
392+
+ " Can only be initialized using the following string values: [l1, l2, none].")
393+
}
367394
trainRegressionModel(
368395
(data, initialWeights) =>
369-
LogisticRegressionWithSGD.train(
370-
data,
371-
numIterations,
372-
stepSize,
373-
miniBatchFraction,
374-
initialWeights),
396+
LogRegAlg.run(data, initialWeights),
375397
dataBytesJRDD,
376398
initialWeightsBA)
377399
}

python/pyspark/mllib/classification.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,36 @@ def predict(self, x):
7373

7474
class LogisticRegressionWithSGD(object):
7575
@classmethod
76-
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
77-
"""Train a logistic regression model on the given data."""
76+
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
77+
initialWeights=None, regParam=1.0, regType=None, intecept=False):
78+
"""
79+
Train a logistic regression model on the given data.
80+
81+
@param data: The training data.
82+
@param iterations: The number of iterations (default: 100).
83+
@param step: The step parameter used in SGD
84+
(default: 1.0).
85+
@param miniBatchFraction: Fraction of data to be used for each SGD
86+
iteration.
87+
@param initialWeights: The initial weights (default: None).
88+
@param regParam: The regularizer parameter (default: 1.0).
89+
@param regType: The type of regularizer used for training
90+
our model.
91+
Allowed values: "l1" for using L1Updater,
92+
"l2" for using
93+
SquaredL2Updater,
94+
"none" for no regularizer.
95+
(default: "none")
96+
@param intercept: Boolean parameter which indicates the use
97+
or not of the augmented representation for
98+
training data (i.e. whether bias features
99+
are activated or not).
100+
"""
78101
sc = data.context
102+
if regType is None:
103+
regType = "none"
79104
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
80-
d._jrdd, iterations, step, miniBatchFraction, i)
105+
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
81106
return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
82107
initialWeights)
83108

@@ -115,11 +140,35 @@ def predict(self, x):
115140
class SVMWithSGD(object):
116141
@classmethod
117142
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
118-
miniBatchFraction=1.0, initialWeights=None):
119-
"""Train a support vector machine on the given data."""
143+
miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False):
144+
"""
145+
Train a support vector machine on the given data.
146+
147+
@param data: The training data.
148+
@param iterations: The number of iterations (default: 100).
149+
@param step: The step parameter used in SGD
150+
(default: 1.0).
151+
@param regParam: The regularizer parameter (default: 1.0).
152+
@param miniBatchFraction: Fraction of data to be used for each SGD
153+
iteration.
154+
@param initialWeights: The initial weights (default: None).
155+
@param regType: The type of regularizer used for training
156+
our model.
157+
Allowed values: "l1" for using L1Updater,
158+
"l2" for using
159+
SquaredL2Updater,
160+
"none" for no regularizer.
161+
(default: "none")
162+
@param intercept: Boolean parameter which indicates the use
163+
or not of the augmented representation for
164+
training data (i.e. whether bias features
165+
are activated or not).
166+
"""
120167
sc = data.context
168+
if regType is None:
169+
regType = "none"
121170
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
122-
d._jrdd, iterations, step, regParam, miniBatchFraction, i)
171+
d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
123172
return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)
124173

125174

0 commit comments

Comments
 (0)