@@ -681,36 +681,47 @@ object DecisionTree extends Serializable with Logging {
681
681
topImpurity : Double ): InformationGainStats = {
682
682
strategy.algo match {
683
683
case Classification =>
684
- // TODO: Modify here
685
- val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0 )
686
- val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1 )
687
- val leftCount = left0Count + left1Count
688
-
689
- val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0 )
690
- val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1 )
691
- val rightCount = right0Count + right1Count
684
+ var classIndex = 0
685
+ val leftCounts : Array [Double ] = new Array [Double ](numClasses)
686
+ val rightCounts : Array [Double ] = new Array [Double ](numClasses)
687
+ var leftTotalCount = 0.0
688
+ var rightTotalCount = 0.0
689
+ while (classIndex < numClasses) {
690
+ val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
691
+ val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
692
+ leftCounts(classIndex) = leftClassCount
693
+ leftTotalCount += leftClassCount
694
+ rightCounts(classIndex) = rightClassCount
695
+ rightTotalCount += rightClassCount
696
+ classIndex += 1
697
+ }
692
698
693
699
val impurity = {
694
700
if (level > 0 ) {
695
701
topImpurity
696
702
} else {
697
703
// Calculate impurity for root node.
698
- strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
704
+ val rootNodeCounts = new Array [Double ](numClasses)
705
+ var classIndex = 0
706
+ while (classIndex < numClasses) {
707
+ rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
708
+ }
709
+ strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
699
710
}
700
711
}
701
712
702
- if (leftCount == 0 ) {
713
+ if (leftTotalCount == 0 ) {
703
714
return new InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity,1 )
704
715
}
705
- if (rightCount == 0 ) {
716
+ if (rightTotalCount == 0 ) {
706
717
return new InformationGainStats (0 , topImpurity, topImpurity, Double .MinValue ,0 )
707
718
}
708
719
709
- val leftImpurity = strategy.impurity.calculate(left0Count, left1Count )
710
- val rightImpurity = strategy.impurity.calculate(right0Count, right1Count )
720
+ val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount )
721
+ val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount )
711
722
712
- val leftWeight = leftCount .toDouble / (leftCount + rightCount )
713
- val rightWeight = rightCount .toDouble / (leftCount + rightCount )
723
+ val leftWeight = leftTotalCount .toDouble / (leftTotalCount + rightTotalCount )
724
+ val rightWeight = rightTotalCount .toDouble / (leftTotalCount + rightTotalCount )
714
725
715
726
val gain = {
716
727
if (level > 0 ) {
@@ -720,7 +731,8 @@ object DecisionTree extends Serializable with Logging {
720
731
}
721
732
}
722
733
723
- val predict = (left1Count + right1Count) / (leftCount + rightCount)
734
+ // TODO: Make modification here
735
+ val predict = (leftCounts(1 ) + rightCounts(1 )) / (leftTotalCount + rightTotalCount)
724
736
725
737
new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict)
726
738
case Regression =>
@@ -782,7 +794,6 @@ object DecisionTree extends Serializable with Logging {
782
794
binData : Array [Double ]): (Array [Array [Array [Double ]]], Array [Array [Array [Double ]]]) = {
783
795
strategy.algo match {
784
796
case Classification =>
785
- // TODO: Multiclass modification here
786
797
787
798
// Initialize left and right split aggregates.
788
799
val leftNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
@@ -793,17 +804,19 @@ object DecisionTree extends Serializable with Logging {
793
804
while (featureIndex < numFeatures){
794
805
val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
795
806
val maxSplits = math.pow(2 , numCategories) - 1
796
- var i = 0
797
- // TODO: Add multiclass case here
798
- while (i < maxSplits) {
807
+ var splitIndex = 0
808
+ while (splitIndex < maxSplits) {
799
809
var classIndex = 0
800
810
while (classIndex < numClasses) {
801
811
// shift for this featureIndex
802
812
val shift = numClasses * featureIndex * numBins
803
-
813
+ leftNodeAgg(featureIndex)(splitIndex)(classIndex)
814
+ = binData(shift + classIndex)
815
+ rightNodeAgg(featureIndex)(splitIndex)(classIndex)
816
+ = binData(shift + numClasses + classIndex)
804
817
classIndex += 1
805
818
}
806
- i += 1
819
+ splitIndex += 1
807
820
}
808
821
featureIndex += 1
809
822
}
@@ -931,8 +944,6 @@ object DecisionTree extends Serializable with Logging {
931
944
binData : Array [Double ],
932
945
nodeImpurity : Double ): (Split , InformationGainStats ) = {
933
946
934
- // TODO: Multiclass modification here
935
-
936
947
logDebug(" node impurity = " + nodeImpurity)
937
948
938
949
// Extract left right node aggregates.
@@ -977,9 +988,8 @@ object DecisionTree extends Serializable with Logging {
977
988
def getBinDataForNode (node : Int ): Array [Double ] = {
978
989
strategy.algo match {
979
990
case Classification =>
980
- // TODO: Multiclass modification here
981
- val shift = 2 * node * numBins * numFeatures
982
- val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
991
+ val shift = numClasses * node * numBins * numFeatures
992
+ val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
983
993
binsForNode
984
994
case Regression =>
985
995
val shift = 3 * node * numBins * numFeatures
0 commit comments