Skip to content

Commit 25e271d

Browse files
MechCoderjkbradley
authored andcommitted
[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve
Added evaluateEachIteration to allow the user to manually extract the error for each iteration of GradientBoosting. The internal optimisation can be dealt with later. Author: MechCoder <[email protected]> Closes apache#4906 from MechCoder/spark-6025 and squashes the following commits: 67146ab [MechCoder] Minor 352001f [MechCoder] Minor 6e8aa10 [MechCoder] Made the following changes Used mapPartition instead of map Refactored computeError and unpersisted broadcast variables bc99ac6 [MechCoder] Refactor the method and stuff dbda033 [MechCoder] [SPARK-6025] Add helper method evaluateEachIteration to extract learning curve
1 parent a95043b commit 25e271d

File tree

7 files changed

+96
-46
lines changed

7 files changed

+96
-46
lines changed

docs/mllib-ensembles.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset
464464
The training is stopped when the improvement in the validation error is not more than a certain tolerance
465465
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
466466
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
467-
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
468-
iterations.
467+
and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration`
468+
(which gives the error or loss per iteration) to tune the number of iterations.
469469

470470
### Examples
471471

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,9 @@ object AbsoluteError extends Loss {
4747
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
4848
}
4949

50-
/**
51-
* Method to calculate loss of the base learner for the gradient boosting calculation.
52-
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
53-
* purposes.
54-
* @param model Ensemble model
55-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
56-
* @return Mean absolute error of model on data
57-
*/
58-
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
59-
data.map { y =>
60-
val err = model.predict(y.features) - y.label
61-
math.abs(err)
62-
}.mean()
50+
override def computeError(prediction: Double, label: Double): Double = {
51+
val err = label - prediction
52+
math.abs(err)
6353
}
54+
6455
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,10 @@ object LogLoss extends Loss {
5050
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
5151
}
5252

53-
/**
54-
* Method to calculate loss of the base learner for the gradient boosting calculation.
55-
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
56-
* purposes.
57-
* @param model Ensemble model
58-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
59-
* @return Mean log loss of model on data
60-
*/
61-
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
62-
data.map { case point =>
63-
val prediction = model.predict(point.features)
64-
val margin = 2.0 * point.label * prediction
65-
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
66-
2.0 * MLUtils.log1pExp(-margin)
67-
}.mean()
53+
override def computeError(prediction: Double, label: Double): Double = {
54+
val margin = 2.0 * label * prediction
55+
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
56+
2.0 * MLUtils.log1pExp(-margin)
6857
}
58+
6959
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ trait Loss extends Serializable {
4747
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
4848
* @return Measure of model error on data
4949
*/
50-
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
50+
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
51+
data.map(point => computeError(model.predict(point.features), point.label)).mean()
52+
}
53+
54+
/**
55+
* Method to calculate loss when the predictions are already known.
56+
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
57+
* predicted values from previously fit trees.
58+
* @param prediction Predicted label.
59+
* @param label True label.
60+
* @return Measure of model error on datapoint.
61+
*/
62+
def computeError(prediction: Double, label: Double): Double
5163

5264
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,9 @@ object SquaredError extends Loss {
4747
2.0 * (model.predict(point.features) - point.label)
4848
}
4949

50-
/**
51-
* Method to calculate loss of the base learner for the gradient boosting calculation.
52-
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
53-
* purposes.
54-
* @param model Ensemble model
55-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
56-
* @return Mean squared error of model on data
57-
*/
58-
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
59-
data.map { y =>
60-
val err = model.predict(y.features) - y.label
61-
err * err
62-
}.mean()
50+
override def computeError(prediction: Double, label: Double): Double = {
51+
val err = prediction - label
52+
err * err
6353
}
54+
6455
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.mllib.linalg.Vector
31+
import org.apache.spark.mllib.regression.LabeledPoint
3132
import org.apache.spark.mllib.tree.configuration.Algo
3233
import org.apache.spark.mllib.tree.configuration.Algo._
3334
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
35+
import org.apache.spark.mllib.tree.loss.Loss
3436
import org.apache.spark.mllib.util.{Loader, Saveable}
3537
import org.apache.spark.rdd.RDD
3638
import org.apache.spark.sql.SQLContext
@@ -108,6 +110,58 @@ class GradientBoostedTreesModel(
108110
}
109111

110112
override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
113+
114+
/**
115+
* Method to compute error or loss for every iteration of gradient boosting.
116+
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
117+
* @param loss evaluation metric.
118+
* @return an array with index i having the losses or errors for the ensemble
119+
* containing the first i+1 trees
120+
*/
121+
def evaluateEachIteration(
122+
data: RDD[LabeledPoint],
123+
loss: Loss): Array[Double] = {
124+
125+
val sc = data.sparkContext
126+
val remappedData = algo match {
127+
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
128+
case _ => data
129+
}
130+
131+
val numIterations = trees.length
132+
val evaluationArray = Array.fill(numIterations)(0.0)
133+
134+
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
135+
val pred = treeWeights(0) * trees(0).predict(i.features)
136+
val error = loss.computeError(pred, i.label)
137+
(pred, error)
138+
}
139+
evaluationArray(0) = predictionAndError.values.mean()
140+
141+
// Avoid the model being copied across numIterations.
142+
val broadcastTrees = sc.broadcast(trees)
143+
val broadcastWeights = sc.broadcast(treeWeights)
144+
145+
(1 until numIterations).map { nTree =>
146+
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
147+
val currentTree = broadcastTrees.value(nTree)
148+
val currentTreeWeight = broadcastWeights.value(nTree)
149+
iter.map {
150+
case (point, (pred, error)) => {
151+
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
152+
val newError = loss.computeError(newPred, point.label)
153+
(newPred, newError)
154+
}
155+
}
156+
}
157+
evaluationArray(nTree) = predictionAndError.values.mean()
158+
}
159+
160+
broadcastTrees.unpersist()
161+
broadcastWeights.unpersist()
162+
evaluationArray
163+
}
164+
111165
}
112166

113167
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
175175
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
176176
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
177177
.runWithValidation(trainRdd, validateRdd)
178-
assert(gbtValidate.numTrees !== numIterations)
178+
val numTrees = gbtValidate.numTrees
179+
assert(numTrees !== numIterations)
179180

180181
// Test that it performs better on the validation dataset.
181-
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
182+
val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
182183
val (errorWithoutValidation, errorWithValidation) = {
183184
if (algo == Classification) {
184185
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
188189
}
189190
}
190191
assert(errorWithValidation <= errorWithoutValidation)
192+
193+
// Test that results from evaluateEachIteration comply with runWithValidation.
194+
// Note that convergenceTol is set to 0.0
195+
val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
196+
assert(evaluationArray.length === numIterations)
197+
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
198+
var i = 1
199+
while (i < numTrees) {
200+
assert(evaluationArray(i) <= evaluationArray(i - 1))
201+
i += 1
202+
}
191203
}
192204
}
193205
}

0 commit comments

Comments
 (0)