Skip to content

Commit b0e3e76

Browse files
committed
adding enum for feature type
Signed-off-by: Manish Amde <[email protected]>
1 parent 154aa77 commit b0e3e76

File tree

5 files changed

+40
-31
lines changed

5 files changed

+40
-31
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
2626
import scala.util.control.Breaks._
2727
import org.apache.spark.mllib.tree.configuration.Strategy
2828
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
29+
import org.apache.spark.mllib.tree.configuration.FeatureType._
2930

3031

3132
class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
@@ -353,21 +354,13 @@ object DecisionTree extends Serializable with Logging {
353354
def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
354355
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
355356
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
356-
//logDebug("binData.length = " + binData.length)
357-
//logDebug("binData.sum = " + binData.sum)
358357
for (featureIndex <- 0 until numFeatures) {
359-
//logDebug("featureIndex = " + featureIndex)
360358
val shift = 2*featureIndex*numSplits
361359
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
362-
//logDebug("binData(shift + 0) = " + binData(shift + 0))
363360
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
364-
//logDebug("binData(shift + 1) = " + binData(shift + 1))
365361
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
366-
//logDebug(binData(shift + (2 * (numSplits - 1))))
367362
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
368-
//logDebug(binData(shift + (2 * (numSplits - 1)) + 1))
369363
for (splitIndex <- 1 until numSplits - 1) {
370-
//logDebug("splitIndex = " + splitIndex)
371364
leftNodeAgg(featureIndex)(2 * splitIndex)
372365
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
373366
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
@@ -479,33 +472,43 @@ object DecisionTree extends Serializable with Logging {
479472

480473
//Find all splits
481474
for (featureIndex <- 0 until numFeatures){
482-
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
483-
484-
val stride : Double = numSamples.toDouble/numBins
485-
logDebug("stride = " + stride)
486-
for (index <- 0 until numBins-1) {
487-
val sampleIndex = (index+1)*stride.toInt
488-
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
489-
splits(featureIndex)(index) = split
475+
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
476+
if (isFeatureContinous) {
477+
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
478+
479+
val stride : Double = numSamples.toDouble/numBins
480+
logDebug("stride = " + stride)
481+
for (index <- 0 until numBins-1) {
482+
val sampleIndex = (index+1)*stride.toInt
483+
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous)
484+
splits(featureIndex)(index) = split
485+
}
486+
} else {
487+
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
488+
for (index <- 0 until maxFeatureValue){
489+
//TODO: Sort by centriod
490+
val split = new Split(featureIndex,index,Categorical)
491+
splits(featureIndex)(index) = split
492+
}
490493
}
491494
}
492495

493496
//Find all bins
494497
for (featureIndex <- 0 until numFeatures){
495498
bins(featureIndex)(0)
496-
= new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous")
499+
= new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous)
497500
for (index <- 1 until numBins - 1){
498-
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous")
501+
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous)
499502
bins(featureIndex)(index) = bin
500503
}
501504
bins(featureIndex)(numBins-1)
502-
= new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit("continuous"),"continuous")
505+
= new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit(Continuous),Continuous)
503506
}
504507

505508
(splits,bins)
506509
}
507510
case MinMax => {
508-
(Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2))
511+
throw new UnsupportedOperationException("minmax not supported yet.")
509512
}
510513
case ApproxHist => {
511514
throw new UnsupportedOperationException("approximate histogram not supported yet.")

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class Strategy (
2525
val impurity : Impurity,
2626
val maxDepth : Int,
2727
val maxBins : Int,
28-
val quantileCalculationStrategy : QuantileStrategy = Sort) extends Serializable {
28+
val quantileCalculationStrategy : QuantileStrategy = Sort,
29+
val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable {
2930

3031
var numBins : Int = Int.MinValue
3132

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19-
case class Bin(lowSplit : Split, highSplit : Split, kind : String) {
19+
import org.apache.spark.mllib.tree.configuration.FeatureType._
20+
21+
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) {
2022

2123
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19-
case class Split(feature: Int, threshold : Double, kind : String){
20-
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", kind = " + kind
19+
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
20+
21+
case class Split(feature: Int, threshold : Double, featureType : FeatureType){
22+
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType
2123
}
2224

23-
class DummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind)
25+
class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind)
2426

25-
class DummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind)
27+
class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind)
2628

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
3030
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
3131
import org.apache.spark.mllib.tree.model.Filter
3232
import org.apache.spark.mllib.tree.configuration.Strategy
33+
import org.apache.spark.mllib.tree.configuration.Algo._
3334

3435
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
3536

@@ -48,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
4849
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
4950
assert(arr.length == 1000)
5051
val rdd = sc.parallelize(arr)
51-
val strategy = new Strategy("regression",Gini,3,100,"sort")
52+
val strategy = new Strategy(Regression,Gini,3,100)
5253
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
5354
assert(splits.length==2)
5455
assert(bins.length==2)
@@ -61,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6162
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
6263
assert(arr.length == 1000)
6364
val rdd = sc.parallelize(arr)
64-
val strategy = new Strategy("regression",Gini,3,100,"sort")
65+
val strategy = new Strategy(Regression,Gini,3,100)
6566
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
6667
assert(splits.length==2)
6768
assert(splits(0).length==99)
@@ -87,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
8788
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
8889
assert(arr.length == 1000)
8990
val rdd = sc.parallelize(arr)
90-
val strategy = new Strategy("regression",Gini,3,100,"sort")
91+
val strategy = new Strategy(Regression,Gini,3,100)
9192
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
9293
assert(splits.length==2)
9394
assert(splits(0).length==99)
@@ -113,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
113114
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
114115
assert(arr.length == 1000)
115116
val rdd = sc.parallelize(arr)
116-
val strategy = new Strategy("regression",Entropy,3,100,"sort")
117+
val strategy = new Strategy(Regression,Entropy,3,100)
117118
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
118119
assert(splits.length==2)
119120
assert(splits(0).length==99)
@@ -138,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
138139
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
139140
assert(arr.length == 1000)
140141
val rdd = sc.parallelize(arr)
141-
val strategy = new Strategy("regression",Entropy,3,100,"sort")
142+
val strategy = new Strategy(Regression,Entropy,3,100)
142143
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
143144
assert(splits.length==2)
144145
assert(splits(0).length==99)

0 commit comments

Comments
 (0)