Skip to content

Commit 6b7de78

Browse files
committed
minor refactoring and tests
Signed-off-by: Manish Amde <[email protected]>
1 parent d504eb1 commit 6b7de78

File tree

2 files changed

+63
-54
lines changed

2 files changed

+63
-54
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,15 @@ object DecisionTree extends Serializable with Logging {
152152
//Find the number of features by looking at the first sample
153153
val numFeatures = input.take(1)(0).features.length
154154
logDebug("numFeatures = " + numFeatures)
155-
val numSplits = strategy.numBins
156-
logDebug("numSplits = " + numSplits)
155+
val numBins = strategy.numBins
156+
logDebug("numBins = " + numBins)
157157

158158
/*Find the filters used before reaching the current code*/
159159
def findParentFilters(nodeIndex: Int): List[Filter] = {
160160
if (level == 0) {
161161
List[Filter]()
162162
} else {
163163
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
164-
//val parentFilterIndex = nodeFilterIndex / 2
165-
//TODO: Check left or right filter
166164
filters(nodeFilterIndex)
167165
}
168166
}
@@ -204,9 +202,9 @@ object DecisionTree extends Serializable with Logging {
204202
}
205203

206204
/*Finds the right bin for the given feature*/
207-
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = {
205+
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {
208206

209-
if (isFeatureContinous){
207+
if (isFeatureContinuous){
210208
//TODO: Do binary search
211209
for (binIndex <- 0 until strategy.numBins) {
212210
val bin = bins(featureIndex)(binIndex)
@@ -245,11 +243,11 @@ object DecisionTree extends Serializable with Logging {
245243
// calculating bin index and label per feature per node
246244
val arr = new Array[Double](1+(numFeatures * numNodes))
247245
arr(0) = labeledPoint.label
248-
for (index <- 0 until numNodes) {
249-
val parentFilters = findParentFilters(index)
246+
for (nodeIndex <- 0 until numNodes) {
247+
val parentFilters = findParentFilters(nodeIndex)
250248
//Find out whether the sample qualifies for the particular node
251249
val sampleValid = isSampleValid(parentFilters, labeledPoint)
252-
val shift = 1 + numFeatures * index
250+
val shift = 1 + numFeatures * nodeIndex
253251
if (!sampleValid) {
254252
//Add to invalid bin index -1
255253
for (featureIndex <- 0 until numFeatures) {
@@ -274,11 +272,11 @@ object DecisionTree extends Serializable with Logging {
274272
val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false
275273
if (isSampleValidForNode) {
276274
val label = arr(0)
277-
for (feature <- 0 until numFeatures) {
275+
for (featureIndex <- 0 until numFeatures) {
278276
val arrShift = 1 + numFeatures * node
279-
val aggShift = 2 * numSplits * numFeatures * node
280-
val arrIndex = arrShift + feature
281-
val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2
277+
val aggShift = 2 * numBins * numFeatures * node
278+
val arrIndex = arrShift + featureIndex
279+
val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
282280
label match {
283281
case (0.0) => agg(aggIndex) = agg(aggIndex) + 1
284282
case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
@@ -296,9 +294,9 @@ object DecisionTree extends Serializable with Logging {
296294
val label = arr(0)
297295
for (feature <- 0 until numFeatures) {
298296
val arrShift = 1 + numFeatures * node
299-
val aggShift = 3 * numSplits * numFeatures * node
297+
val aggShift = 3 * numBins * numFeatures * node
300298
val arrIndex = arrShift + feature
301-
val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3
299+
val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3
302300
//count, sum, sum^2
303301
agg(aggIndex) = agg(aggIndex) + 1
304302
agg(aggIndex + 1) = agg(aggIndex + 1) + label
@@ -318,7 +316,6 @@ object DecisionTree extends Serializable with Logging {
318316
@return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
319317
*/
320318
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
321-
//TODO: Requires logic for regressions
322319
strategy.algo match {
323320
case Classification => classificationBinSeqOp(arr, agg)
324321
//TODO: Implement this
@@ -327,10 +324,9 @@ object DecisionTree extends Serializable with Logging {
327324
agg
328325
}
329326

330-
//TODO: This length is different for regression
331327
val binAggregateLength = strategy.algo match {
332-
case Classification => 2*numSplits * numFeatures * numNodes
333-
case Regression => 3*numSplits * numFeatures * numNodes
328+
case Classification => 2*numBins * numFeatures * numNodes
329+
case Regression => 3*numBins * numFeatures * numNodes
334330
}
335331
logDebug("binAggregateLength = " + binAggregateLength)
336332

@@ -453,52 +449,52 @@ object DecisionTree extends Serializable with Logging {
453449
strategy.algo match {
454450
case Classification => {
455451

456-
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
457-
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
452+
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
453+
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
458454
for (featureIndex <- 0 until numFeatures) {
459-
val shift = 2*featureIndex*numSplits
455+
val shift = 2*featureIndex*numBins
460456
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
461457
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
462-
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
463-
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
464-
for (splitIndex <- 1 until numSplits - 1) {
458+
rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1)))
459+
rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1)
460+
for (splitIndex <- 1 until numBins - 1) {
465461
leftNodeAgg(featureIndex)(2 * splitIndex)
466462
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
467463
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
468464
= binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
469-
rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex))
470-
= binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex))
471-
rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1)
472-
= binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1)
465+
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
466+
= binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
467+
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1)
468+
= binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
473469
}
474470
}
475471
(leftNodeAgg, rightNodeAgg)
476472
}
477473
case Regression => {
478474

479-
val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1))
480-
val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1))
475+
val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
476+
val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
481477
for (featureIndex <- 0 until numFeatures) {
482-
val shift = 3*featureIndex*numSplits
478+
val shift = 3*featureIndex*numBins
483479
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
484480
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
485481
leftNodeAgg(featureIndex)(2) = binData(shift + 2)
486-
rightNodeAgg(featureIndex)(3 * (numSplits - 2)) = binData(shift + (3 * (numSplits - 1)))
487-
rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 1) = binData(shift + (3 * (numSplits - 1)) + 1)
488-
rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 2) = binData(shift + (3 * (numSplits - 1)) + 2)
489-
for (splitIndex <- 1 until numSplits - 1) {
482+
rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1)))
483+
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1)
484+
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2)
485+
for (splitIndex <- 1 until numBins - 1) {
490486
leftNodeAgg(featureIndex)(3 * splitIndex)
491487
= binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3)
492488
leftNodeAgg(featureIndex)(3 * splitIndex + 1)
493489
= binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
494490
leftNodeAgg(featureIndex)(3 * splitIndex + 2)
495491
= binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
496-
rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex))
497-
= binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex))
498-
rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1)
499-
= binData(shift + (3 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1)
500-
rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2)
501-
= binData(shift + (3 * (numSplits - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2)
492+
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
493+
= binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
494+
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1)
495+
= binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
496+
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2)
497+
= binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
502498
}
503499
}
504500
(leftNodeAgg, rightNodeAgg)
@@ -509,10 +505,10 @@ object DecisionTree extends Serializable with Logging {
509505
def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double)
510506
: Array[Array[InformationGainStats]] = {
511507

512-
val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1)
508+
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
513509

514510
for (featureIndex <- 0 until numFeatures) {
515-
for (index <- 0 until numSplits -1) {
511+
for (index <- 0 until numBins -1) {
516512
//logDebug("splitIndex = " + index)
517513
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
518514
}
@@ -521,10 +517,10 @@ object DecisionTree extends Serializable with Logging {
521517
}
522518

523519
/*
524-
Find the best split for a node given bin aggregate data
520+
Find the best split for a node given bin aggregate data
525521
526-
@param binData Array[Double] of size 2*numSplits*numFeatures
527-
*/
522+
@param binData Array[Double] of size 2*numSplits*numFeatures
523+
*/
528524
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = {
529525
logDebug("node impurity = " + nodeImpurity)
530526
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
@@ -536,7 +532,7 @@ object DecisionTree extends Serializable with Logging {
536532
//Initialization with infeasible values
537533
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1)
538534
for (featureIndex <- 0 until numFeatures) {
539-
for (splitIndex <- 0 until numSplits - 1){
535+
for (splitIndex <- 0 until numBins - 1){
540536
val gainStats = gains(featureIndex)(splitIndex)
541537
if(gainStats.gain > bestGainStats.gain) {
542538
bestGainStats = gainStats
@@ -556,13 +552,13 @@ object DecisionTree extends Serializable with Logging {
556552
def getBinDataForNode(node: Int): Array[Double] = {
557553
strategy.algo match {
558554
case Classification => {
559-
val shift = 2 * node * numSplits * numFeatures
560-
val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures)
555+
val shift = 2 * node * numBins * numFeatures
556+
val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
561557
binsForNode
562558
}
563559
case Regression => {
564-
val shift = 3 * node * numSplits * numFeatures
565-
val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures)
560+
val shift = 3 * node * numBins * numFeatures
561+
val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
566562
binsForNode
567563
}
568564
}

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,20 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
109109
//TODO: Test max feature value > num bins
110110

111111

112-
test("stump with all categorical variables"){
112+
test("classification stump with all categorical variables"){
113+
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
114+
assert(arr.length == 1000)
115+
val rdd = sc.parallelize(arr)
116+
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
117+
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
118+
strategy.numBins = 100
119+
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
120+
println(bestSplits(0)._1)
121+
println(bestSplits(0)._2)
122+
//TODO: Add asserts
123+
}
124+
125+
test("regression stump with all categorical variables"){
113126
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
114127
assert(arr.length == 1000)
115128
val rdd = sc.parallelize(arr)
@@ -123,7 +136,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
123136
}
124137

125138

126-
test("stump with fixed label 0 for Gini"){
139+
test("stump with fixed label 0 for Gini"){
127140
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
128141
assert(arr.length == 1000)
129142
val rdd = sc.parallelize(arr)

0 commit comments

Comments
 (0)