@@ -37,7 +37,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
37
37
// Cache input RDD for speedup during multiple passes
38
38
input.cache()
39
39
40
- val (splits, bins) = DecisionTree .find_splits_bins (input, strategy)
40
+ val (splits, bins) = DecisionTree .findSplitsBins (input, strategy)
41
41
logDebug(" numSplits = " + bins(0 ).length)
42
42
strategy.numBins = bins(0 ).length
43
43
@@ -54,8 +54,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
54
54
55
55
logDebug(" algo = " + strategy.algo)
56
56
57
-
58
-
59
57
breakable {
60
58
for (level <- 0 until maxDepth){
61
59
@@ -185,10 +183,21 @@ object DecisionTree extends Serializable with Logging {
185
183
val featureIndex = filter.split.feature
186
184
val threshold = filter.split.threshold
187
185
val comparison = filter.comparison
188
- comparison match {
189
- case (- 1 ) => if (features(featureIndex) > threshold) return false
190
- case (0 ) => if (features(featureIndex) != threshold) return false
191
- case (1 ) => if (features(featureIndex) <= threshold) return false
186
+ val categories = filter.split.categories
187
+ val isFeatureContinuous = filter.split.featureType == Continuous
188
+ val feature = features(featureIndex)
189
+ if (isFeatureContinuous){
190
+ comparison match {
191
+ case (- 1 ) => if (feature > threshold) return false
192
+ case (1 ) => if (feature <= threshold) return false
193
+ }
194
+ } else {
195
+ val containsFeature = categories.contains(feature)
196
+ comparison match {
197
+ case (- 1 ) => if (! containsFeature) return false
198
+ case (1 ) => if (containsFeature) return false
199
+ }
200
+
192
201
}
193
202
}
194
203
true
@@ -197,18 +206,34 @@ object DecisionTree extends Serializable with Logging {
197
206
/* Finds the right bin for the given feature*/
198
207
def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
199
208
// logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
200
- // TODO: Do binary search
201
- for (binIndex <- 0 until strategy.numBins) {
202
- val bin = bins(featureIndex)(binIndex)
203
- // TODO: Remove this requirement post basic functional
204
- val lowThreshold = bin.lowSplit.threshold
205
- val highThreshold = bin.highSplit.threshold
206
- val features = labeledPoint.features
207
- if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
208
- return binIndex
209
+
210
+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
211
+ if (isFeatureContinous){
212
+ // TODO: Do binary search
213
+ for (binIndex <- 0 until strategy.numBins) {
214
+ val bin = bins(featureIndex)(binIndex)
215
+ // TODO: Remove this requirement post basic functional
216
+ val lowThreshold = bin.lowSplit.threshold
217
+ val highThreshold = bin.highSplit.threshold
218
+ val features = labeledPoint.features
219
+ if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
220
+ return binIndex
221
+ }
222
+ }
223
+ throw new UnknownError (" no bin was found for continuous variable." )
224
+ } else {
225
+ for (binIndex <- 0 until strategy.numBins) {
226
+ val bin = bins(featureIndex)(binIndex)
227
+ // TODO: Remove this requirement post basic functional
228
+ val category = bin.category
229
+ val features = labeledPoint.features
230
+ if (category == features(featureIndex)) {
231
+ return binIndex
232
+ }
209
233
}
234
+ throw new UnknownError (" no bin was found for categorical variable." )
235
+
210
236
}
211
- throw new UnknownError (" no bin was found." )
212
237
213
238
}
214
239
@@ -565,7 +590,7 @@ object DecisionTree extends Serializable with Logging {
565
590
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
566
591
Array[Array[Bin]] of size (numFeatures,numSplits1)
567
592
*/
568
- def find_splits_bins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
593
+ def findSplitsBins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
569
594
570
595
val count = input.count()
571
596
@@ -603,31 +628,71 @@ object DecisionTree extends Serializable with Logging {
603
628
logDebug(" stride = " + stride)
604
629
for (index <- 0 until numBins- 1 ) {
605
630
val sampleIndex = (index+ 1 )* stride.toInt
606
- val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous )
631
+ val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous , List () )
607
632
splits(featureIndex)(index) = split
608
633
}
609
634
} else {
610
635
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
611
- for (index <- 0 until maxFeatureValue){
612
- // TODO: Sort by centriod
613
- val split = new Split (featureIndex,index,Categorical )
614
- splits(featureIndex)(index) = split
636
+
637
+ require(maxFeatureValue < numBins, " number of categories should be less than number of bins" )
638
+
639
+ val centriodForCategories
640
+ = sampledInput.map(lp => (lp.features(featureIndex),lp.label))
641
+ .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length)
642
+
643
+ // Checking for missing categorical variables
644
+ val fullCentriodForCategories = scala.collection.mutable.Map [Double ,Double ]()
645
+ for (i <- 0 until maxFeatureValue){
646
+ if (centriodForCategories.contains(i)){
647
+ fullCentriodForCategories(i) = centriodForCategories(i)
648
+ } else {
649
+ fullCentriodForCategories(i) = Double .MaxValue
650
+ }
651
+ }
652
+
653
+ val categoriesSortedByCentriod
654
+ = fullCentriodForCategories.toList sortBy {_._2}
655
+
656
+ logDebug(" centriod for categorical variable = " + categoriesSortedByCentriod)
657
+
658
+ var categoriesForSplit = List [Double ]()
659
+ categoriesSortedByCentriod.iterator.zipWithIndex foreach {
660
+ case ((key, value), index) => {
661
+ categoriesForSplit = key :: categoriesForSplit
662
+ splits(featureIndex)(index) = new Split (featureIndex,Double .MinValue ,Categorical ,categoriesForSplit)
663
+ bins(featureIndex)(index) = {
664
+ if (index == 0 ) {
665
+ new Bin (new DummyCategoricalSplit (featureIndex,Categorical ),splits(featureIndex)(0 ),Categorical ,key)
666
+ }
667
+ else {
668
+ new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Categorical ,key)
669
+ }
670
+ }
671
+ }
615
672
}
616
673
}
617
674
}
618
675
619
676
// Find all bins
620
677
for (featureIndex <- 0 until numFeatures){
621
- bins(featureIndex)(0 )
622
- = new Bin (new DummyLowSplit (Continuous ),splits(featureIndex)(0 ),Continuous )
623
- for (index <- 1 until numBins - 1 ){
624
- val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Continuous )
625
- bins(featureIndex)(index) = bin
678
+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
679
+ if (isFeatureContinous) { // bins for categorical variables are already assigned
680
+ bins(featureIndex)(0 )
681
+ = new Bin (new DummyLowSplit (featureIndex, Continuous ),splits(featureIndex)(0 ),Continuous ,Double .MinValue )
682
+ for (index <- 1 until numBins - 1 ){
683
+ val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Continuous ,Double .MinValue )
684
+ bins(featureIndex)(index) = bin
685
+ }
686
+ bins(featureIndex)(numBins- 1 )
687
+ = new Bin (splits(featureIndex)(numBins- 2 ),new DummyHighSplit (featureIndex, Continuous ),Continuous ,Double .MinValue )
688
+ } else {
689
+ val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
690
+ for (i <- maxFeatureValue until numBins){
691
+ bins(featureIndex)(i)
692
+ = new Bin (new DummyCategoricalSplit (featureIndex,Categorical ),new DummyCategoricalSplit (featureIndex,Categorical ),Categorical ,Double .MaxValue )
693
+ }
626
694
}
627
- bins(featureIndex)(numBins- 1 )
628
- = new Bin (splits(featureIndex)(numBins- 2 ),new DummyHighSplit (Continuous ),Continuous )
629
695
}
630
-
631
696
(splits,bins)
632
697
}
633
698
case MinMax => {
0 commit comments