Skip to content

[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
Expand Down Expand Up @@ -252,15 +254,27 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
initialWeightsBA: Array[Byte],
regParam: Double,
regType: String,
intercept: Boolean): java.util.List[java.lang.Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
lrAlg.optimizer.setUpdater(new L1Updater)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using enumerations for regType parameter anymore. Switched to string values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is safer to add

    else if (regType != "none")
      throw IllegalArgumentException("...")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By adding the exception to the scala code, I am going to remove the ValueError exception used in the python code.

} else if (regType != "none") {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
(data, initialWeights) =>
LinearRegressionWithSGD.train(
data,
numIterations,
stepSize,
miniBatchFraction,
initialWeights),
lrAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
Expand Down
32 changes: 28 additions & 4 deletions python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use two empty lines to separate methods in pyspark. (I don't know the exact reason ...)

class LinearRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a linear regression model on the given data."""
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
initialWeights=None, regParam=1.0, regType=None, intercept=False):
"""
Train a linear regression model on the given data.

@param data: The training data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to align the doc, especially when limited by 78 characters.

@param iterations: The number of iterations (default: 100).
@param step: The step parameter used in SGD
(default: 1.0).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: please keep this empty line

@param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
@param initialWeights: The initial weights (default: None).
@param regParam: The regularizer parameter (default: 1.0).
@param regType: The type of regularizer used for training
our model.
Allowed values: "l1" for using L1Updater,
"l2" for using
SquaredL2Updater,
"none" for no regularizer.
(default: "none")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
sc = data.context
if regType is None:
regType = "none"
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
d._jrdd, iterations, step, miniBatchFraction, i)
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights)


Expand Down