Skip to content

Commit 9d2d35d

Browse files
committed
test varargs and chain model params
1 parent f46e927 commit 9d2d35d

File tree

4 files changed

+34
-62
lines changed

4 files changed

+34
-62
lines changed

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,28 @@
1717

1818
package org.apache.spark.ml
1919

20-
import org.apache.spark.ml.param.{ParamMap, Params, ParamPair}
21-
import org.apache.spark.sql.SchemaRDD
22-
2320
import scala.annotation.varargs
2421

22+
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
23+
import org.apache.spark.sql.SchemaRDD
24+
2525
/**
2626
* Abstract class for estimators that fits models to data.
2727
*/
28-
abstract class Estimator[M <: Model] extends Identifiable with Params with PipelineStage {
28+
abstract class Estimator[M <: Model] extends PipelineStage with Params {
2929

3030
/**
31-
* Fits a single model to the input data with default parameters.
31+
* Fits a single model to the input data with optional parameters.
3232
*
3333
* @param dataset input dataset
34+
* @param paramPairs optional list of param pairs, overwrite embedded params
3435
* @return fitted model
3536
*/
36-
def fit(dataset: SchemaRDD): M = {
37-
fit(dataset, ParamMap.empty)
37+
@varargs
38+
def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
39+
val map = new ParamMap()
40+
paramPairs.foreach(map.put(_))
41+
fit(dataset, map)
3842
}
3943

4044
/**
@@ -46,25 +50,6 @@ abstract class Estimator[M <: Model] extends Identifiable with Params with Pipel
4650
*/
4751
def fit(dataset: SchemaRDD, paramMap: ParamMap): M
4852

49-
/**
50-
* Fits a single model to the input data with provided parameters.
51-
*
52-
* @param dataset input dataset
53-
* @param firstParamPair first parameter
54-
* @param otherParamPairs other parameters
55-
* @return fitted model
56-
*/
57-
@varargs
58-
def fit[T](
59-
dataset: SchemaRDD,
60-
firstParamPair: ParamPair[_],
61-
otherParamPairs: ParamPair[_]*): M = {
62-
val map = new ParamMap()
63-
map.put(firstParamPair)
64-
otherParamPairs.foreach(map.put(_))
65-
fit(dataset, map)
66-
}
67-
6853
/**
6954
* Fits multiple models to the input data with multiple sets of parameters.
7055
* The default implementation uses a for loop on each parameter map.
@@ -74,7 +59,7 @@ abstract class Estimator[M <: Model] extends Identifiable with Params with Pipel
7459
* @param paramMaps an array of parameter maps
7560
* @return fitted models, matching the input parameter maps
7661
*/
77-
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
62+
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { // how to return an array?
7863
paramMaps.map(fit(dataset, _))
7964
}
8065

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@
1717

1818
package org.apache.spark.ml
1919

20-
import org.apache.spark.ml.param.{ParamMap, Params, ParamPair}
20+
import scala.annotation.varargs
21+
22+
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
2123
import org.apache.spark.sql.SchemaRDD
2224

2325
/**
2426
* Abstract class for transformers that transform one dataset into another.
2527
*/
26-
abstract class Transformer extends Identifiable with Params with PipelineStage {
28+
abstract class Transformer extends PipelineStage with Params {
2729

2830
/**
29-
* Transforms the dataset with the default parameters.
31+
* Transforms the dataset with optional parameters
3032
* @param dataset input dataset
33+
* @param paramPairs optional list of param pairs, overwrite embedded params
3134
* @return transformed dataset
3235
*/
33-
def transform(dataset: SchemaRDD): SchemaRDD = {
36+
@varargs
37+
def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
3438
transform(dataset, ParamMap.empty)
3539
}
3640

@@ -41,31 +45,4 @@ abstract class Transformer extends Identifiable with Params with PipelineStage {
4145
* @return transformed dataset
4246
*/
4347
def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
44-
45-
/**
46-
* Transforms the dataset with provided parameter pairs.
47-
* @param dataset input dataset
48-
* @param firstParamPair first parameter pair
49-
* @param otherParamPairs second parameter pair
50-
* @return transformed dataset
51-
*/
52-
def transform(
53-
dataset: SchemaRDD,
54-
firstParamPair: ParamPair[_],
55-
otherParamPairs: ParamPair[_]*): SchemaRDD = {
56-
val map = new ParamMap()
57-
map.put(firstParamPair)
58-
otherParamPairs.foreach(map.put(_))
59-
transform(dataset, map)
60-
}
61-
62-
/**
63-
* Transforms the dataset with multiple sets of parameters.
64-
* @param dataset input dataset
65-
* @param paramMaps an array of parameter maps
66-
* @return transformed dataset
67-
*/
68-
def transform(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Array[SchemaRDD] = {
69-
paramMaps.map(transform(dataset, _))
70-
}
7148
}

mllib/src/test/java/org/apache/spark/ml/example/JavaLogisticRegressionSuite.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,20 @@ public void logisticRegression() {
6464
.setMaxIter(10)
6565
.setRegParam(1.0);
6666
lr.model().setThreshold(0.8);
67-
// In Java we can access baseSchemaRDD, while in Scala we cannot.
68-
LogisticRegressionModel model = lr.fit(dataset.baseSchemaRDD());
69-
model.transform(dataset.baseSchemaRDD()).registerTempTable("prediction");
67+
LogisticRegressionModel model = lr.fit(dataset.schemaRDD());
68+
model.transform(dataset.schemaRDD()).registerTempTable("prediction");
7069
JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
7170
for (Row r: predictions.collect()) {
7271
System.out.println(r);
7372
}
7473
}
7574

75+
@Test
76+
public void logisticRegressionFitWithVarargs() {
77+
LogisticRegression lr = new LogisticRegression();
78+
lr.fit(dataset.schemaRDD(), lr.maxIter().w(10), lr.regParam().w(1.0));
79+
}
80+
7681
@Test
7782
public void logisticRegressionWithCrossValidation() {
7883
LogisticRegression lr = new LogisticRegression();

mllib/src/test/scala/org/apache/spark/ml/example/LogisticRegressionSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,16 @@ class LogisticRegressionSuite extends FunSuite {
3535
.setMaxIter(10)
3636
.setRegParam(1.0)
3737
val model = lr.fit(dataset)
38-
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
38+
model.transform(dataset, lr.model.threshold -> 0.8) // overwrite threshold
3939
.select('label, 'score, 'prediction).collect()
4040
.foreach(println)
4141
}
4242

43+
test("logistic regression fit with varargs") {
44+
val lr = new LogisticRegression
45+
lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
46+
}
47+
4348
test("logistic regression with cross validation") {
4449
val lr = new LogisticRegression
4550
val lrParamMaps = new ParamGridBuilder()

0 commit comments

Comments
 (0)