@@ -204,15 +204,12 @@ object DecisionTree extends Serializable with Logging {
204
204
}
205
205
206
206
/* Finds the right bin for the given feature*/
207
- def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
208
- // logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
207
+ def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinous : Boolean ) : Int = {
209
208
210
- val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
211
209
if (isFeatureContinous){
212
210
// TODO: Do binary search
213
211
for (binIndex <- 0 until strategy.numBins) {
214
212
val bin = bins(featureIndex)(binIndex)
215
- // TODO: Remove this requirement post basic functional
216
213
val lowThreshold = bin.lowSplit.threshold
217
214
val highThreshold = bin.highSplit.threshold
218
215
val features = labeledPoint.features
@@ -222,9 +219,9 @@ object DecisionTree extends Serializable with Logging {
222
219
}
223
220
throw new UnknownError (" no bin was found for continuous variable." )
224
221
} else {
222
+
225
223
for (binIndex <- 0 until strategy.numBins) {
226
224
val bin = bins(featureIndex)(binIndex)
227
- // TODO: Remove this requirement post basic functional
228
225
val category = bin.category
229
226
val features = labeledPoint.features
230
227
if (category == features(featureIndex)) {
@@ -262,7 +259,8 @@ object DecisionTree extends Serializable with Logging {
262
259
} else {
263
260
for (featureIndex <- 0 until numFeatures) {
264
261
// logDebug("shift+featureIndex =" + (shift+featureIndex))
265
- arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
262
+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
263
+ arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous)
266
264
}
267
265
}
268
266
0 commit comments