Skip to content

Commit a95bc22

Browse files
committed
timing for DecisionTree internals
1 parent 1c5555a commit a95bc22

File tree

1 file changed

+76
-4
lines changed

1 file changed

+76
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20+
import java.util.Calendar
21+
2022
import org.apache.spark.annotation.Experimental
2123
import org.apache.spark.Logging
2224
import org.apache.spark.mllib.regression.LabeledPoint
@@ -29,6 +31,40 @@ import org.apache.spark.mllib.tree.model._
2931
import org.apache.spark.rdd.RDD
3032
import org.apache.spark.util.random.XORShiftRandom
3133

34+
class TimeTracker {
35+
36+
var tmpTime: Long = Calendar.getInstance().getTimeInMillis
37+
38+
def reset(): Unit = {
39+
tmpTime = Calendar.getInstance().getTimeInMillis
40+
}
41+
42+
def elapsed(): Long = {
43+
Calendar.getInstance().getTimeInMillis - tmpTime
44+
}
45+
46+
var initTime: Long = 0 // Data retag and cache
47+
var findSplitsBinsTime: Long = 0
48+
var extractNodeInfoTime: Long = 0
49+
var extractInfoForLowerLevelsTime: Long = 0
50+
var findBestSplitsTime: Long = 0
51+
var findBinsForLevelTime: Long = 0
52+
var binAggregatesTime: Long = 0
53+
var chooseSplitsTime: Long = 0
54+
55+
override def toString: String = {
56+
s"DecisionTree timing\n" +
57+
s"initTime: $initTime\n" +
58+
s"findSplitsBinsTime: $findSplitsBinsTime\n" +
59+
s"extractNodeInfoTime: $extractNodeInfoTime\n" +
60+
s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" +
61+
s"findBestSplitsTime: $findBestSplitsTime\n" +
62+
s"findBinsForLevelTime: $findBinsForLevelTime\n" +
63+
s"binAggregatesTime: $binAggregatesTime\n" +
64+
s"chooseSplitsTime: $chooseSplitsTime\n"
65+
}
66+
}
67+
3268
/**
3369
* :: Experimental ::
3470
* A class which implements a decision tree learning algorithm for classification and regression.
@@ -47,16 +83,24 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
4783
*/
4884
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
4985

86+
val timer = new TimeTracker()
87+
timer.reset()
88+
5089
// Cache input RDD for speedup during multiple passes.
5190
val retaggedInput = input.retag(classOf[LabeledPoint]).cache()
5291
logDebug("algo = " + strategy.algo)
5392

93+
timer.initTime += timer.elapsed()
94+
timer.reset()
95+
5496
// Find the splits and the corresponding bins (interval between the splits) using a sample
5597
// of the input data.
5698
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
5799
val numBins = bins(0).length
58100
logDebug("numBins = " + numBins)
59101

102+
timer.findSplitsBinsTime += timer.elapsed()
103+
60104
// depth of the decision tree
61105
val maxDepth = strategy.maxDepth
62106
// the max number of nodes possible given the depth of the tree
@@ -98,6 +142,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
98142
* still survived the filters of the parent nodes.
99143
*/
100144

145+
var findBestSplitsTime: Long = 0
146+
var extractNodeInfoTime: Long = 0
147+
var extractInfoForLowerLevelsTime: Long = 0
148+
101149
var level = 0
102150
var break = false
103151
while (level <= maxDepth && !break) {
@@ -106,16 +154,23 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
106154
logDebug("level = " + level)
107155
logDebug("#####################################")
108156

157+
109158
// Find best split for all nodes at a level.
159+
timer.reset()
110160
val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities,
111-
strategy, level, filters, splits, bins, maxLevelForSingleGroup)
161+
strategy, level, filters, splits, bins, timer, maxLevelForSingleGroup)
162+
timer.findBestSplitsTime += timer.elapsed()
112163

113164
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
165+
timer.reset()
114166
// Extract info for nodes at the current level.
115167
extractNodeInfo(nodeSplitStats, level, index, nodes)
168+
timer.extractNodeInfoTime += timer.elapsed()
169+
timer.reset()
116170
// Extract info for nodes at the next lower level.
117171
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
118172
filters)
173+
timer.extractInfoForLowerLevelsTime += timer.elapsed()
119174
logDebug("final best split = " + nodeSplitStats._1)
120175
}
121176
require(math.pow(2, level) == splitsStatsForLevel.length)
@@ -129,6 +184,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
129184
}
130185
}
131186

187+
println(timer)
188+
132189
logDebug("#####################################")
133190
logDebug("Extracting tree model")
134191
logDebug("#####################################")
@@ -194,6 +251,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
194251
}
195252
}
196253

254+
197255
object DecisionTree extends Serializable with Logging {
198256

199257
/**
@@ -325,6 +383,7 @@ object DecisionTree extends Serializable with Logging {
325383
filters: Array[List[Filter]],
326384
splits: Array[Array[Split]],
327385
bins: Array[Array[Bin]],
386+
timer: TimeTracker,
328387
maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
329388
// split into groups to avoid memory overflow during aggregation
330389
if (level > maxLevelForSingleGroup) {
@@ -339,13 +398,13 @@ object DecisionTree extends Serializable with Logging {
339398
var groupIndex = 0
340399
while (groupIndex < numGroups) {
341400
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
342-
filters, splits, bins, numGroups, groupIndex)
401+
filters, splits, bins, timer, numGroups, groupIndex)
343402
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
344403
groupIndex += 1
345404
}
346405
bestSplits
347406
} else {
348-
findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
407+
findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer)
349408
}
350409
}
351410

@@ -372,6 +431,7 @@ object DecisionTree extends Serializable with Logging {
372431
filters: Array[List[Filter]],
373432
splits: Array[Array[Split]],
374433
bins: Array[Array[Bin]],
434+
timer: TimeTracker,
375435
numGroups: Int = 1,
376436
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
377437

@@ -628,9 +688,13 @@ object DecisionTree extends Serializable with Logging {
628688
arr
629689
}
630690

631-
// Find feature bins for all nodes at a level.
691+
timer.reset()
692+
693+
// Find feature bins for all nodes at a level.
632694
val binMappedRDD = input.map(x => findBinsForLevel(x))
633695

696+
timer.findBinsForLevelTime += timer.elapsed()
697+
634698
/**
635699
* Increment aggregate in location for (node, feature, bin, label).
636700
*
@@ -873,12 +937,16 @@ object DecisionTree extends Serializable with Logging {
873937
combinedAggregate
874938
}
875939

940+
timer.reset()
941+
876942
// Calculate bin aggregates.
877943
val binAggregates = {
878944
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
879945
}
880946
logDebug("binAggregates.length = " + binAggregates.length)
881947

948+
timer.binAggregatesTime += timer.elapsed()
949+
882950
/**
883951
* Calculates the information gain for all splits based upon left/right split aggregates.
884952
* @param leftNodeAgg left node aggregates
@@ -1282,6 +1350,8 @@ object DecisionTree extends Serializable with Logging {
12821350
}
12831351
}
12841352

1353+
timer.reset()
1354+
12851355
// Calculate best splits for all nodes at a given level
12861356
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
12871357
// Iterating over all nodes at this level
@@ -1295,6 +1365,8 @@ object DecisionTree extends Serializable with Logging {
12951365
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
12961366
node += 1
12971367
}
1368+
timer.chooseSplitsTime += timer.elapsed()
1369+
12981370
bestSplits
12991371
}
13001372

0 commit comments

Comments
 (0)