@@ -547,16 +547,18 @@ object DecisionTree extends Serializable with Logging {
547
547
548
548
/**
549
549
* Sequential search helper method to find bin for categorical feature in multiclass
550
- * classification. Dummy value of 0 used since it is not used in future calculation
550
+ * classification. The category is returned since each category can belong to multiple
551
+ * splits. The actual left/right child allocation per split is performed in the
552
+ * sequential phase of the bin aggregate operation.
551
553
*/
552
- def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = {
554
+ def sequentialBinSearchForCategoricalFeatureInMulticlassClassification (): Int = {
553
555
labeledPoint.features(featureIndex).toInt
554
556
}
555
557
556
558
/**
557
559
* Sequential search helper method to find bin for categorical feature.
558
560
*/
559
- def sequentialBinSearchForCategoricalFeatureInMultiClassClassification (): Int = {
561
+ def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = {
560
562
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
561
563
val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
562
564
var binIndex = 0
@@ -583,9 +585,9 @@ object DecisionTree extends Serializable with Logging {
583
585
// Perform sequential search to find bin for categorical features.
584
586
val binIndex = {
585
587
if (isMulticlassClassification) {
586
- sequentialBinSearchForCategoricalFeatureInBinaryClassification ()
588
+ sequentialBinSearchForCategoricalFeatureInMulticlassClassification ()
587
589
} else {
588
- sequentialBinSearchForCategoricalFeatureInMultiClassClassification ()
590
+ sequentialBinSearchForCategoricalFeatureInBinaryClassification ()
589
591
}
590
592
}
591
593
if (binIndex == - 1 ){
@@ -684,7 +686,7 @@ object DecisionTree extends Serializable with Logging {
684
686
* @return Array[Double] storing aggregate calculation of size
685
687
* 2 * numSplits * numFeatures * numNodes for classification
686
688
*/
687
- def binaryClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
689
+ def orderedClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
688
690
// Iterate over all nodes.
689
691
var nodeIndex = 0
690
692
while (nodeIndex < numNodes) {
@@ -716,7 +718,7 @@ object DecisionTree extends Serializable with Logging {
716
718
* @return Array[Double] storing aggregate calculation of size
717
719
* 2 * numClasses * numSplits * numFeatures * numNodes for classification
718
720
*/
719
- def multiClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
721
+ def unorderedClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
720
722
// Iterate over all nodes.
721
723
var nodeIndex = 0
722
724
while (nodeIndex < numNodes) {
@@ -789,9 +791,9 @@ object DecisionTree extends Serializable with Logging {
789
791
strategy.algo match {
790
792
case Classification =>
791
793
if (isMulticlassClassificationWithCategoricalFeatures) {
792
- multiClassificationBinSeqOp (arr, agg)
794
+ unorderedClassificationBinSeqOp (arr, agg)
793
795
} else {
794
- binaryClassificationBinSeqOp (arr, agg)
796
+ orderedClassificationBinSeqOp (arr, agg)
795
797
}
796
798
case Regression => regressionBinSeqOp(arr, agg)
797
799
}
0 commit comments