Skip to content

Commit b09dc98

Browse files
committed
minor refactoring
Signed-off-by: Manish Amde <[email protected]>
1 parent 6b7de78 commit b09dc98

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -367,18 +367,18 @@ object DecisionTree extends Serializable with Logging {
367367

368368
def calculateGainForSplit(leftNodeAgg: Array[Array[Double]],
369369
featureIndex: Int,
370-
index: Int,
370+
splitIndex: Int,
371371
rightNodeAgg: Array[Array[Double]],
372372
topImpurity: Double) : InformationGainStats = {
373373
strategy.algo match {
374374
case Classification => {
375375

376-
val left0Count = leftNodeAgg(featureIndex)(2 * index)
377-
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
376+
val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
377+
val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
378378
val leftCount = left0Count + left1Count
379379

380-
val right0Count = rightNodeAgg(featureIndex)(2 * index)
381-
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
380+
val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
381+
val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
382382
val rightCount = right0Count + right1Count
383383

384384
val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
@@ -405,13 +405,13 @@ object DecisionTree extends Serializable with Logging {
405405
new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict)
406406
}
407407
case Regression => {
408-
val leftCount = leftNodeAgg(featureIndex)(3 * index)
409-
val leftSum = leftNodeAgg(featureIndex)(3 * index + 1)
410-
val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2)
408+
val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
409+
val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
410+
val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
411411

412-
val rightCount = rightNodeAgg(featureIndex)(3 * index)
413-
val rightSum = rightNodeAgg(featureIndex)(3 * index + 1)
414-
val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2)
412+
val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
413+
val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
414+
val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
415415

416416
val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)
417417

@@ -463,9 +463,9 @@ object DecisionTree extends Serializable with Logging {
463463
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
464464
= binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
465465
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
466-
= binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
466+
= binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
467467
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1)
468-
= binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
468+
= binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
469469
}
470470
}
471471
(leftNodeAgg, rightNodeAgg)
@@ -490,11 +490,11 @@ object DecisionTree extends Serializable with Logging {
490490
leftNodeAgg(featureIndex)(3 * splitIndex + 2)
491491
= binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
492492
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
493-
= binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
493+
= binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
494494
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1)
495-
= binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
495+
= binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
496496
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2)
497-
= binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
497+
= binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
498498
}
499499
}
500500
(leftNodeAgg, rightNodeAgg)
@@ -508,9 +508,9 @@ object DecisionTree extends Serializable with Logging {
508508
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
509509

510510
for (featureIndex <- 0 until numFeatures) {
511-
for (index <- 0 until numBins -1) {
511+
for (splitIndex <- 0 until numBins -1) {
512512
//logDebug("splitIndex = " + index)
513-
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
513+
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity)
514514
}
515515
}
516516
gains
@@ -544,6 +544,8 @@ object DecisionTree extends Serializable with Logging {
544544
(bestFeatureIndex,bestSplitIndex,bestGainStats)
545545
}
546546

547+
logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
548+
logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
547549
(splits(bestFeatureIndex)(bestSplitIndex),gainStats)
548550
}
549551

@@ -614,13 +616,14 @@ object DecisionTree extends Serializable with Logging {
614616

615617
//Find all splits
616618
for (featureIndex <- 0 until numFeatures){
617-
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
618-
if (isFeatureContinous) {
619+
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
620+
if (isFeatureContinuous) {
619621
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
620622

621623
val stride : Double = numSamples.toDouble/numBins
622624
logDebug("stride = " + stride)
623625
for (index <- 0 until numBins-1) {
626+
//TODO: Investigate this
624627
val sampleIndex = (index+1)*stride.toInt
625628
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List())
626629
splits(featureIndex)(index) = split

0 commit comments

Comments
 (0)