Skip to content

Commit 5c78e1a

Browse files
committed
added multiclass support
1 parent 6c7af22 commit 5c78e1a

File tree

5 files changed

+68
-56
lines changed

5 files changed

+68
-56
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -681,36 +681,47 @@ object DecisionTree extends Serializable with Logging {
681681
topImpurity: Double): InformationGainStats = {
682682
strategy.algo match {
683683
case Classification =>
684-
// TODO: Modify here
685-
val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0)
686-
val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1)
687-
val leftCount = left0Count + left1Count
688-
689-
val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0)
690-
val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1)
691-
val rightCount = right0Count + right1Count
684+
var classIndex = 0
685+
val leftCounts: Array[Double] = new Array[Double](numClasses)
686+
val rightCounts: Array[Double] = new Array[Double](numClasses)
687+
var leftTotalCount = 0.0
688+
var rightTotalCount = 0.0
689+
while (classIndex < numClasses) {
690+
val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
691+
val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
692+
leftCounts(classIndex) = leftClassCount
693+
leftTotalCount += leftClassCount
694+
rightCounts(classIndex) = rightClassCount
695+
rightTotalCount += rightClassCount
696+
classIndex += 1
697+
}
692698

693699
val impurity = {
694700
if (level > 0) {
695701
topImpurity
696702
} else {
697703
// Calculate impurity for root node.
698-
strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
704+
val rootNodeCounts = new Array[Double](numClasses)
705+
var classIndex = 0
706+
while (classIndex < numClasses) {
707+
rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
708+
}
709+
strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
699710
}
700711
}
701712

702-
if (leftCount == 0) {
713+
if (leftTotalCount == 0) {
703714
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
704715
}
705-
if (rightCount == 0) {
716+
if (rightTotalCount == 0) {
706717
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
707718
}
708719

709-
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
710-
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
720+
val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
721+
val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
711722

712-
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
713-
val rightWeight = rightCount.toDouble / (leftCount + rightCount)
723+
val leftWeight = leftTotalCount.toDouble / (leftTotalCount + rightTotalCount)
724+
val rightWeight = rightTotalCount.toDouble / (leftTotalCount + rightTotalCount)
714725

715726
val gain = {
716727
if (level > 0) {
@@ -720,7 +731,8 @@ object DecisionTree extends Serializable with Logging {
720731
}
721732
}
722733

723-
val predict = (left1Count + right1Count) / (leftCount + rightCount)
734+
//TODO: Make modification here
735+
val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount)
724736

725737
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
726738
case Regression =>
@@ -782,7 +794,6 @@ object DecisionTree extends Serializable with Logging {
782794
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
783795
strategy.algo match {
784796
case Classification =>
785-
// TODO: Multiclass modification here
786797

787798
// Initialize left and right split aggregates.
788799
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
@@ -793,17 +804,19 @@ object DecisionTree extends Serializable with Logging {
793804
while (featureIndex < numFeatures){
794805
val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
795806
val maxSplits = math.pow(2, numCategories) - 1
796-
var i = 0
797-
// TODO: Add multiclass case here
798-
while (i < maxSplits) {
807+
var splitIndex = 0
808+
while (splitIndex < maxSplits) {
799809
var classIndex = 0
800810
while (classIndex < numClasses) {
801811
// shift for this featureIndex
802812
val shift = numClasses * featureIndex * numBins
803-
813+
leftNodeAgg(featureIndex)(splitIndex)(classIndex)
814+
= binData(shift + classIndex)
815+
rightNodeAgg(featureIndex)(splitIndex)(classIndex)
816+
= binData(shift + numClasses + classIndex)
804817
classIndex += 1
805818
}
806-
i += 1
819+
splitIndex += 1
807820
}
808821
featureIndex += 1
809822
}
@@ -931,8 +944,6 @@ object DecisionTree extends Serializable with Logging {
931944
binData: Array[Double],
932945
nodeImpurity: Double): (Split, InformationGainStats) = {
933946

934-
// TODO: Multiclass modification here
935-
936947
logDebug("node impurity = " + nodeImpurity)
937948

938949
// Extract left right node aggregates.
@@ -977,9 +988,8 @@ object DecisionTree extends Serializable with Logging {
977988
def getBinDataForNode(node: Int): Array[Double] = {
978989
strategy.algo match {
979990
case Classification =>
980-
// TODO: Multiclass modification here
981-
val shift = 2 * node * numBins * numFeatures
982-
val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
991+
val shift = numClasses * node * numBins * numFeatures
992+
val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
983993
binsForNode
984994
case Regression =>
985995
val shift = 3 * node * numBins * numFeatures

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,22 @@ object Entropy extends Impurity {
3131

3232
/**
3333
* :: DeveloperApi ::
34-
* entropy calculation
35-
* @param c0 count of instances with label 0
36-
* @param c1 count of instances with label 1
37-
* @return entropy value
34+
* information calculation for multiclass classification
35+
* @param counts Array[Double] with counts for each label
36+
* @param totalCount sum of counts for all labels
37+
* @return information value
3838
*/
3939
@DeveloperApi
40-
override def calculate(c0: Double, c1: Double): Double = {
41-
if (c0 == 0 || c1 == 0) {
42-
0
43-
} else {
44-
val total = c0 + c1
45-
val f0 = c0 / total
46-
val f1 = c1 / total
47-
-(f0 * log2(f0)) - (f1 * log2(f1))
40+
override def calculate(counts: Array[Double], totalCount: Double): Double = {
41+
val numClasses = counts.length
42+
var impurity = 0.0
43+
var classIndex = 0
44+
while (classIndex < numClasses) {
45+
val freq = counts(classIndex) / totalCount
46+
impurity -= freq * log2(freq)
47+
classIndex += 1
4848
}
49+
impurity
4950
}
5051

5152
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,22 @@ object Gini extends Impurity {
3030

3131
/**
3232
* :: DeveloperApi ::
33-
* Gini coefficient calculation
34-
* @param c0 count of instances with label 0
35-
* @param c1 count of instances with label 1
36-
* @return Gini coefficient value
33+
* information calculation for multiclass classification
34+
* @param counts Array[Double] with counts for each label
35+
* @param totalCount sum of counts for all labels
36+
* @return information value
3737
*/
3838
@DeveloperApi
39-
override def calculate(c0: Double, c1: Double): Double = {
40-
if (c0 == 0 || c1 == 0) {
41-
0
42-
} else {
43-
val total = c0 + c1
44-
val f0 = c0 / total
45-
val f1 = c1 / total
46-
1 - f0 * f0 - f1 * f1
39+
override def calculate(counts: Array[Double], totalCount: Double): Double = {
40+
val numClasses = counts.length
41+
var impurity = 1.0
42+
var classIndex = 0
43+
while (classIndex < numClasses) {
44+
val freq = counts(classIndex) / totalCount
45+
impurity -= freq * freq
46+
classIndex += 1
4747
}
48+
impurity
4849
}
4950

5051
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ trait Impurity extends Serializable {
2828

2929
/**
3030
* :: DeveloperApi ::
31-
* information calculation for binary classification
32-
* @param c0 count of instances with label 0
33-
* @param c1 count of instances with label 1
31+
* information calculation for multiclass classification
32+
* @param counts Array[Double] with counts for each label
33+
* @param totalCount sum of counts for all labels
3434
* @return information value
3535
*/
3636
@DeveloperApi
37-
def calculate(c0 : Double, c1 : Double): Double
37+
def calculate(counts: Array[Double], totalCount: Double): Double
3838

3939
/**
4040
* :: DeveloperApi ::

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
2525
*/
2626
@Experimental
2727
object Variance extends Impurity {
28-
override def calculate(c0: Double, c1: Double): Double =
28+
override def calculate(counts: Array[Double], totalCounts: Double): Double =
2929
throw new UnsupportedOperationException("Variance.calculate")
3030

3131
/**

0 commit comments

Comments
 (0)