Skip to content

Commit d504eb1

Browse files
committed
more tests for categorical features
Signed-off-by: Manish Amde <[email protected]>
1 parent dbb7ac1 commit d504eb1

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,12 @@ object DecisionTree extends Serializable with Logging {
204204
}
205205

206206
/*Finds the right bin for the given feature*/
207-
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
208-
//logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
207+
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = {
209208

210-
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
211209
if (isFeatureContinous){
212210
//TODO: Do binary search
213211
for (binIndex <- 0 until strategy.numBins) {
214212
val bin = bins(featureIndex)(binIndex)
215-
//TODO: Remove this requirement post basic functional
216213
val lowThreshold = bin.lowSplit.threshold
217214
val highThreshold = bin.highSplit.threshold
218215
val features = labeledPoint.features
@@ -222,9 +219,9 @@ object DecisionTree extends Serializable with Logging {
222219
}
223220
throw new UnknownError("no bin was found for continuous variable.")
224221
} else {
222+
225223
for (binIndex <- 0 until strategy.numBins) {
226224
val bin = bins(featureIndex)(binIndex)
227-
//TODO: Remove this requirement post basic functional
228225
val category = bin.category
229226
val features = labeledPoint.features
230227
if (category == features(featureIndex)) {
@@ -262,7 +259,8 @@ object DecisionTree extends Serializable with Logging {
262259
} else {
263260
for (featureIndex <- 0 until numFeatures) {
264261
//logDebug("shift+featureIndex =" + (shift+featureIndex))
265-
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
262+
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
263+
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous)
266264
}
267265
}
268266

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
7575
println(splits(1)(0))
7676
println(splits(1)(1))
7777
println(bins(1)(0))
78+
//TODO: Add asserts
79+
7880
}
7981

8082
test("split and bin calculations for categorical variables with no sample for one category"){
@@ -100,12 +102,28 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
100102
println(bins(1)(1))
101103
println(bins(0)(2))
102104
println(bins(0)(3))
105+
//TODO: Add asserts
106+
103107
}
104108

105109
//TODO: Test max feature value > num bins
106110

107111

108-
test("stump with fixed label 0 for Gini"){
112+
test("stump with all categorical variables"){
113+
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
114+
assert(arr.length == 1000)
115+
val rdd = sc.parallelize(arr)
116+
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
117+
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
118+
strategy.numBins = 100
119+
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
120+
println(bestSplits(0)._1)
121+
println(bestSplits(0)._2)
122+
//TODO: Add asserts
123+
}
124+
125+
126+
test("stump with fixed label 0 for Gini"){
109127
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
110128
assert(arr.length == 1000)
111129
val rdd = sc.parallelize(arr)

0 commit comments

Comments
 (0)