17
17
18
18
package org .apache .spark .mllib .tree
19
19
20
+ import java .util .Calendar
21
+
20
22
import org .apache .spark .annotation .Experimental
21
23
import org .apache .spark .Logging
22
24
import org .apache .spark .mllib .regression .LabeledPoint
@@ -29,6 +31,40 @@ import org.apache.spark.mllib.tree.model._
29
31
import org .apache .spark .rdd .RDD
30
32
import org .apache .spark .util .random .XORShiftRandom
31
33
34
+ class TimeTracker {
35
+
36
+ var tmpTime : Long = Calendar .getInstance().getTimeInMillis
37
+
38
+ def reset (): Unit = {
39
+ tmpTime = Calendar .getInstance().getTimeInMillis
40
+ }
41
+
42
+ def elapsed (): Long = {
43
+ Calendar .getInstance().getTimeInMillis - tmpTime
44
+ }
45
+
46
+ var initTime : Long = 0 // Data retag and cache
47
+ var findSplitsBinsTime : Long = 0
48
+ var extractNodeInfoTime : Long = 0
49
+ var extractInfoForLowerLevelsTime : Long = 0
50
+ var findBestSplitsTime : Long = 0
51
+ var findBinsForLevelTime : Long = 0
52
+ var binAggregatesTime : Long = 0
53
+ var chooseSplitsTime : Long = 0
54
+
55
+ override def toString : String = {
56
+ s " DecisionTree timing \n " +
57
+ s " initTime: $initTime\n " +
58
+ s " findSplitsBinsTime: $findSplitsBinsTime\n " +
59
+ s " extractNodeInfoTime: $extractNodeInfoTime\n " +
60
+ s " extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n " +
61
+ s " findBestSplitsTime: $findBestSplitsTime\n " +
62
+ s " findBinsForLevelTime: $findBinsForLevelTime\n " +
63
+ s " binAggregatesTime: $binAggregatesTime\n " +
64
+ s " chooseSplitsTime: $chooseSplitsTime\n "
65
+ }
66
+ }
67
+
32
68
/**
33
69
* :: Experimental ::
34
70
* A class which implements a decision tree learning algorithm for classification and regression.
@@ -47,16 +83,24 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
47
83
*/
48
84
def train (input : RDD [LabeledPoint ]): DecisionTreeModel = {
49
85
86
+ val timer = new TimeTracker ()
87
+ timer.reset()
88
+
50
89
// Cache input RDD for speedup during multiple passes.
51
90
val retaggedInput = input.retag(classOf [LabeledPoint ]).cache()
52
91
logDebug(" algo = " + strategy.algo)
53
92
93
+ timer.initTime += timer.elapsed()
94
+ timer.reset()
95
+
54
96
// Find the splits and the corresponding bins (interval between the splits) using a sample
55
97
// of the input data.
56
98
val (splits, bins) = DecisionTree .findSplitsBins(retaggedInput, strategy)
57
99
val numBins = bins(0 ).length
58
100
logDebug(" numBins = " + numBins)
59
101
102
+ timer.findSplitsBinsTime += timer.elapsed()
103
+
60
104
// depth of the decision tree
61
105
val maxDepth = strategy.maxDepth
62
106
// the max number of nodes possible given the depth of the tree
@@ -98,6 +142,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
98
142
* still survived the filters of the parent nodes.
99
143
*/
100
144
145
+ var findBestSplitsTime : Long = 0
146
+ var extractNodeInfoTime : Long = 0
147
+ var extractInfoForLowerLevelsTime : Long = 0
148
+
101
149
var level = 0
102
150
var break = false
103
151
while (level <= maxDepth && ! break) {
@@ -106,16 +154,23 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
106
154
logDebug(" level = " + level)
107
155
logDebug(" #####################################" )
108
156
157
+
109
158
// Find best split for all nodes at a level.
159
+ timer.reset()
110
160
val splitsStatsForLevel = DecisionTree .findBestSplits(retaggedInput, parentImpurities,
111
- strategy, level, filters, splits, bins, maxLevelForSingleGroup)
161
+ strategy, level, filters, splits, bins, timer, maxLevelForSingleGroup)
162
+ timer.findBestSplitsTime += timer.elapsed()
112
163
113
164
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
165
+ timer.reset()
114
166
// Extract info for nodes at the current level.
115
167
extractNodeInfo(nodeSplitStats, level, index, nodes)
168
+ timer.extractNodeInfoTime += timer.elapsed()
169
+ timer.reset()
116
170
// Extract info for nodes at the next lower level.
117
171
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
118
172
filters)
173
+ timer.extractInfoForLowerLevelsTime += timer.elapsed()
119
174
logDebug(" final best split = " + nodeSplitStats._1)
120
175
}
121
176
require(math.pow(2 , level) == splitsStatsForLevel.length)
@@ -129,6 +184,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
129
184
}
130
185
}
131
186
187
+ println(timer)
188
+
132
189
logDebug(" #####################################" )
133
190
logDebug(" Extracting tree model" )
134
191
logDebug(" #####################################" )
@@ -194,6 +251,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
194
251
}
195
252
}
196
253
254
+
197
255
object DecisionTree extends Serializable with Logging {
198
256
199
257
/**
@@ -325,6 +383,7 @@ object DecisionTree extends Serializable with Logging {
325
383
filters : Array [List [Filter ]],
326
384
splits : Array [Array [Split ]],
327
385
bins : Array [Array [Bin ]],
386
+ timer : TimeTracker ,
328
387
maxLevelForSingleGroup : Int ): Array [(Split , InformationGainStats )] = {
329
388
// split into groups to avoid memory overflow during aggregation
330
389
if (level > maxLevelForSingleGroup) {
@@ -339,13 +398,13 @@ object DecisionTree extends Serializable with Logging {
339
398
var groupIndex = 0
340
399
while (groupIndex < numGroups) {
341
400
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
342
- filters, splits, bins, numGroups, groupIndex)
401
+ filters, splits, bins, timer, numGroups, groupIndex)
343
402
bestSplits = Array .concat(bestSplits, bestSplitsForGroup)
344
403
groupIndex += 1
345
404
}
346
405
bestSplits
347
406
} else {
348
- findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
407
+ findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer )
349
408
}
350
409
}
351
410
@@ -372,6 +431,7 @@ object DecisionTree extends Serializable with Logging {
372
431
filters : Array [List [Filter ]],
373
432
splits : Array [Array [Split ]],
374
433
bins : Array [Array [Bin ]],
434
+ timer : TimeTracker ,
375
435
numGroups : Int = 1 ,
376
436
groupIndex : Int = 0 ): Array [(Split , InformationGainStats )] = {
377
437
@@ -628,9 +688,13 @@ object DecisionTree extends Serializable with Logging {
628
688
arr
629
689
}
630
690
631
- // Find feature bins for all nodes at a level.
691
+ timer.reset()
692
+
693
+ // Find feature bins for all nodes at a level.
632
694
val binMappedRDD = input.map(x => findBinsForLevel(x))
633
695
696
+ timer.findBinsForLevelTime += timer.elapsed()
697
+
634
698
/**
635
699
* Increment aggregate in location for (node, feature, bin, label).
636
700
*
@@ -873,12 +937,16 @@ object DecisionTree extends Serializable with Logging {
873
937
combinedAggregate
874
938
}
875
939
940
+ timer.reset()
941
+
876
942
// Calculate bin aggregates.
877
943
val binAggregates = {
878
944
binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
879
945
}
880
946
logDebug(" binAggregates.length = " + binAggregates.length)
881
947
948
+ timer.binAggregatesTime += timer.elapsed()
949
+
882
950
/**
883
951
* Calculates the information gain for all splits based upon left/right split aggregates.
884
952
* @param leftNodeAgg left node aggregates
@@ -1282,6 +1350,8 @@ object DecisionTree extends Serializable with Logging {
1282
1350
}
1283
1351
}
1284
1352
1353
+ timer.reset()
1354
+
1285
1355
// Calculate best splits for all nodes at a given level
1286
1356
val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
1287
1357
// Iterating over all nodes at this level
@@ -1295,6 +1365,8 @@ object DecisionTree extends Serializable with Logging {
1295
1365
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
1296
1366
node += 1
1297
1367
}
1368
+ timer.chooseSplitsTime += timer.elapsed()
1369
+
1298
1370
bestSplits
1299
1371
}
1300
1372
0 commit comments