17
17
18
18
package org .apache .spark .mllib .tree
19
19
20
-
21
20
import scala .collection .JavaConverters ._
22
21
23
22
import org .apache .spark .annotation .Experimental
@@ -32,6 +31,7 @@ import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
32
31
import org .apache .spark .mllib .tree .impurity .{Impurities , Impurity }
33
32
import org .apache .spark .mllib .tree .model ._
34
33
import org .apache .spark .rdd .RDD
34
+ import org .apache .spark .storage .StorageLevel
35
35
import org .apache .spark .util .random .XORShiftRandom
36
36
37
37
@@ -59,11 +59,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
59
59
60
60
timer.start(" total" )
61
61
62
- // Cache input RDD for speedup during multiple passes.
63
62
timer.start(" init" )
63
+
64
64
val retaggedInput = input.retag(classOf [LabeledPoint ])
65
65
logDebug(" algo = " + strategy.algo)
66
- timer.stop(" init" )
67
66
68
67
// Find the splits and the corresponding bins (interval between the splits) using a sample
69
68
// of the input data.
@@ -73,9 +72,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
73
72
timer.stop(" findSplitsBins" )
74
73
logDebug(" numBins = " + numBins)
75
74
76
- timer.start( " init " )
77
- val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins).cache()
78
- timer.stop( " init " )
75
+ // Cache input RDD for speedup during multiple passes.
76
+ val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins)
77
+ .persist( StorageLevel . MEMORY_AND_DISK )
79
78
80
79
// depth of the decision tree
81
80
val maxDepth = strategy.maxDepth
@@ -90,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
90
89
// dummy value for top node (updated during first split calculation)
91
90
val nodes = new Array [Node ](maxNumNodes)
92
91
// num features
93
- val numFeatures = treeInput.take(1 )(0 ).features .size
92
+ val numFeatures = treeInput.take(1 )(0 ).binnedFeatures .size
94
93
95
94
// Calculate level for single group construction
96
95
@@ -110,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
110
109
(math.log(maxNumberOfNodesPerGroup) / math.log(2 )).floor.toInt, 0 )
111
110
logDebug(" max level for single group = " + maxLevelForSingleGroup)
112
111
112
+ timer.stop(" init" )
113
+
113
114
/*
114
115
* The main idea here is to perform level-wise training of the decision tree nodes thus
115
116
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
@@ -126,7 +127,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
126
127
logDebug(" level = " + level)
127
128
logDebug(" #####################################" )
128
129
129
-
130
130
// Find best split for all nodes at a level.
131
131
timer.start(" findBestSplits" )
132
132
val splitsStatsForLevel = DecisionTree .findBestSplits(treeInput, parentImpurities,
@@ -167,8 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
167
167
168
168
timer.stop(" total" )
169
169
170
- logDebug (" Internal timing for DecisionTree:" )
171
- logDebug (s " $timer" )
170
+ logInfo (" Internal timing for DecisionTree:" )
171
+ logInfo (s " $timer" )
172
172
173
173
new DecisionTreeModel (topNode, strategy.algo)
174
174
}
@@ -226,7 +226,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
226
226
}
227
227
}
228
228
229
-
230
229
object DecisionTree extends Serializable with Logging {
231
230
232
231
/**
@@ -536,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
536
535
logDebug(" numNodes = " + numNodes)
537
536
538
537
// Find the number of features by looking at the first sample.
539
- val numFeatures = input.first().features .size
538
+ val numFeatures = input.first().binnedFeatures .size
540
539
logDebug(" numFeatures = " + numFeatures)
541
540
542
541
// numBins: Number of bins = 1 + number of possible splits
@@ -578,12 +577,12 @@ object DecisionTree extends Serializable with Logging {
578
577
}
579
578
580
579
// Apply each filter and check sample validity. Return false when invalid condition found.
581
- for (filter <- parentFilters) {
580
+ parentFilters.foreach { filter =>
582
581
val featureIndex = filter.split.feature
583
582
val comparison = filter.comparison
584
583
val isFeatureContinuous = filter.split.featureType == Continuous
585
584
if (isFeatureContinuous) {
586
- val binId = treePoint.features (featureIndex)
585
+ val binId = treePoint.binnedFeatures (featureIndex)
587
586
val bin = bins(featureIndex)(binId)
588
587
val featureValue = bin.highSplit.threshold
589
588
val threshold = filter.split.threshold
@@ -598,9 +597,9 @@ object DecisionTree extends Serializable with Logging {
598
597
val isUnorderedFeature =
599
598
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
600
599
val featureValue = if (isUnorderedFeature) {
601
- treePoint.features (featureIndex)
600
+ treePoint.binnedFeatures (featureIndex)
602
601
} else {
603
- val binId = treePoint.features (featureIndex)
602
+ val binId = treePoint.binnedFeatures (featureIndex)
604
603
bins(featureIndex)(binId).category
605
604
}
606
605
val containsFeature = filter.split.categories.contains(featureValue)
@@ -648,9 +647,8 @@ object DecisionTree extends Serializable with Logging {
648
647
arr(shift) = InvalidBinIndex
649
648
} else {
650
649
var featureIndex = 0
651
- // TODO: Vectorize this
652
650
while (featureIndex < numFeatures) {
653
- arr(shift + featureIndex) = treePoint.features (featureIndex)
651
+ arr(shift + featureIndex) = treePoint.binnedFeatures (featureIndex)
654
652
featureIndex += 1
655
653
}
656
654
}
@@ -660,9 +658,8 @@ object DecisionTree extends Serializable with Logging {
660
658
}
661
659
662
660
// Find feature bins for all nodes at a level.
663
- timer.start(" findBinsForLevel " )
661
+ timer.start(" aggregation " )
664
662
val binMappedRDD = input.map(x => findBinsForLevel(x))
665
- timer.stop(" findBinsForLevel" )
666
663
667
664
/**
668
665
* Increment aggregate in location for (node, feature, bin, label).
@@ -907,13 +904,11 @@ object DecisionTree extends Serializable with Logging {
907
904
combinedAggregate
908
905
}
909
906
910
-
911
907
// Calculate bin aggregates.
912
- timer.start(" binAggregates" )
913
908
val binAggregates = {
914
909
binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
915
910
}
916
- timer.stop(" binAggregates " )
911
+ timer.stop(" aggregation " )
917
912
logDebug(" binAggregates.length = " + binAggregates.length)
918
913
919
914
/**
@@ -1225,12 +1220,16 @@ object DecisionTree extends Serializable with Logging {
1225
1220
nodeImpurity : Double ): Array [Array [InformationGainStats ]] = {
1226
1221
val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
1227
1222
1228
- for (featureIndex <- 0 until numFeatures) {
1223
+ var featureIndex = 0
1224
+ while (featureIndex < numFeatures) {
1229
1225
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1230
- for (splitIndex <- 0 until numSplitsForFeature) {
1226
+ var splitIndex = 0
1227
+ while (splitIndex < numSplitsForFeature) {
1231
1228
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
1232
1229
splitIndex, rightNodeAgg, nodeImpurity)
1230
+ splitIndex += 1
1233
1231
}
1232
+ featureIndex += 1
1234
1233
}
1235
1234
gains
1236
1235
}
0 commit comments