Skip to content

Commit b5f52c1

Browse files
committed
Add param to CrossValidator for choosing whether to maximize evaulation value.
1 parent 54557f3 commit b5f52c1

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
297297

298298
/**
299299
* :: Experimental ::
300-
* A param amd its value.
300+
* A param and its value.
301301
*/
302302
@Experimental
303303
case class ParamPair[T](param: Param[T], value: T) {

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,20 @@ private[ml] trait CrossValidatorParams extends Params {
7676
def getNumFolds: Int = $(numFolds)
7777

7878
setDefault(numFolds -> 3)
79+
80+
/**
81+
* Param for whether maximize the evaluation value during cross validation.
82+
* If false, turn to minimize the evaluation value.
83+
* Default: true
84+
* @group param
85+
*/
86+
val useMax: BooleanParam = new BooleanParam(this, "useMax",
87+
"whether maximize the evaluation value durin cross validation")
88+
89+
/** @group getParam */
90+
def getUseMax: Boolean = $(useMax)
91+
92+
setDefault(useMax -> true)
7993
}
8094

8195
/**
@@ -102,6 +116,9 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
102116
/** @group setParam */
103117
def setNumFolds(value: Int): this.type = set(numFolds, value)
104118

119+
/** @group setParam */
120+
def setUseMax(value: Boolean): this.type = set(useMax, value)
121+
105122
override def fit(dataset: DataFrame): CrossValidatorModel = {
106123
val schema = dataset.schema
107124
transformSchema(schema, logging = true)
@@ -131,7 +148,11 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
131148
}
132149
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
133150
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
134-
val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
151+
val (bestMetric, bestIndex) = if ($(useMax)) {
152+
metrics.zipWithIndex.maxBy(_._1)
153+
} else {
154+
metrics.zipWithIndex.minBy(_._1)
155+
}
135156
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
136157
logInfo(s"Best cross-validation metric: $bestMetric.")
137158
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ import org.apache.spark.SparkFunSuite
2121

2222
import org.apache.spark.ml.{Estimator, Model}
2323
import org.apache.spark.ml.classification.LogisticRegression
24-
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
24+
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
2525
import org.apache.spark.ml.param.ParamMap
2626
import org.apache.spark.ml.param.shared.HasInputCol
27+
import org.apache.spark.ml.regression.LinearRegression
2728
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
29+
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
2930
import org.apache.spark.sql.{DataFrame, SQLContext}
3031
import org.apache.spark.sql.types.StructType
3132

@@ -59,6 +60,30 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
5960
assert(cvModel.avgMetrics.length === lrParamMaps.length)
6061
}
6162

63+
test("cross validation with linear regression") {
64+
val dataset = sqlContext.createDataFrame(
65+
sc.parallelize(LinearDataGenerator.generateLinearInput(
66+
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
67+
68+
val trainer = new LinearRegression
69+
val lrParamMaps = new ParamGridBuilder()
70+
.addGrid(trainer.regParam, Array(1000.0, 0.001))
71+
.addGrid(trainer.maxIter, Array(0, 10))
72+
.build()
73+
val eval = new RegressionEvaluator()
74+
val cv = new CrossValidator()
75+
.setEstimator(trainer)
76+
.setEstimatorParamMaps(lrParamMaps)
77+
.setEvaluator(eval)
78+
.setNumFolds(3)
79+
.setUseMax(false)
80+
val cvModel = cv.fit(dataset)
81+
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
82+
assert(parent.getRegParam === 0.001)
83+
assert(parent.getMaxIter === 10)
84+
assert(cvModel.avgMetrics.length === lrParamMaps.length)
85+
}
86+
6287
test("validateParams should check estimatorParamMaps") {
6388
import CrossValidatorSuite._
6489

0 commit comments

Comments
 (0)