Skip to content

Commit 6068356

Browse files
committed
ensuring num bins is always greater than max number of categories
1 parent 62c2562 commit 6068356

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,15 @@ object DecisionTree extends Serializable with Logging {
816816

817817
val maxBins = strategy.maxBins
818818
val numBins = if (maxBins <= count) maxBins else count.toInt
819-
logDebug("maxBins = " + numBins)
819+
logDebug("numBins = " + numBins)
820+
821+
// I will also add a require statement ensuring #bins is always greater than the categories
822+
// It's a limitation of the current implementation but a reasonable tradeoff since features
823+
// with large number of categories get favored over continuous features.
824+
if (strategy.categoricalFeaturesInfo.size > 0){
825+
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
826+
require(numBins >= maxCategoriesForFeatures)
827+
}
820828

821829
// Calculate the number of sample for approximate quantile calculation
822830
val requiredSamples = numBins*numBins

0 commit comments

Comments
 (0)