-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624
Changes from all commits
3ac8874
78853ec
b962744
ec50ee9
638be47
8eba9c5
44e6ff0
fed8eaa
8dcb888
c02e5f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,8 @@ import org.apache.spark.annotation.DeveloperApi | |
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} | ||
import org.apache.spark.mllib.classification._ | ||
import org.apache.spark.mllib.clustering._ | ||
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.optimization._ | ||
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} | ||
import org.apache.spark.mllib.recommendation._ | ||
|
@@ -252,15 +254,27 @@ class PythonMLLibAPI extends Serializable { | |
numIterations: Int, | ||
stepSize: Double, | ||
miniBatchFraction: Double, | ||
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { | ||
initialWeightsBA: Array[Byte], | ||
regParam: Double, | ||
regType: String, | ||
intercept: Boolean): java.util.List[java.lang.Object] = { | ||
val lrAlg = new LinearRegressionWithSGD() | ||
lrAlg.setIntercept(intercept) | ||
lrAlg.optimizer | ||
.setNumIterations(numIterations) | ||
.setRegParam(regParam) | ||
.setStepSize(stepSize) | ||
if (regType == "l2") { | ||
lrAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
} else if (regType == "l1") { | ||
lrAlg.optimizer.setUpdater(new L1Updater) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is safer to add
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By adding the exception to the scala code, I am going to remove the |
||
} else if (regType != "none") { | ||
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
+ " Can only be initialized using the following string values: [l1, l2, none].") | ||
} | ||
trainRegressionModel( | ||
(data, initialWeights) => | ||
LinearRegressionWithSGD.train( | ||
data, | ||
numIterations, | ||
stepSize, | ||
miniBatchFraction, | ||
initialWeights), | ||
lrAlg.run(data, initialWeights), | ||
dataBytesJRDD, | ||
initialWeightsBA) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase): | |
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use two empty lines to separate methods in pyspark. (I don't know the exact reason ...) |
||
class LinearRegressionWithSGD(object): | ||
@classmethod | ||
def train(cls, data, iterations=100, step=1.0, | ||
miniBatchFraction=1.0, initialWeights=None): | ||
"""Train a linear regression model on the given data.""" | ||
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, | ||
initialWeights=None, regParam=1.0, regType=None, intercept=False): | ||
""" | ||
Train a linear regression model on the given data. | ||
|
||
@param data: The training data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not necessary to align the doc, especially when limited by 78 characters. |
||
@param iterations: The number of iterations (default: 100). | ||
@param step: The step parameter used in SGD | ||
(default: 1.0). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: please keep this empty line |
||
@param miniBatchFraction: Fraction of data to be used for each SGD | ||
iteration. | ||
@param initialWeights: The initial weights (default: None). | ||
@param regParam: The regularizer parameter (default: 1.0). | ||
@param regType: The type of regularizer used for training | ||
our model. | ||
Allowed values: "l1" for using L1Updater, | ||
"l2" for using | ||
SquaredL2Updater, | ||
"none" for no regularizer. | ||
(default: "none") | ||
@param intercept: Boolean parameter which indicates the use | ||
or not of the augmented representation for | ||
training data (i.e. whether bias features | ||
are activated or not). | ||
""" | ||
sc = data.context | ||
if regType is None: | ||
regType = "none" | ||
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
d._jrdd, iterations, step, miniBatchFraction, i) | ||
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) | ||
return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not using enumerations for
regType
parameter anymore. Switched to string values.