Skip to content

Commit 6e8aa10

Browse files
committed
Made the following changes
Used mapPartition instead of map Refactored computeError and unpersisted broadcast variables
1 parent bc99ac6 commit 6e8aa10

File tree

6 files changed

+40
-80
lines changed

6 files changed

+40
-80
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,15 @@ object AbsoluteError extends Loss {
4747
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
4848
}
4949

50-
/**
51-
* Method to calculate loss of the base learner for the gradient boosting calculation.
52-
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
53-
* purposes.
54-
* @param model Ensemble model
55-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
56-
* @return Mean absolute error of model on data
57-
*/
58-
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
59-
data.map { y =>
60-
val err = model.predict(y.features) - y.label
61-
math.abs(err)
62-
}.mean()
63-
}
64-
6550
/**
6651
* Method to calculate loss when the predictions are already known.
6752
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
6853
* predicted values from previously fit trees.
69-
* @param datum: LabeledPoint
70-
* @param prediction: Predicted label.
71-
* @return Absolute error of model on the given datapoint.
54+
* @param prediction Predicted label.
55+
* @param datum LabeledPoint.
56+
* @return Absolute error of model on the given datapoint.
7257
*/
73-
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
58+
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
7459
val err = datum.label - prediction
7560
math.abs(err)
7661
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,15 @@ object LogLoss extends Loss {
5050
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
5151
}
5252

53-
/**
54-
* Method to calculate loss of the base learner for the gradient boosting calculation.
55-
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
56-
* purposes.
57-
* @param model Ensemble model
58-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
59-
* @return Mean log loss of model on data
60-
*/
61-
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
62-
data.map { case point =>
63-
val prediction = model.predict(point.features)
64-
val margin = 2.0 * point.label * prediction
65-
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
66-
2.0 * MLUtils.log1pExp(-margin)
67-
}.mean()
68-
}
69-
7053
/**
7154
* Method to calculate loss when the predictions are already known.
7255
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
7356
* predicted values from previously fit trees.
74-
* @param datum: LabeledPoint
75-
* @param prediction: Predicted label.
57+
* @param prediction Predicted label.
58+
* @param datum LabeledPoint
7659
* @return log loss of model on the datapoint.
7760
*/
78-
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
61+
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
7962
val margin = 2.0 * datum.label * prediction
8063
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
8164
2.0 * MLUtils.log1pExp(-margin)

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,18 @@ trait Loss extends Serializable {
4747
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
4848
* @return Measure of model error on data
4949
*/
50-
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
50+
def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
51+
data.map(point => computeError(model.predict(point.features), point)).mean()
52+
}
5153

5254
/**
5355
* Method to calculate loss when the predictions are already known.
5456
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
5557
* predicted values from previously fit trees.
56-
* @param datum: LabeledPoint
57-
* @param prediction: Predicted label.
58+
* @param prediction Predicted label.
59+
* @param datum LabeledPoint
5860
* @return Measure of model error on datapoint.
5961
*/
60-
def computeError(datum: LabeledPoint, prediction: Double) : Double
62+
def computeError(prediction: Double, datum: LabeledPoint): Double
6163

6264
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,15 @@ object SquaredError extends Loss {
4747
2.0 * (model.predict(point.features) - point.label)
4848
}
4949

50-
/**
51-
* Method to calculate loss of the base learner for the gradient boosting calculation.
52-
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
53-
* purposes.
54-
* @param model Ensemble model
55-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
56-
* @return Mean squared error of model on data
57-
*/
58-
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
59-
data.map { y =>
60-
val err = model.predict(y.features) - y.label
61-
err * err
62-
}.mean()
63-
}
64-
6550
/**
6651
* Method to calculate loss when the predictions are already known.
6752
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
6853
* predicted values from previously fit trees.
69-
* @param datum: LabeledPoint
70-
* @param prediction: Predicted label.
71-
* @return Mean squared error of model on datapoint.
54+
* @param prediction Predicted label.
55+
* @param datum LabeledPoint
56+
* @return Mean squared error of model on datapoint.
7257
*/
73-
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
58+
override def computeError(prediction: Double, datum: LabeledPoint): Double = {
7459
val err = prediction - datum.label
7560
err * err
7661
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,47 +113,52 @@ class GradientBoostedTreesModel(
113113

114114
/**
115115
* Method to compute error or loss for every iteration of gradient boosting.
116-
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
117-
* @param loss: evaluation metric.
116+
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
117+
* @param loss evaluation metric.
118118
* @return an array with index i having the losses or errors for the ensemble
119119
* containing trees 1 to i + 1
120120
*/
121121
def evaluateEachIteration(
122122
data: RDD[LabeledPoint],
123-
loss: Loss) : Array[Double] = {
123+
loss: Loss): Array[Double] = {
124124

125125
val sc = data.sparkContext
126126
val remappedData = algo match {
127127
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
128128
case _ => data
129129
}
130-
val initialTree = trees(0)
130+
131131
val numIterations = trees.length
132132
val evaluationArray = Array.fill(numIterations)(0.0)
133133

134-
// Initial weight is 1.0
135-
var predictionErrorModel = remappedData.map {i =>
136-
val pred = initialTree.predict(i.features)
137-
val error = loss.computeError(i, pred)
134+
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
135+
val pred = treeWeights(0) * trees(0).predict(i.features)
136+
val error = loss.computeError(pred, i)
138137
(pred, error)
139138
}
140-
evaluationArray(0) = predictionErrorModel.values.mean()
139+
evaluationArray(0) = predictionAndError.values.mean()
141140

142141
// Avoid the model being copied across numIterations.
143142
val broadcastTrees = sc.broadcast(trees)
144143
val broadcastWeights = sc.broadcast(treeWeights)
145144

146-
(1 until numIterations).map {nTree =>
147-
predictionErrorModel = (remappedData zip predictionErrorModel) map {
148-
case (point, (pred, error)) => {
149-
val newPred = pred + (
150-
broadcastTrees.value(nTree).predict(point.features) * broadcastWeights.value(nTree))
151-
val newError = loss.computeError(point, newPred)
152-
(newPred, newError)
145+
(1 until numIterations).map { nTree =>
146+
val currentTree = broadcastTrees.value(nTree)
147+
val currentTreeWeight = broadcastWeights.value(nTree)
148+
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
149+
iter map {
150+
case (point, (pred, error)) => {
151+
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
152+
val newError = loss.computeError(newPred, point)
153+
(newPred, newError)
154+
}
153155
}
154156
}
155-
evaluationArray(nTree) = predictionErrorModel.values.mean()
157+
evaluationArray(nTree) = predictionAndError.values.mean()
156158
}
159+
160+
broadcastTrees.unpersist()
161+
broadcastWeights.unpersist()
157162
evaluationArray
158163
}
159164

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
197197
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
198198
var i = 1
199199
while (i < numTrees) {
200-
assert(evaluationArray(i) < evaluationArray(i - 1))
200+
assert(evaluationArray(i) <= evaluationArray(i - 1))
201201
i += 1
202202
}
203203
}

0 commit comments

Comments
 (0)