18
18
package org .apache .spark .mllib .tree
19
19
20
20
import scala .collection .JavaConverters ._
21
- import scala .collection .mutable
22
21
23
22
import org .apache .spark .annotation .Experimental
24
23
import org .apache .spark .api .java .JavaRDD
25
24
import org .apache .spark .Logging
25
+ import org .apache .spark .mllib .rdd .RDDFunctions ._
26
26
import org .apache .spark .mllib .regression .LabeledPoint
27
27
import org .apache .spark .mllib .tree .configuration .Strategy
28
28
import org .apache .spark .mllib .tree .configuration .Algo ._
@@ -100,8 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
100
100
// depth of the decision tree
101
101
val maxDepth = strategy.maxDepth
102
102
// the max number of nodes possible given the depth of the tree
103
- val maxNumNodes = DecisionTree .maxNodesInLevel(maxDepth + 1 ) - 1
104
- // TODO: CHECK val maxNumNodes = (2 << maxDepth) - 1
103
+ val maxNumNodes = Node .maxNodesInLevel(maxDepth + 1 ) - 1
105
104
// Initialize an array to hold parent impurity calculations for each node.
106
105
val parentImpurities = new Array [Double ](maxNumNodes)
107
106
// dummy value for top node (updated during first split calculation)
@@ -153,18 +152,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
153
152
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
154
153
timer.stop(" findBestSplits" )
155
154
156
- val levelNodeIndexOffset = DecisionTree .maxNodesInLevel(level) - 1
155
+ val levelNodeIndexOffset = Node .maxNodesInLevel(level) - 1
157
156
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
158
157
/* println(s"splitsStatsForLevel: index=$index")
159
158
println(s"\t split: ${nodeSplitStats._1}")
160
159
println(s"\t gain stats: ${nodeSplitStats._2}")*/
161
160
val nodeIndex = levelNodeIndexOffset + index
162
- val isLeftChild = level != 0 && nodeIndex % 2 == 1
163
- val parentNodeIndex = if (isLeftChild) { // -1 for root node
164
- (nodeIndex - 1 ) / 2
165
- } else {
166
- (nodeIndex - 2 ) / 2
167
- }
161
+ val isLeftChild = Node .isLeftChild(nodeIndex)
162
+ val parentNodeIndex = Node .parentIndex(nodeIndex) // -1 for root node
163
+
168
164
// if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf))
169
165
// TODO: Use above check to skip unused branch of tree
170
166
@@ -192,7 +188,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
192
188
timer.stop(" extractInfoForLowerLevels" )
193
189
logDebug(" final best split = " + nodeSplitStats._1)
194
190
}
195
- require(DecisionTree .maxNodesInLevel(level) == splitsStatsForLevel.length)
191
+ require(Node .maxNodesInLevel(level) == splitsStatsForLevel.length)
196
192
// Check whether all the nodes at the current level at leaves.
197
193
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
198
194
logDebug(" all leaf = " + allLeaf)
@@ -232,8 +228,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
232
228
if (level >= maxDepth) {
233
229
return
234
230
}
235
- // TODO: Move nodeIndexOffset calc out of function?
236
- val leftNodeIndex = (2 << level) - 1 + 2 * index
231
+ val leftNodeIndex = Node .maxNodesInSubtree(level) + 2 * index
237
232
val leftImpurity = nodeSplitStats._2.leftImpurity
238
233
logDebug(" leftNodeIndex = " + leftNodeIndex + " , impurity = " + leftImpurity)
239
234
parentImpurities(leftNodeIndex) = leftImpurity
@@ -547,7 +542,7 @@ object DecisionTree extends Serializable with Logging {
547
542
548
543
// numNodes: Number of nodes in this (level of tree, group),
549
544
// where nodes at deeper (larger) levels may be divided into groups.
550
- val numNodes = DecisionTree .maxNodesInLevel(level) / numGroups
545
+ val numNodes = Node .maxNodesInLevel(level) / numGroups
551
546
logDebug(" numNodes = " + numNodes)
552
547
553
548
// Find the number of features by looking at the first sample.
@@ -619,16 +614,8 @@ object DecisionTree extends Serializable with Logging {
619
614
}
620
615
}
621
616
622
- def nodeIndexToLevel (idx : Int ): Int = {
623
- if (idx == 0 ) {
624
- 0
625
- } else {
626
- math.floor(math.log(idx) / math.log(2 )).toInt
627
- }
628
- }
629
-
630
617
// Used for treePointToNodeIndex
631
- val levelOffset = DecisionTree .maxNodesInLevel(level) - 1
618
+ val levelOffset = Node .maxNodesInLevel(level) - 1
632
619
633
620
/**
634
621
* Find the node index for the given example.
@@ -787,7 +774,7 @@ object DecisionTree extends Serializable with Logging {
787
774
timer.start(" aggregation" )
788
775
val binAggregates = {
789
776
val initAgg = getEmptyBinAggregates(metadata, numNodes)
790
- input.aggregate (initAgg)(binSeqOp, binCombOp)
777
+ input.treeAggregate (initAgg)(binSeqOp, binCombOp)
791
778
}
792
779
timer.stop(" aggregation" )
793
780
/*
@@ -804,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
804
791
// Calculate best splits for all nodes at a given level
805
792
timer.start(" chooseSplits" )
806
793
val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
807
- val nodeIndexOffset = DecisionTree .maxNodesInLevel(level) - 1
794
+ val nodeIndexOffset = Node .maxNodesInLevel(level) - 1
808
795
// Iterating over all nodes at this level
809
796
var nodeIndex = 0
810
797
while (nodeIndex < numNodes) {
@@ -1160,7 +1147,6 @@ object DecisionTree extends Serializable with Logging {
1160
1147
* For multiclass classification with a low-arity feature
1161
1148
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
1162
1149
* the feature is split based on subsets of categories.
1163
- * There are (1 << maxFeatureValue - 1) - 1 splits.
1164
1150
* (b) "ordered features"
1165
1151
* For regression and binary classification,
1166
1152
* and for multiclass classification with a high-arity feature,
@@ -1366,12 +1352,4 @@ object DecisionTree extends Serializable with Logging {
1366
1352
categories
1367
1353
}
1368
1354
1369
- private [tree] def maxNodesInLevel (level : Int ): Int = {
1370
- math.pow(2 , level).toInt
1371
- }
1372
-
1373
- private [tree] def numUnorderedBins (arity : Int ): Int = {
1374
- (math.pow(2 , arity - 1 ) - 1 ).toInt
1375
- }
1376
-
1377
1355
}
0 commit comments