Skip to content

Commit 6b5651e

Browse files
committed
Updates based on code review. 1 major change: persisting to memory + disk, not just memory.
Details: DecisionTree * Changed: .cache() -> .persist(StorageLevel.MEMORY_AND_DISK) ** This gave major performance improvements on small tests. E.g., 500K examples, 500 features, depth 5, on MacBook, took 292 sec with cache() and 112 when using disk as well. * Change for to while loops * Small cleanups TimeTracker * Removed useless timing in DecisionTree TreePoint * Renamed features to binnedFeatures
1 parent 2d2aaaf commit 6b5651e

File tree

4 files changed

+46
-53
lines changed

4 files changed

+46
-53
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
2120
import scala.collection.JavaConverters._
2221

2322
import org.apache.spark.annotation.Experimental
@@ -32,6 +31,7 @@ import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
3231
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
3332
import org.apache.spark.mllib.tree.model._
3433
import org.apache.spark.rdd.RDD
34+
import org.apache.spark.storage.StorageLevel
3535
import org.apache.spark.util.random.XORShiftRandom
3636

3737

@@ -59,11 +59,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
5959

6060
timer.start("total")
6161

62-
// Cache input RDD for speedup during multiple passes.
6362
timer.start("init")
63+
6464
val retaggedInput = input.retag(classOf[LabeledPoint])
6565
logDebug("algo = " + strategy.algo)
66-
timer.stop("init")
6766

6867
// Find the splits and the corresponding bins (interval between the splits) using a sample
6968
// of the input data.
@@ -73,9 +72,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7372
timer.stop("findSplitsBins")
7473
logDebug("numBins = " + numBins)
7574

76-
timer.start("init")
77-
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
78-
timer.stop("init")
75+
// Cache input RDD for speedup during multiple passes.
76+
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
77+
.persist(StorageLevel.MEMORY_AND_DISK)
7978

8079
// depth of the decision tree
8180
val maxDepth = strategy.maxDepth
@@ -90,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
9089
// dummy value for top node (updated during first split calculation)
9190
val nodes = new Array[Node](maxNumNodes)
9291
// num features
93-
val numFeatures = treeInput.take(1)(0).features.size
92+
val numFeatures = treeInput.take(1)(0).binnedFeatures.size
9493

9594
// Calculate level for single group construction
9695

@@ -110,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
110109
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
111110
logDebug("max level for single group = " + maxLevelForSingleGroup)
112111

112+
timer.stop("init")
113+
113114
/*
114115
* The main idea here is to perform level-wise training of the decision tree nodes thus
115116
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
@@ -126,7 +127,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
126127
logDebug("level = " + level)
127128
logDebug("#####################################")
128129

129-
130130
// Find best split for all nodes at a level.
131131
timer.start("findBestSplits")
132132
val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
@@ -167,8 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
167167

168168
timer.stop("total")
169169

170-
logDebug("Internal timing for DecisionTree:")
171-
logDebug(s"$timer")
170+
logInfo("Internal timing for DecisionTree:")
171+
logInfo(s"$timer")
172172

173173
new DecisionTreeModel(topNode, strategy.algo)
174174
}
@@ -226,7 +226,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
226226
}
227227
}
228228

229-
230229
object DecisionTree extends Serializable with Logging {
231230

232231
/**
@@ -536,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
536535
logDebug("numNodes = " + numNodes)
537536

538537
// Find the number of features by looking at the first sample.
539-
val numFeatures = input.first().features.size
538+
val numFeatures = input.first().binnedFeatures.size
540539
logDebug("numFeatures = " + numFeatures)
541540

542541
// numBins: Number of bins = 1 + number of possible splits
@@ -578,12 +577,12 @@ object DecisionTree extends Serializable with Logging {
578577
}
579578

580579
// Apply each filter and check sample validity. Return false when invalid condition found.
581-
for (filter <- parentFilters) {
580+
parentFilters.foreach { filter =>
582581
val featureIndex = filter.split.feature
583582
val comparison = filter.comparison
584583
val isFeatureContinuous = filter.split.featureType == Continuous
585584
if (isFeatureContinuous) {
586-
val binId = treePoint.features(featureIndex)
585+
val binId = treePoint.binnedFeatures(featureIndex)
587586
val bin = bins(featureIndex)(binId)
588587
val featureValue = bin.highSplit.threshold
589588
val threshold = filter.split.threshold
@@ -598,9 +597,9 @@ object DecisionTree extends Serializable with Logging {
598597
val isUnorderedFeature =
599598
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
600599
val featureValue = if (isUnorderedFeature) {
601-
treePoint.features(featureIndex)
600+
treePoint.binnedFeatures(featureIndex)
602601
} else {
603-
val binId = treePoint.features(featureIndex)
602+
val binId = treePoint.binnedFeatures(featureIndex)
604603
bins(featureIndex)(binId).category
605604
}
606605
val containsFeature = filter.split.categories.contains(featureValue)
@@ -648,9 +647,8 @@ object DecisionTree extends Serializable with Logging {
648647
arr(shift) = InvalidBinIndex
649648
} else {
650649
var featureIndex = 0
651-
// TODO: Vectorize this
652650
while (featureIndex < numFeatures) {
653-
arr(shift + featureIndex) = treePoint.features(featureIndex)
651+
arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
654652
featureIndex += 1
655653
}
656654
}
@@ -660,9 +658,8 @@ object DecisionTree extends Serializable with Logging {
660658
}
661659

662660
// Find feature bins for all nodes at a level.
663-
timer.start("findBinsForLevel")
661+
timer.start("aggregation")
664662
val binMappedRDD = input.map(x => findBinsForLevel(x))
665-
timer.stop("findBinsForLevel")
666663

667664
/**
668665
* Increment aggregate in location for (node, feature, bin, label).
@@ -907,13 +904,11 @@ object DecisionTree extends Serializable with Logging {
907904
combinedAggregate
908905
}
909906

910-
911907
// Calculate bin aggregates.
912-
timer.start("binAggregates")
913908
val binAggregates = {
914909
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
915910
}
916-
timer.stop("binAggregates")
911+
timer.stop("aggregation")
917912
logDebug("binAggregates.length = " + binAggregates.length)
918913

919914
/**
@@ -1225,12 +1220,16 @@ object DecisionTree extends Serializable with Logging {
12251220
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
12261221
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
12271222

1228-
for (featureIndex <- 0 until numFeatures) {
1223+
var featureIndex = 0
1224+
while (featureIndex < numFeatures) {
12291225
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1230-
for (splitIndex <- 0 until numSplitsForFeature) {
1226+
var splitIndex = 0
1227+
while (splitIndex < numSplitsForFeature) {
12311228
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
12321229
splitIndex, rightNodeAgg, nodeImpurity)
1230+
splitIndex += 1
12331231
}
1232+
featureIndex += 1
12341233
}
12351234
gains
12361235
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ import org.apache.spark.annotation.Experimental
2525
* Time tracker implementation which holds labeled timers.
2626
*/
2727
@Experimental
28-
private[tree]
29-
class TimeTracker extends Serializable {
28+
private[tree] class TimeTracker extends Serializable {
3029

3130
private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
3231

@@ -36,24 +35,24 @@ class TimeTracker extends Serializable {
3635
* Starts a new timer, or re-starts a stopped timer.
3736
*/
3837
def start(timerLabel: String): Unit = {
39-
val tmpTime = System.nanoTime()
38+
val currentTime = System.nanoTime()
4039
if (starts.contains(timerLabel)) {
4140
throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
4241
s" timerLabel = $timerLabel before that timer was stopped.")
4342
}
44-
starts(timerLabel) = tmpTime
43+
starts(timerLabel) = currentTime
4544
}
4645

4746
/**
4847
* Stops a timer and returns the elapsed time in seconds.
4948
*/
5049
def stop(timerLabel: String): Double = {
51-
val tmpTime = System.nanoTime()
50+
val currentTime = System.nanoTime()
5251
if (!starts.contains(timerLabel)) {
5352
throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
5453
s" timerLabel = $timerLabel, but that timer was not started.")
5554
}
56-
val elapsed = tmpTime - starts(timerLabel)
55+
val elapsed = currentTime - starts(timerLabel)
5756
starts.remove(timerLabel)
5857
if (totals.contains(timerLabel)) {
5958
totals(timerLabel) += elapsed

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD
3535
* or any categorical feature used in regression or binary classification.
3636
*
3737
* @param label Label from LabeledPoint
38-
* @param features Binned feature values.
39-
* Same length as LabeledPoint.features, but values are bin indices.
38+
* @param binnedFeatures Binned feature values.
39+
* Same length as LabeledPoint.features, but values are bin indices.
4040
*/
41-
private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable {
41+
private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) extends Serializable {
4242
}
4343

44-
4544
private[tree] object TreePoint {
4645

4746
/**
@@ -76,7 +75,7 @@ private[tree] object TreePoint {
7675
val numFeatures = labeledPoint.features.size
7776
val numBins = bins(0).size
7877
val arr = new Array[Int](numFeatures)
79-
var featureIndex = 0 // offset by 1 for label
78+
var featureIndex = 0
8079
while (featureIndex < numFeatures) {
8180
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
8281
val isFeatureContinuous = featureInfo.isEmpty
@@ -98,7 +97,6 @@ private[tree] object TreePoint {
9897
new TreePoint(labeledPoint.label, arr)
9998
}
10099

101-
102100
/**
103101
* Find bin for one (labeledPoint, feature).
104102
*
@@ -129,11 +127,9 @@ private[tree] object TreePoint {
129127
val highThreshold = bin.highSplit.threshold
130128
if ((lowThreshold < feature) && (highThreshold >= feature)) {
131129
return mid
132-
}
133-
else if (lowThreshold >= feature) {
130+
} else if (lowThreshold >= feature) {
134131
right = mid - 1
135-
}
136-
else {
132+
} else {
137133
left = mid + 1
138134
}
139135
}
@@ -181,7 +177,8 @@ private[tree] object TreePoint {
181177
// Perform binary search for finding bin for continuous features.
182178
val binIndex = binarySearchForBins()
183179
if (binIndex == -1) {
184-
throw new UnknownError("No bin was found for continuous feature." +
180+
throw new RuntimeException("No bin was found for continuous feature." +
181+
" This error can occur when given invalid data values (such as NaN)." +
185182
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
186183
}
187184
binIndex
@@ -193,7 +190,8 @@ private[tree] object TreePoint {
193190
sequentialBinSearchForOrderedCategoricalFeature()
194191
}
195192
if (binIndex == -1) {
196-
throw new UnknownError("No bin was found for categorical feature." +
193+
throw new RuntimeException("No bin was found for categorical feature." +
194+
" This error can occur when given invalid data values (such as NaN)." +
197195
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
198196
}
199197
binIndex

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import org.apache.spark.mllib.tree.impl.TreePoint
21-
2220
import scala.collection.JavaConverters._
2321

2422
import org.scalatest.FunSuite
2523

26-
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
27-
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
28-
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
2924
import org.apache.spark.mllib.tree.configuration.Algo._
3025
import org.apache.spark.mllib.tree.configuration.FeatureType._
26+
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
27+
import org.apache.spark.mllib.tree.impl.TreePoint
28+
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
29+
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
3130
import org.apache.spark.mllib.linalg.Vectors
3231
import org.apache.spark.mllib.util.LocalSparkContext
3332
import org.apache.spark.mllib.regression.LabeledPoint
@@ -43,10 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
4342
prediction != expected.label
4443
}
4544
val accuracy = (input.length - numOffPredictions).toDouble / input.length
46-
if (accuracy < requiredAccuracy) {
47-
println(s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
48-
}
49-
assert(accuracy >= requiredAccuracy)
45+
assert(accuracy >= requiredAccuracy,
46+
s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
5047
}
5148

5249
def validateRegressor(
@@ -59,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
5956
err * err
6057
}.sum
6158
val mse = squaredError / input.length
62-
assert(mse <= requiredMSE)
59+
assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
6360
}
6461

6562
test("split and bin calculation") {

0 commit comments

Comments
 (0)