Skip to content

Commit dbda033

Browse files
committed
[SPARK-6025] Add helper method evaluateEachIteration to extract learning curve
1 parent 1f1fccc commit dbda033

File tree

7 files changed

+125
-5
lines changed

7 files changed

+125
-5
lines changed

docs/mllib-ensembles.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset
464464
The training is stopped when the improvement in the validation error is not more than a certain tolerance
465465
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
466466
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
467-
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
468-
iterations.
467+
and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration`
468+
(which gives the error or loss per iteration) to tune the number of iterations.
469469

470470
### Examples
471471

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.configuration.BoostingStrategy
2525
import org.apache.spark.mllib.tree.configuration.Algo._
2626
import org.apache.spark.mllib.tree.impl.TimeTracker
2727
import org.apache.spark.mllib.tree.impurity.Variance
28+
import org.apache.spark.mllib.tree.loss.Loss
2829
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
2930
import org.apache.spark.rdd.RDD
3031
import org.apache.spark.storage.StorageLevel
@@ -52,14 +53,18 @@ import org.apache.spark.storage.StorageLevel
5253
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
5354
extends Serializable with Logging {
5455

56+
private val numIterations = boostingStrategy.numIterations
57+
private var baseLearners = new Array[DecisionTreeModel](numIterations)
58+
private var baseLearnerWeights = new Array[Double](numIterations)
59+
5560
/**
5661
* Method to train a gradient boosting model
5762
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
5863
* @return a gradient boosted trees model that can be used for prediction
5964
*/
6065
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
6166
val algo = boostingStrategy.treeStrategy.algo
62-
algo match {
67+
val fitGradientBoostingModel = algo match {
6368
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
6469
case Classification =>
6570
// Map labels to -1, +1 so binary classification can be treated as regression.
@@ -69,6 +74,42 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
6974
case _ =>
7075
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
7176
}
77+
baseLearners = fitGradientBoostingModel.trees
78+
baseLearnerWeights = fitGradientBoostingModel.treeWeights
79+
fitGradientBoostingModel
80+
}
81+
82+
/**
83+
* Method to compute error or loss for every iteration of gradient boosting.
84+
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
85+
* @param loss: evaluation metric that defaults to boostingStrategy.loss
86+
* @return an array with index i having the losses or errors for the ensemble
87+
* containing trees 1 to i + 1
88+
*/
89+
def evaluateEachIteration(
90+
data: RDD[LabeledPoint],
91+
loss: Loss = boostingStrategy.loss) : Array[Double] = {
92+
93+
val algo = boostingStrategy.treeStrategy.algo
94+
val remappedData = algo match {
95+
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
96+
case _ => data
97+
}
98+
val initialTree = baseLearners(0)
99+
val evaluationArray = Array.fill(numIterations)(0.0)
100+
101+
// Initial weight is 1.0
102+
var predictionRDD = remappedData.map(i => initialTree.predict(i.features))
103+
evaluationArray(0) = loss.computeError(remappedData, predictionRDD)
104+
105+
(1 until numIterations).map {nTree =>
106+
predictionRDD = (remappedData zip predictionRDD) map {
107+
case (point, pred) =>
108+
pred + baseLearners(nTree).predict(point.features) * baseLearnerWeights(nTree)
109+
}
110+
evaluationArray(nTree) = loss.computeError(remappedData, predictionRDD)
111+
}
112+
evaluationArray
72113
}
73114

74115
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,23 @@ object AbsoluteError extends Loss {
6161
math.abs(err)
6262
}.mean()
6363
}
64+
65+
/**
66+
* Method to calculate loss when the predictions are already known.
67+
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
68+
* predicted values from previously fit trees.
69+
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
70+
* @param prediction: RDD[Double] of predicted labels.
71+
* @return Mean absolute error of model on data
72+
*/
73+
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
74+
val errorAcrossSamples = (data zip prediction) map {
75+
case (yTrue, yPred) => {
76+
val err = yTrue.label - yPred
77+
math.abs(err)
78+
}
79+
}
80+
errorAcrossSamples.mean()
81+
}
82+
6483
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,23 @@ object LogLoss extends Loss {
6666
2.0 * MLUtils.log1pExp(-margin)
6767
}.mean()
6868
}
69+
70+
/**
71+
* Method to calculate loss when the predictions are already known.
72+
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
73+
* predicted values from previously fit trees.
74+
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
75+
* @param prediction: RDD[Double] of predicted labels.
76+
* @return Mean log loss of model on data
77+
*/
78+
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
79+
val errorAcrossSamples = (data zip prediction) map {
80+
case (yTrue, yPred) =>
81+
val margin = 2.0 * yTrue.label * yPred
82+
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
83+
2.0 * MLUtils.log1pExp(-margin)
84+
}
85+
errorAcrossSamples.mean()
86+
}
87+
6988
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,14 @@ trait Loss extends Serializable {
4949
*/
5050
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
5151

52+
/**
53+
* Method to calculate loss when the predictions are already known.
54+
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
55+
* predicted values from previously fit trees.
56+
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
57+
* @param prediction: RDD[Double] of predicted labels.
58+
* @return Measure of model error on data
59+
*/
60+
def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]) : Double
61+
5262
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,22 @@ object SquaredError extends Loss {
6161
err * err
6262
}.mean()
6363
}
64+
65+
/**
66+
* Method to calculate loss when the predictions are already known.
67+
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
68+
* predicted values from previously fit trees.
69+
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
70+
* @param prediction: RDD[Double] of predicted labels.
71+
* @return Mean squared error of model on data
72+
*/
73+
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
74+
val errorAcrossSamples = (data zip prediction) map {
75+
case (yTrue, yPred) =>
76+
val err = yPred - yTrue.label
77+
err * err
78+
}
79+
errorAcrossSamples.mean()
80+
}
81+
6482
}

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,12 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
175175
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
176176
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
177177
.runWithValidation(trainRdd, validateRdd)
178-
assert(gbtValidate.numTrees !== numIterations)
178+
val numTrees = gbtValidate.numTrees
179+
assert(numTrees !== numIterations)
179180

180181
// Test that it performs better on the validation dataset.
181-
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
182+
val gbtModel = new GradientBoostedTrees(boostingStrategy)
183+
val gbt = gbtModel.run(trainRdd)
182184
val (errorWithoutValidation, errorWithValidation) = {
183185
if (algo == Classification) {
184186
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -188,6 +190,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
188190
}
189191
}
190192
assert(errorWithValidation <= errorWithoutValidation)
193+
194+
// Test that results from evaluateEachIteration comply with runWithValidation.
195+
// Note that convergenceTol is set to 0.0
196+
val evaluationArray = gbtModel.evaluateEachIteration(validateRdd)
197+
assert(evaluationArray.length === numIterations)
198+
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
199+
var i = 1
200+
while (i < numTrees) {
201+
assert(evaluationArray(i) < evaluationArray(i - 1))
202+
i += 1
203+
}
191204
}
192205
}
193206
}

0 commit comments

Comments
 (0)