Skip to content

Commit 5f94342

Browse files
committed
Added treeAggregate since not yet merged from master. Moved node indexing functions to Node.
1 parent 61c4509 commit 5f94342

File tree

3 files changed

+64
-37
lines changed

3 files changed

+64
-37
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
package org.apache.spark.mllib.tree
1919

2020
import scala.collection.JavaConverters._
21-
import scala.collection.mutable
2221

2322
import org.apache.spark.annotation.Experimental
2423
import org.apache.spark.api.java.JavaRDD
2524
import org.apache.spark.Logging
25+
import org.apache.spark.mllib.rdd.RDDFunctions._
2626
import org.apache.spark.mllib.regression.LabeledPoint
2727
import org.apache.spark.mllib.tree.configuration.Strategy
2828
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -100,8 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
100100
// depth of the decision tree
101101
val maxDepth = strategy.maxDepth
102102
// the max number of nodes possible given the depth of the tree
103-
val maxNumNodes = DecisionTree.maxNodesInLevel(maxDepth + 1) - 1
104-
// TODO: CHECK val maxNumNodes = (2 << maxDepth) - 1
103+
val maxNumNodes = Node.maxNodesInLevel(maxDepth + 1) - 1
105104
// Initialize an array to hold parent impurity calculations for each node.
106105
val parentImpurities = new Array[Double](maxNumNodes)
107106
// dummy value for top node (updated during first split calculation)
@@ -153,18 +152,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
153152
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
154153
timer.stop("findBestSplits")
155154

156-
val levelNodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1
155+
val levelNodeIndexOffset = Node.maxNodesInLevel(level) - 1
157156
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
158157
/*println(s"splitsStatsForLevel: index=$index")
159158
println(s"\t split: ${nodeSplitStats._1}")
160159
println(s"\t gain stats: ${nodeSplitStats._2}")*/
161160
val nodeIndex = levelNodeIndexOffset + index
162-
val isLeftChild = level != 0 && nodeIndex % 2 == 1
163-
val parentNodeIndex = if (isLeftChild) { // -1 for root node
164-
(nodeIndex - 1) / 2
165-
} else {
166-
(nodeIndex - 2) / 2
167-
}
161+
val isLeftChild = Node.isLeftChild(nodeIndex)
162+
val parentNodeIndex = Node.parentIndex(nodeIndex) // -1 for root node
163+
168164
// if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf))
169165
// TODO: Use above check to skip unused branch of tree
170166

@@ -192,7 +188,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
192188
timer.stop("extractInfoForLowerLevels")
193189
logDebug("final best split = " + nodeSplitStats._1)
194190
}
195-
require(DecisionTree.maxNodesInLevel(level) == splitsStatsForLevel.length)
191+
require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
196192
// Check whether all the nodes at the current level at leaves.
197193
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
198194
logDebug("all leaf = " + allLeaf)
@@ -232,8 +228,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
232228
if (level >= maxDepth) {
233229
return
234230
}
235-
// TODO: Move nodeIndexOffset calc out of function?
236-
val leftNodeIndex = (2 << level) - 1 + 2 * index
231+
val leftNodeIndex = Node.maxNodesInSubtree(level) + 2 * index
237232
val leftImpurity = nodeSplitStats._2.leftImpurity
238233
logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
239234
parentImpurities(leftNodeIndex) = leftImpurity
@@ -547,7 +542,7 @@ object DecisionTree extends Serializable with Logging {
547542

548543
// numNodes: Number of nodes in this (level of tree, group),
549544
// where nodes at deeper (larger) levels may be divided into groups.
550-
val numNodes = DecisionTree.maxNodesInLevel(level) / numGroups
545+
val numNodes = Node.maxNodesInLevel(level) / numGroups
551546
logDebug("numNodes = " + numNodes)
552547

553548
// Find the number of features by looking at the first sample.
@@ -619,16 +614,8 @@ object DecisionTree extends Serializable with Logging {
619614
}
620615
}
621616

622-
def nodeIndexToLevel(idx: Int): Int = {
623-
if (idx == 0) {
624-
0
625-
} else {
626-
math.floor(math.log(idx) / math.log(2)).toInt
627-
}
628-
}
629-
630617
// Used for treePointToNodeIndex
631-
val levelOffset = DecisionTree.maxNodesInLevel(level) - 1
618+
val levelOffset = Node.maxNodesInLevel(level) - 1
632619

633620
/**
634621
* Find the node index for the given example.
@@ -787,7 +774,7 @@ object DecisionTree extends Serializable with Logging {
787774
timer.start("aggregation")
788775
val binAggregates = {
789776
val initAgg = getEmptyBinAggregates(metadata, numNodes)
790-
input.aggregate(initAgg)(binSeqOp, binCombOp)
777+
input.treeAggregate(initAgg)(binSeqOp, binCombOp)
791778
}
792779
timer.stop("aggregation")
793780
/*
@@ -804,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
804791
// Calculate best splits for all nodes at a given level
805792
timer.start("chooseSplits")
806793
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
807-
val nodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1
794+
val nodeIndexOffset = Node.maxNodesInLevel(level) - 1
808795
// Iterating over all nodes at this level
809796
var nodeIndex = 0
810797
while (nodeIndex < numNodes) {
@@ -1160,7 +1147,6 @@ object DecisionTree extends Serializable with Logging {
11601147
* For multiclass classification with a low-arity feature
11611148
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
11621149
* the feature is split based on subsets of categories.
1163-
* There are (1 << maxFeatureValue - 1) - 1 splits.
11641150
* (b) "ordered features"
11651151
* For regression and binary classification,
11661152
* and for multiclass classification with a high-arity feature,
@@ -1366,12 +1352,4 @@ object DecisionTree extends Serializable with Logging {
13661352
categories
13671353
}
13681354

1369-
private[tree] def maxNodesInLevel(level: Int): Int = {
1370-
math.pow(2, level).toInt
1371-
}
1372-
1373-
private[tree] def numUnorderedBins(arity: Int): Int = {
1374-
(math.pow(2, arity - 1) - 1).toInt
1375-
}
1376-
13771355
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private[tree] object DecisionTreeMetadata {
102102
// Note: The above check is equivalent to checking:
103103
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
104104
unorderedFeatures.add(f)
105-
numBins(f) = DecisionTree.numUnorderedBins(k)
105+
numBins(f) = numUnorderedBins(k)
106106
} else {
107107
// TODO: Check the below k <= maxBins.
108108
// This used to be k < maxPossibleBins, but <= should work.
@@ -129,4 +129,12 @@ private[tree] object DecisionTreeMetadata {
129129
strategy.impurity, strategy.quantileCalculationStrategy)
130130
}
131131

132+
/**
133+
* Given the arity of a categorical feature (arity = number of categories),
134+
* return the number of bins for the feature if it is to be treated as an unordered feature.
135+
*/
136+
def numUnorderedBins(arity: Int): Int = {
137+
(1 << arity - 1) - 1
138+
}
139+
132140
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ class Node (
5252
*/
5353
def build(nodes: Array[Node]): Unit = {
5454

55-
logDebug("building node " + id + " at level " +
56-
(scala.math.log(id + 1)/scala.math.log(2)).toInt )
55+
logDebug("building node " + id + " at level " + Node.indexToLevel(id))
5756
logDebug("id = " + id + ", split = " + split)
5857
logDebug("stats = " + stats)
5958
logDebug("predict = " + predict)
@@ -148,3 +147,45 @@ class Node (
148147
}
149148

150149
}
150+
151+
private[tree] object Node {
152+
153+
/**
154+
* Return the level of a tree which the given node is in.
155+
*/
156+
def indexToLevel(nodeIndex: Int): Int = {
157+
math.floor(math.log(nodeIndex + 1) / math.log(2)).toInt
158+
}
159+
160+
/**
161+
* Returns true if this is a left child.
162+
* Note: Returns false for the root.
163+
*/
164+
def isLeftChild(nodeIndex: Int): Boolean = nodeIndex != 0 && nodeIndex % 2 == 1
165+
166+
/**
167+
* Get the parent index of the given node, or -1 if it is the root.
168+
*/
169+
def parentIndex(nodeIndex: Int): Int = {
170+
if (isLeftChild(nodeIndex)) { // -1 for root node
171+
(nodeIndex - 1) / 2
172+
} else {
173+
(nodeIndex - 2) / 2
174+
}
175+
176+
}
177+
178+
/**
179+
* Return the maximum number of nodes which can be in the given level of the tree.
180+
* @param level Level of tree (0 = root).
181+
*/
182+
private[tree] def maxNodesInLevel(level: Int): Int = 1 << level
183+
184+
/**
185+
* Return the maximum number of nodes which can be in or above the given level of the tree
186+
* (i.e., for the entire subtree from the root to this level).
187+
* @param level Level of tree (0 = root).
188+
*/
189+
private[tree] def maxNodesInSubtree(level: Int): Int = (2 << level) - 1
190+
191+
}

0 commit comments

Comments
 (0)