Skip to content

Commit dbb7ac1

Browse files
committed
categorical feature support
Signed-off-by: Manish Amde <[email protected]>
1 parent 6df35b9 commit dbb7ac1

File tree

5 files changed

+185
-53
lines changed

5 files changed

+185
-53
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
3737
//Cache input RDD for speedup during multiple passes
3838
input.cache()
3939

40-
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)
40+
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
4141
logDebug("numSplits = " + bins(0).length)
4242
strategy.numBins = bins(0).length
4343

@@ -54,8 +54,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
5454

5555
logDebug("algo = " + strategy.algo)
5656

57-
58-
5957
breakable {
6058
for (level <- 0 until maxDepth){
6159

@@ -185,10 +183,21 @@ object DecisionTree extends Serializable with Logging {
185183
val featureIndex = filter.split.feature
186184
val threshold = filter.split.threshold
187185
val comparison = filter.comparison
188-
comparison match {
189-
case(-1) => if (features(featureIndex) > threshold) return false
190-
case(0) => if (features(featureIndex) != threshold) return false
191-
case(1) => if (features(featureIndex) <= threshold) return false
186+
val categories = filter.split.categories
187+
val isFeatureContinuous = filter.split.featureType == Continuous
188+
val feature = features(featureIndex)
189+
if (isFeatureContinuous){
190+
comparison match {
191+
case(-1) => if (feature > threshold) return false
192+
case(1) => if (feature <= threshold) return false
193+
}
194+
} else {
195+
val containsFeature = categories.contains(feature)
196+
comparison match {
197+
case(-1) => if (!containsFeature) return false
198+
case(1) => if (containsFeature) return false
199+
}
200+
192201
}
193202
}
194203
true
@@ -197,18 +206,34 @@ object DecisionTree extends Serializable with Logging {
197206
/*Finds the right bin for the given feature*/
198207
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
199208
//logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
200-
//TODO: Do binary search
201-
for (binIndex <- 0 until strategy.numBins) {
202-
val bin = bins(featureIndex)(binIndex)
203-
//TODO: Remove this requirement post basic functional
204-
val lowThreshold = bin.lowSplit.threshold
205-
val highThreshold = bin.highSplit.threshold
206-
val features = labeledPoint.features
207-
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
208-
return binIndex
209+
210+
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
211+
if (isFeatureContinous){
212+
//TODO: Do binary search
213+
for (binIndex <- 0 until strategy.numBins) {
214+
val bin = bins(featureIndex)(binIndex)
215+
//TODO: Remove this requirement post basic functional
216+
val lowThreshold = bin.lowSplit.threshold
217+
val highThreshold = bin.highSplit.threshold
218+
val features = labeledPoint.features
219+
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
220+
return binIndex
221+
}
222+
}
223+
throw new UnknownError("no bin was found for continuous variable.")
224+
} else {
225+
for (binIndex <- 0 until strategy.numBins) {
226+
val bin = bins(featureIndex)(binIndex)
227+
//TODO: Remove this requirement post basic functional
228+
val category = bin.category
229+
val features = labeledPoint.features
230+
if (category == features(featureIndex)) {
231+
return binIndex
232+
}
209233
}
234+
throw new UnknownError("no bin was found for categorical variable.")
235+
210236
}
211-
throw new UnknownError("no bin was found.")
212237

213238
}
214239

@@ -565,7 +590,7 @@ object DecisionTree extends Serializable with Logging {
565590
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
566591
Array[Array[Bin]] of size (numFeatures,numSplits1)
567592
*/
568-
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
593+
def findSplitsBins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
569594

570595
val count = input.count()
571596

@@ -603,31 +628,71 @@ object DecisionTree extends Serializable with Logging {
603628
logDebug("stride = " + stride)
604629
for (index <- 0 until numBins-1) {
605630
val sampleIndex = (index+1)*stride.toInt
606-
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous)
631+
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List())
607632
splits(featureIndex)(index) = split
608633
}
609634
} else {
610635
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
611-
for (index <- 0 until maxFeatureValue){
612-
//TODO: Sort by centriod
613-
val split = new Split(featureIndex,index,Categorical)
614-
splits(featureIndex)(index) = split
636+
637+
require(maxFeatureValue < numBins, "number of categories should be less than number of bins")
638+
639+
val centriodForCategories
640+
= sampledInput.map(lp => (lp.features(featureIndex),lp.label))
641+
.groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length)
642+
643+
//Checking for missing categorical variables
644+
val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]()
645+
for (i <- 0 until maxFeatureValue){
646+
if (centriodForCategories.contains(i)){
647+
fullCentriodForCategories(i) = centriodForCategories(i)
648+
} else {
649+
fullCentriodForCategories(i) = Double.MaxValue
650+
}
651+
}
652+
653+
val categoriesSortedByCentriod
654+
= fullCentriodForCategories.toList sortBy {_._2}
655+
656+
logDebug("centriod for categorical variable = " + categoriesSortedByCentriod)
657+
658+
var categoriesForSplit = List[Double]()
659+
categoriesSortedByCentriod.iterator.zipWithIndex foreach {
660+
case((key, value), index) => {
661+
categoriesForSplit = key :: categoriesForSplit
662+
splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical,categoriesForSplit)
663+
bins(featureIndex)(index) = {
664+
if(index == 0) {
665+
new Bin(new DummyCategoricalSplit(featureIndex,Categorical),splits(featureIndex)(0),Categorical,key)
666+
}
667+
else {
668+
new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Categorical,key)
669+
}
670+
}
671+
}
615672
}
616673
}
617674
}
618675

619676
//Find all bins
620677
for (featureIndex <- 0 until numFeatures){
621-
bins(featureIndex)(0)
622-
= new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous)
623-
for (index <- 1 until numBins - 1){
624-
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous)
625-
bins(featureIndex)(index) = bin
678+
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
679+
if (isFeatureContinous) { //bins for categorical variables are already assigned
680+
bins(featureIndex)(0)
681+
= new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue)
682+
for (index <- 1 until numBins - 1){
683+
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous,Double.MinValue)
684+
bins(featureIndex)(index) = bin
685+
}
686+
bins(featureIndex)(numBins-1)
687+
= new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous),Continuous,Double.MinValue)
688+
} else {
689+
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
690+
for (i <- maxFeatureValue until numBins){
691+
bins(featureIndex)(i)
692+
= new Bin(new DummyCategoricalSplit(featureIndex,Categorical),new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue)
693+
}
626694
}
627-
bins(featureIndex)(numBins-1)
628-
= new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(Continuous),Continuous)
629695
}
630-
631696
(splits,bins)
632697
}
633698
case MinMax => {

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ package org.apache.spark.mllib.tree.model
1818

1919
import org.apache.spark.mllib.tree.configuration.FeatureType._
2020

21-
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) {
21+
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) {
2222

2323
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model
1818

1919
import org.apache.spark.Logging
2020
import org.apache.spark.mllib.regression.LabeledPoint
21+
import org.apache.spark.mllib.tree.configuration.FeatureType._
2122

2223
class Node ( val id : Int,
2324
val predict : Double,
@@ -49,10 +50,18 @@ class Node ( val id : Int,
4950
if (isLeaf) {
5051
predict
5152
} else{
52-
if (feature(split.get.feature) <= split.get.threshold) {
53-
leftNode.get.predictIfLeaf(feature)
53+
if (split.get.featureType == Continuous) {
54+
if (feature(split.get.feature) <= split.get.threshold) {
55+
leftNode.get.predictIfLeaf(feature)
56+
} else {
57+
rightNode.get.predictIfLeaf(feature)
58+
}
5459
} else {
55-
rightNode.get.predictIfLeaf(feature)
60+
if (split.get.categories.contains(feature(split.get.feature))) {
61+
leftNode.get.predictIfLeaf(feature)
62+
} else {
63+
rightNode.get.predictIfLeaf(feature)
64+
}
5665
}
5766
}
5867
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ package org.apache.spark.mllib.tree.model
1818

1919
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
2020

21-
case class Split(feature: Int, threshold : Double, featureType : FeatureType){
22-
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType
21+
case class Split(feature: Int, threshold : Double, featureType : FeatureType, categories : List[Double]){
22+
override def toString =
23+
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories
2324
}
2425

25-
class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind)
26+
class DummyLowSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MinValue, kind, List())
2627

27-
class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind)
28+
class DummyHighSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List())
29+
30+
class DummyCategoricalSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List())
2831

0 commit comments

Comments
 (0)