@@ -367,18 +367,18 @@ object DecisionTree extends Serializable with Logging {
367
367
368
368
def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]],
369
369
featureIndex : Int ,
370
- index : Int ,
370
+ splitIndex : Int ,
371
371
rightNodeAgg : Array [Array [Double ]],
372
372
topImpurity : Double ) : InformationGainStats = {
373
373
strategy.algo match {
374
374
case Classification => {
375
375
376
- val left0Count = leftNodeAgg(featureIndex)(2 * index )
377
- val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
376
+ val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex )
377
+ val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
378
378
val leftCount = left0Count + left1Count
379
379
380
- val right0Count = rightNodeAgg(featureIndex)(2 * index )
381
- val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
380
+ val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex )
381
+ val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1 )
382
382
val rightCount = right0Count + right1Count
383
383
384
384
val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
@@ -405,13 +405,13 @@ object DecisionTree extends Serializable with Logging {
405
405
new InformationGainStats (gain,impurity,leftImpurity,rightImpurity,predict)
406
406
}
407
407
case Regression => {
408
- val leftCount = leftNodeAgg(featureIndex)(3 * index )
409
- val leftSum = leftNodeAgg(featureIndex)(3 * index + 1 )
410
- val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2 )
408
+ val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex )
409
+ val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1 )
410
+ val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
411
411
412
- val rightCount = rightNodeAgg(featureIndex)(3 * index )
413
- val rightSum = rightNodeAgg(featureIndex)(3 * index + 1 )
414
- val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2 )
412
+ val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex )
413
+ val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1 )
414
+ val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2 )
415
415
416
416
val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)
417
417
@@ -463,9 +463,9 @@ object DecisionTree extends Serializable with Logging {
463
463
leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
464
464
= binData(shift + 2 * splitIndex + 1 ) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1 )
465
465
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
466
- = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
466
+ = binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
467
467
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1 )
468
- = binData(shift + (2 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
468
+ = binData(shift + (2 * (numBins - 2 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
469
469
}
470
470
}
471
471
(leftNodeAgg, rightNodeAgg)
@@ -490,11 +490,11 @@ object DecisionTree extends Serializable with Logging {
490
490
leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
491
491
= binData(shift + 3 * splitIndex + 2 ) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2 )
492
492
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
493
- = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
493
+ = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
494
494
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1 )
495
- = binData(shift + (3 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1 )
495
+ = binData(shift + (3 * (numBins - 2 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1 )
496
496
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2 )
497
- = binData(shift + (3 * (numBins - 1 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2 )
497
+ = binData(shift + (3 * (numBins - 2 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2 )
498
498
}
499
499
}
500
500
(leftNodeAgg, rightNodeAgg)
@@ -508,9 +508,9 @@ object DecisionTree extends Serializable with Logging {
508
508
val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
509
509
510
510
for (featureIndex <- 0 until numFeatures) {
511
- for (index <- 0 until numBins - 1 ) {
511
+ for (splitIndex <- 0 until numBins - 1 ) {
512
512
// logDebug("splitIndex = " + index)
513
- gains(featureIndex)(index ) = calculateGainForSplit(leftNodeAgg, featureIndex, index , rightNodeAgg, nodeImpurity)
513
+ gains(featureIndex)(splitIndex ) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex , rightNodeAgg, nodeImpurity)
514
514
}
515
515
}
516
516
gains
@@ -544,6 +544,8 @@ object DecisionTree extends Serializable with Logging {
544
544
(bestFeatureIndex,bestSplitIndex,bestGainStats)
545
545
}
546
546
547
+ logDebug(" best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
548
+ logDebug(" best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
547
549
(splits(bestFeatureIndex)(bestSplitIndex),gainStats)
548
550
}
549
551
@@ -614,13 +616,14 @@ object DecisionTree extends Serializable with Logging {
614
616
615
617
// Find all splits
616
618
for (featureIndex <- 0 until numFeatures){
617
- val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
618
- if (isFeatureContinous ) {
619
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
620
+ if (isFeatureContinuous ) {
619
621
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
620
622
621
623
val stride : Double = numSamples.toDouble/ numBins
622
624
logDebug(" stride = " + stride)
623
625
for (index <- 0 until numBins- 1 ) {
626
+ // TODO: Investigate this
624
627
val sampleIndex = (index+ 1 )* stride.toInt
625
628
val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous , List ())
626
629
splits(featureIndex)(index) = split
0 commit comments