Skip to content

Commit 2d040b3

Browse files
committed
implement setters inside each class, add Params.copyValues [ci skip]
1 parent fd751fc commit 2d040b3

File tree

8 files changed

+82
-124
lines changed

8 files changed

+82
-124
lines changed

mllib/src/main/scala/org/apache/spark/ml/example/BinaryClassificationEvaluator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params
2828

2929
setMetricName("areaUnderROC")
3030

31+
def setMetricName(value: String): this.type = { set(metricName, value); this }
32+
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
33+
def setLabelCol(value: String): this.type = { set(labelCol, value); this }
34+
3135
override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
3236
import dataset.sqlContext._
3337
val map = this.paramMap ++ paramMap

mllib/src/main/scala/org/apache/spark/ml/example/CrossValidator.scala

Lines changed: 14 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,75 +20,33 @@ package org.apache.spark.ml.example
2020
import com.github.fommil.netlib.F2jBLAS
2121

2222
import org.apache.spark.ml._
23-
import org.apache.spark.ml.param.{ParamMap, Params, Param}
23+
import org.apache.spark.ml.param.{Param, ParamMap, Params}
2424
import org.apache.spark.mllib.util.MLUtils
2525
import org.apache.spark.sql.SchemaRDD
2626

27-
trait HasEstimator extends Params {
27+
class CrossValidator extends Estimator[CrossValidatorModel] with Params {
2828

29-
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
30-
31-
def setEstimator(estimator: Estimator[_]): this.type = {
32-
set(this.estimator, estimator)
33-
this
34-
}
35-
36-
def getEstimator: Estimator[_] = {
37-
get(this.estimator)
38-
}
39-
}
40-
41-
trait HasEvaluator extends Params {
42-
43-
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
44-
45-
def setEvaluator(evaluator: Evaluator): this.type = {
46-
set(this.evaluator, evaluator)
47-
this
48-
}
49-
50-
def getEvaluator: Evaluator = {
51-
get(evaluator)
52-
}
53-
}
29+
private val f2jBLAS = new F2jBLAS
5430

55-
trait HasEstimatorParamMaps extends Params {
31+
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
32+
def setEstimator(value: Estimator[_]): this.type = { set(estimator, value); this }
33+
def getEstimator: Estimator[_] = get(estimator)
5634

5735
val estimatorParamMaps: Param[Array[ParamMap]] =
5836
new Param(this, "estimatorParamMaps", "param maps for the estimator")
59-
60-
def setEstimatorParamMaps(estimatorParamMaps: Array[ParamMap]): this.type = {
61-
set(this.estimatorParamMaps, estimatorParamMaps)
37+
def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
38+
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = {
39+
set(estimatorParamMaps, value)
6240
this
6341
}
6442

65-
def getEstimatorParamMaps: Array[ParamMap] = {
66-
get(estimatorParamMaps)
67-
}
68-
}
69-
70-
71-
class CrossValidator extends Estimator[CrossValidatorModel] with Params
72-
with HasEstimator with HasEstimatorParamMaps with HasEvaluator {
73-
74-
private val f2jBLAS = new F2jBLAS
75-
76-
// Overwrite return type for Java users.
77-
override def setEstimator(estimator: Estimator[_]): this.type = super.setEstimator(estimator)
78-
override def setEstimatorParamMaps(estimatorParamMaps: Array[ParamMap]): this.type =
79-
super.setEstimatorParamMaps(estimatorParamMaps)
80-
override def setEvaluator(evaluator: Evaluator): this.type = super.setEvaluator(evaluator)
43+
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
44+
def setEvaluator(value: Evaluator): this.type = { set(evaluator, value); this }
45+
def getEvaluator: Evaluator = get(evaluator)
8146

8247
val numFolds: Param[Int] = new Param(this, "numFolds", "number of folds for cross validation", 3)
83-
84-
def setNumFolds(numFolds: Int): this.type = {
85-
set(this.numFolds, numFolds)
86-
this
87-
}
88-
89-
def getNumFolds: Int = {
90-
get(numFolds)
91-
}
48+
def setNumFolds(value: Int): this.type = { set(numFolds, value); this }
49+
def getNumFolds: Int = get(numFolds)
9250

9351
/**
9452
* Fits a single model to the input data with provided parameter map.
@@ -135,4 +93,3 @@ class CrossValidatorModel(bestModel: Model, metric: Double) extends Model {
13593
bestModel.transform(dataset, paramMap)
13694
}
13795
}
138-

mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
3636
setRegParam(0.1)
3737
setMaxIter(100)
3838

39-
// Overwrite the return type of setters for Java users.
40-
override def setRegParam(regParam: Double): this.type = super.setRegParam(regParam)
41-
override def setMaxIter(maxIter: Int): this.type = super.setMaxIter(maxIter)
42-
override def setLabelCol(labelCol: String): this.type = super.setLabelCol(labelCol)
43-
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)
39+
def setRegParam(value: Double): this.type = { set(regParam, value); this }
40+
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
41+
def setLabelCol(value: String): this.type = { set(labelCol, value); this }
42+
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }
4443

45-
override final val modelParams: LogisticRegressionModelParams = new LogisticRegressionModelParams {}
44+
override final val modelParams: LogisticRegressionModelParams =
45+
new LogisticRegressionModelParams {}
4646

4747
override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
4848
import dataset.sqlContext._
@@ -58,23 +58,27 @@ class LogisticRegression extends Estimator[LogisticRegressionModel]
5858
.setNumIterations(maxIter)
5959
val lrm = new LogisticRegressionModel(lr.run(instances).weights)
6060
instances.unpersist()
61-
this.modelParams.params.foreach { param =>
62-
if (map.contains(param)) {
63-
lrm.paramMap.put(lrm.getParam(param.name), map(param))
64-
}
65-
}
61+
Params.copyValues(modelParams, lrm)
6662
if (!lrm.paramMap.contains(lrm.featuresCol) && map.contains(lrm.featuresCol)) {
6763
lrm.setFeaturesCol(featuresCol)
6864
}
6965
lrm
7066
}
67+
68+
/**
69+
* Validates parameters specified by the input parameter map.
70+
* Raises an exception if any parameter belongs to this object is invalid.
71+
*/
72+
override def validateParams(paramMap: ParamMap): Unit = {
73+
super.validateParams(paramMap)
74+
}
7175
}
7276

7377
trait LogisticRegressionModelParams extends Params with HasThreshold with HasFeaturesCol
74-
with HasScoreCol {
75-
override def setThreshold(threshold: Double): this.type = super.setThreshold(threshold)
76-
override def setFeaturesCol(featuresCol: String): this.type = super.setFeaturesCol(featuresCol)
77-
override def setScoreCol(scoreCol: String): this.type = super.setScoreCol(scoreCol)
78+
with HasScoreCol {
79+
def setThreshold(value: Double): this.type = { set(threshold, value); this }
80+
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }
81+
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
7882
}
7983

8084
class LogisticRegressionModel(

mllib/src/main/scala/org/apache/spark/ml/example/StandardScaler.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,45 @@ import org.apache.spark.sql.catalyst.analysis.Star
2626
import org.apache.spark.sql.catalyst.dsl._
2727
import org.apache.spark.sql.catalyst.expressions.Row
2828

29-
class StandardScaler extends Transformer with Params with HasInputCol with HasOutputCol {
29+
class StandardScaler extends Estimator[StandardScalerModel] with HasInputCol {
3030

31-
override def setInputCol(inputCol: String): this.type = super.setInputCol(inputCol)
32-
override def setOutputCol(outputCol: String): this.type = super.setOutputCol(outputCol)
31+
def setInputCol(value: String): this.type = { set(inputCol, value); this }
3332

34-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
33+
override val modelParams: StandardScalerModelParams = new StandardScalerModelParams {}
34+
35+
override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
3536
import dataset.sqlContext._
3637
val map = this.paramMap ++ paramMap
3738
import map.implicitMapping
3839
val input = dataset.select((inputCol: String).attr)
3940
.map { case Row(v: Vector) =>
4041
v
41-
}.cache()
42+
}
4243
val scaler = new feature.StandardScaler().fit(input)
44+
val model = new StandardScalerModel(scaler)
45+
Params.copyValues(modelParams, model)
46+
if (!model.paramMap.contains(model.inputCol)) {
47+
model.setInputCol(inputCol)
48+
}
49+
model
50+
}
51+
}
52+
53+
trait StandardScalerModelParams extends Params with HasInputCol with HasOutputCol {
54+
def setInputCol(value: String): this.type = { set(inputCol, value); this }
55+
def setOutputCol(value: String): this.type = { set(outputCol, value); this }
56+
}
57+
58+
class StandardScalerModel private[ml] (
59+
scaler: feature.StandardScalerModel) extends Model with StandardScalerModelParams {
60+
61+
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
62+
import dataset.sqlContext._
63+
val map = this.paramMap ++ paramMap
64+
import map.implicitMapping
4365
val scale: (Vector) => Vector = (v) => {
4466
scaler.transform(v)
4567
}
46-
dataset.select(Star(None), scale.call((inputCol: String).attr) as Symbol(outputCol))
68+
dataset.select(Star(None), scale.call((inputCol: String).attr) as outputCol)
4769
}
4870
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class Param[T] private[param] (
7777
}
7878
}
7979

80+
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
81+
8082
class DoubleParam(parent: Params, name: String, doc: String, default: Option[Double] = None)
8183
extends Param[Double](parent, name, doc, default) {
8284
override def w(value: Double): ParamPair[Double] = ParamPair(this, value)
@@ -152,6 +154,17 @@ private[ml] object Params {
152154
val empty: Params = new Params {
153155
override def params: Array[Param[_]] = Array.empty
154156
}
157+
158+
/**
159+
* Copy parameter values from one Params instance to another.
160+
*/
161+
def copyValues[F <: Params, T <: F](from: F, to: T): Unit = {
162+
from.params.foreach { param =>
163+
if (from.paramMap.contains(param)) {
164+
to.paramMap.put(to.getParam(param.name), from.paramMap(param))
165+
}
166+
}
167+
}
155168
}
156169

157170
/**

mllib/src/main/scala/org/apache/spark/ml/param/shared.scala

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ trait HasRegParam extends Params {
2121

2222
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
2323

24-
def setRegParam(regParam: Double): this.type = {
25-
set(this.regParam, regParam)
26-
this
27-
}
28-
2924
def getRegParam: Double = {
3025
get(regParam)
3126
}
@@ -35,11 +30,6 @@ trait HasMaxIter extends Params {
3530

3631
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
3732

38-
def setMaxIter(maxIter: Int): this.type = {
39-
set(this.maxIter, maxIter)
40-
this
41-
}
42-
4333
def getMaxIter: Int = {
4434
get(maxIter)
4535
}
@@ -50,11 +40,6 @@ trait HasFeaturesCol extends Params {
5040
val featuresCol: Param[String] =
5141
new Param(this, "featuresCol", "features column name", "features")
5242

53-
def setFeaturesCol(featuresCol: String): this.type = {
54-
set(this.featuresCol, featuresCol)
55-
this
56-
}
57-
5843
def getFeaturesCol: String = {
5944
get(featuresCol)
6045
}
@@ -64,23 +49,14 @@ trait HasLabelCol extends Params {
6449

6550
val labelCol: Param[String] = new Param(this, "labelCol", "label column name", "label")
6651

67-
def setLabelCol(labelCol: String): this.type = {
68-
set(this.labelCol, labelCol)
69-
this
70-
}
71-
7252
def getLabelCol: String = {
7353
get(labelCol)
7454
}
7555
}
7656

7757
trait HasScoreCol extends Params {
78-
val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", "score")
7958

80-
def setScoreCol(scoreCol: String): this.type = {
81-
set(this.scoreCol, scoreCol)
82-
this
83-
}
59+
val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", "score")
8460

8561
def getScoreCol: String = {
8662
get(scoreCol)
@@ -91,11 +67,6 @@ trait HasThreshold extends Params {
9167

9268
val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold for prediction")
9369

94-
def setThreshold(threshold: Double): this.type = {
95-
set(this.threshold, threshold)
96-
this
97-
}
98-
9970
def getThreshold: Double = {
10071
get(threshold)
10172
}
@@ -105,11 +76,6 @@ trait HasMetricName extends Params {
10576

10677
val metricName: Param[String] = new Param(this, "metricName", "metric name for evaluation")
10778

108-
def setMetricName(metricName: String): this.type = {
109-
set(this.metricName, metricName)
110-
this
111-
}
112-
11379
def getMetricName: String = {
11480
get(metricName)
11581
}
@@ -119,11 +85,6 @@ trait HasInputCol extends Params {
11985

12086
val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
12187

122-
def setInputCol(inputCol: String): this.type = {
123-
set(this.inputCol, inputCol)
124-
this
125-
}
126-
12788
def getInputCol: String = {
12889
get(inputCol)
12990
}
@@ -133,11 +94,6 @@ trait HasOutputCol extends Params {
13394

13495
val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
13596

136-
def setOutputCol(outputCol: String): this.type = {
137-
set(this.outputCol, outputCol)
138-
this
139-
}
140-
14197
def getOutputCol: String = {
14298
get(outputCol)
14399
}

mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ public void logisticRegressionWithCrossValidation() {
117117
@Test
118118
public void logisticRegressionWithPipeline() {
119119
StandardScaler scaler = new StandardScaler()
120-
.setInputCol("features")
120+
.setInputCol("features");
121+
scaler.modelParams()
121122
.setOutputCol("scaledFeatures");
122123
LogisticRegression lr = new LogisticRegression()
123124
.setFeaturesCol("scaledFeatures");

mllib/src/test/scala/org/apache/spark/ml/example/LogisticRegressionSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite {
8383
test("logistic regression with pipeline") {
8484
val scaler = new StandardScaler()
8585
.setInputCol("features")
86+
scaler.modelParams
8687
.setOutputCol("scaledFeatures")
8788
val lr = new LogisticRegression()
8889
.setFeaturesCol("scaledFeatures")

0 commit comments

Comments
 (0)