Skip to content

Commit 9fd4933

Browse files
committed
add unit test for pipeline
1 parent 2a0df46 commit 9fd4933

File tree

11 files changed

+146
-44
lines changed

11 files changed

+146
-44
lines changed

mllib/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@
101101
<scope>test</scope>
102102
</dependency>
103103
<dependency>
104+
<groupId>org.mockito</groupId>
105+
<artifactId>mockito-all</artifactId>
106+
<version>1.9.0</version>
107+
<scope>test</scope>
108+
</dependency>
109+
<dependency>
104110
<groupId>org.apache.spark</groupId>
105111
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
106112
<version>${project.version}</version>

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.api.java.JavaSchemaRDD
2727
/**
2828
* Abstract class for estimators that fit models to data.
2929
*/
30-
abstract class Estimator[M <: Model] extends PipelineStage with Params {
30+
abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
3131

3232
/**
3333
* Fits a single model to the input data with optional parameters.

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import org.apache.spark.ml.param.ParamMap
2222
/**
2323
* A fitted model.
2424
*/
25-
abstract class Model extends Transformer {
25+
abstract class Model[M <: Model[M]] extends Transformer {
2626
/**
2727
* The parent estimator that produced this model.
2828
*/
29-
val parent: Estimator[_]
29+
val parent: Estimator[M]
3030

3131
/**
3232
* Fitting parameters, such that parent.fit(..., trainingParamMap) could reproduce the model.

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ class Pipeline extends Estimator[PipelineModel] {
105105
class PipelineModel(
106106
override val parent: Pipeline,
107107
override val fittingParamMap: ParamMap,
108-
val transformers: Array[Transformer]) extends Model with Logging {
108+
val transformers: Array[Transformer]) extends Model[PipelineModel] with Logging {
109109

110110
/**
111111
* Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
112112
* estimator does not exist in the pipeline.
113113
*/
114-
def getModel[M <: Model](estimator: Estimator[M]): M = {
114+
def getModel[M <: Model[M]](estimator: Estimator[M]): M = {
115115
val matched = transformers.filter {
116-
case m: Model => m.parent.eq(estimator)
116+
case m: Model[_] => m.parent.eq(estimator)
117117
case _ => false
118118
}
119119
if (matched.isEmpty) {

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ abstract class Transformer extends PipelineStage with Params {
7171
* Abstract class for transformers that take one input column, apply transformation, and output the
7272
* result as a new column.
7373
*/
74-
abstract class UnaryTransformer[IN, OUT: TypeTag, SELF <: UnaryTransformer[IN, OUT, SELF]]
74+
abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
7575
extends Transformer with HasInputCol with HasOutputCol with Logging {
7676

77-
def setInputCol(value: String): SELF = { set(inputCol, value); this.asInstanceOf[SELF] }
78-
def setOutputCol(value: String): SELF = { set(outputCol, value); this.asInstanceOf[SELF] }
77+
def setInputCol(value: String): T = { set(inputCol, value); this.asInstanceOf[T] }
78+
def setOutputCol(value: String): T = { set(outputCol, value); this.asInstanceOf[T] }
7979

8080
/**
8181
* Creates the transform function using the given param map. The input param map already takes

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
9696
class LogisticRegressionModel private[ml] (
9797
override val parent: LogisticRegression,
9898
override val fittingParamMap: ParamMap,
99-
val weights: Vector) extends Model with LogisticRegressionParams {
99+
val weights: Vector) extends Model[LogisticRegressionModel] with LogisticRegressionParams {
100100

101101
def setThreshold(value: Double): this.type = { set(threshold, value); this }
102102
def setFeaturesCol(value: String): this.type = { set(featuresCol, value); this }

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
7171
class StandardScalerModel private[ml] (
7272
override val parent: StandardScaler,
7373
override val fittingParamMap: ParamMap,
74-
scaler: feature.StandardScalerModel) extends Model with StandardScalerParams {
74+
scaler: feature.StandardScalerModel) extends Model[StandardScalerModel]
75+
with StandardScalerParams {
7576

7677
def setInputCol(value: String): this.type = { set(inputCol, value); this }
7778
def setOutputCol(value: String): this.type = { set(outputCol, value); this }

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
7979
val validationDataset = sqlCtx.applySchema(validation, schema).cache()
8080
// multi-model training
8181
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
82-
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model]]
82+
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
8383
var i = 0
8484
while (i < numModels) {
8585
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
@@ -93,7 +93,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
9393
val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
9494
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
9595
logInfo(s"Best cross-validation metric: $bestMetric.")
96-
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model]
96+
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
9797
val cvModel = new CrossValidatorModel(this, map, bestModel)
9898
Params.copyValues(this, cvModel)
9999
cvModel
@@ -111,7 +111,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
111111
class CrossValidatorModel private[ml] (
112112
override val parent: CrossValidator,
113113
override val fittingParamMap: ParamMap,
114-
val bestModel: Model) extends Model with CrossValidatorParams {
114+
val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams {
115115

116116
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
117117
bestModel.transform(dataset, paramMap)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.mockito.Matchers.{any, eq => meq}
21+
import org.mockito.Mockito.when
22+
import org.scalatest.FunSuite
23+
import org.scalatest.mock.MockitoSugar.mock
24+
25+
import org.apache.spark.ml.param.ParamMap
26+
import org.apache.spark.sql.SchemaRDD
27+
28+
class PipelineSuite extends FunSuite {
29+
30+
abstract class MyModel extends Model[MyModel]
31+
32+
test("pipeline") {
33+
val estimator0 = mock[Estimator[MyModel]]
34+
val model0 = mock[MyModel]
35+
val transformer1 = mock[Transformer]
36+
val estimator2 = mock[Estimator[MyModel]]
37+
val model2 = mock[MyModel]
38+
val transformer3 = mock[Transformer]
39+
val dataset0 = mock[SchemaRDD]
40+
val dataset1 = mock[SchemaRDD]
41+
val dataset2 = mock[SchemaRDD]
42+
val dataset3 = mock[SchemaRDD]
43+
val dataset4 = mock[SchemaRDD]
44+
45+
when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
46+
when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
47+
when(model0.parent).thenReturn(estimator0)
48+
when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2)
49+
when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2)
50+
when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3)
51+
when(model2.parent).thenReturn(estimator2)
52+
when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4)
53+
54+
val pipeline = new Pipeline()
55+
.setStages(Array(estimator0, transformer1, estimator2, transformer3))
56+
val pipelineModel = pipeline.fit(dataset0)
57+
58+
assert(pipelineModel.transformers(0).eq(model0))
59+
assert(pipelineModel.transformers(1).eq(transformer1))
60+
assert(pipelineModel.transformers(2).eq(model2))
61+
assert(pipelineModel.transformers(3).eq(transformer3))
62+
63+
assert(pipelineModel.getModel(estimator0).eq(model0))
64+
assert(pipelineModel.getModel(estimator2).eq(model2))
65+
intercept[NoSuchElementException] {
66+
pipelineModel.getModel(mock[Estimator[MyModel]])
67+
}
68+
val output = pipelineModel.transform(dataset0)
69+
assert(output.eq(dataset4))
70+
}
71+
}

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,34 +70,4 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
7070
.select('label, 'probability, 'prediction)
7171
.foreach(println)
7272
}
73-
74-
test("logistic regression with cross validation") {
75-
val lr = new LogisticRegression
76-
val lrParamMaps = new ParamGridBuilder()
77-
.addGrid(lr.regParam, Array(0.1, 100.0))
78-
.addGrid(lr.maxIter, Array(0, 5))
79-
.build()
80-
val eval = new BinaryClassificationEvaluator
81-
val cv = new CrossValidator()
82-
.setEstimator(lr)
83-
.setEstimatorParamMaps(lrParamMaps)
84-
.setEvaluator(eval)
85-
.setNumFolds(3)
86-
val bestModel = cv.fit(dataset)
87-
}
88-
89-
test("logistic regression with pipeline") {
90-
val scaler = new StandardScaler()
91-
.setInputCol("features")
92-
.setOutputCol("scaledFeatures")
93-
val lr = new LogisticRegression()
94-
.setFeaturesCol("scaledFeatures")
95-
val pipeline = new Pipeline()
96-
.setStages(Array(scaler, lr))
97-
val model = pipeline.fit(dataset)
98-
val predictions = model.transform(dataset)
99-
.select('label, 'score, 'prediction)
100-
.collect()
101-
.foreach(println)
102-
}
10373
}

0 commit comments

Comments
 (0)