File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
mllib/src/main/scala/org/apache/spark/mllib/tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -816,7 +816,15 @@ object DecisionTree extends Serializable with Logging {
816
816
817
817
val maxBins = strategy.maxBins
818
818
val numBins = if (maxBins <= count) maxBins else count.toInt
819
- logDebug(" maxBins = " + numBins)
819
+ logDebug(" numBins = " + numBins)
820
+
821
+ // I will also add a require statement ensuring #bins is always greater than the categories
822
+ // It's a limitation of the current implementation but a reasonable tradeoff since features
823
+ // with large number of categories get favored over continuous features.
824
+ if (strategy.categoricalFeaturesInfo.size > 0 ){
825
+ val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
826
+ require(numBins >= maxCategoriesForFeatures)
827
+ }
820
828
821
829
// Calculate the number of sample for approximate quantile calculation
822
830
val requiredSamples = numBins* numBins
You can’t perform that action at this time.
0 commit comments