@@ -26,6 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
26
26
import scala .util .control .Breaks ._
27
27
import org .apache .spark .mllib .tree .configuration .Strategy
28
28
import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
29
+ import org .apache .spark .mllib .tree .configuration .FeatureType ._
29
30
30
31
31
32
class DecisionTree (val strategy : Strategy ) extends Serializable with Logging {
@@ -353,21 +354,13 @@ object DecisionTree extends Serializable with Logging {
353
354
def extractLeftRightNodeAggregates (binData : Array [Double ]): (Array [Array [Double ]], Array [Array [Double ]]) = {
354
355
val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
355
356
val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
356
- // logDebug("binData.length = " + binData.length)
357
- // logDebug("binData.sum = " + binData.sum)
358
357
for (featureIndex <- 0 until numFeatures) {
359
- // logDebug("featureIndex = " + featureIndex)
360
358
val shift = 2 * featureIndex* numSplits
361
359
leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
362
- // logDebug("binData(shift + 0) = " + binData(shift + 0))
363
360
leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
364
- // logDebug("binData(shift + 1) = " + binData(shift + 1))
365
361
rightNodeAgg(featureIndex)(2 * (numSplits - 2 )) = binData(shift + (2 * (numSplits - 1 )))
366
- // logDebug(binData(shift + (2 * (numSplits - 1))))
367
362
rightNodeAgg(featureIndex)(2 * (numSplits - 2 ) + 1 ) = binData(shift + (2 * (numSplits - 1 )) + 1 )
368
- // logDebug(binData(shift + (2 * (numSplits - 1)) + 1))
369
363
for (splitIndex <- 1 until numSplits - 1 ) {
370
- // logDebug("splitIndex = " + splitIndex)
371
364
leftNodeAgg(featureIndex)(2 * splitIndex)
372
365
= binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
373
366
leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
@@ -479,33 +472,43 @@ object DecisionTree extends Serializable with Logging {
479
472
480
473
// Find all splits
481
474
for (featureIndex <- 0 until numFeatures){
482
- val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
483
-
484
- val stride : Double = numSamples.toDouble/ numBins
485
- logDebug(" stride = " + stride)
486
- for (index <- 0 until numBins- 1 ) {
487
- val sampleIndex = (index+ 1 )* stride.toInt
488
- val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
489
- splits(featureIndex)(index) = split
475
+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
476
+ if (isFeatureContinous) {
477
+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
478
+
479
+ val stride : Double = numSamples.toDouble/ numBins
480
+ logDebug(" stride = " + stride)
481
+ for (index <- 0 until numBins- 1 ) {
482
+ val sampleIndex = (index+ 1 )* stride.toInt
483
+ val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous )
484
+ splits(featureIndex)(index) = split
485
+ }
486
+ } else {
487
+ val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
488
+ for (index <- 0 until maxFeatureValue){
489
+ // TODO: Sort by centriod
490
+ val split = new Split (featureIndex,index,Categorical )
491
+ splits(featureIndex)(index) = split
492
+ }
490
493
}
491
494
}
492
495
493
496
// Find all bins
494
497
for (featureIndex <- 0 until numFeatures){
495
498
bins(featureIndex)(0 )
496
- = new Bin (new DummyLowSplit (" continuous " ),splits(featureIndex)(0 )," continuous " )
499
+ = new Bin (new DummyLowSplit (Continuous ),splits(featureIndex)(0 ),Continuous )
497
500
for (index <- 1 until numBins - 1 ){
498
- val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index)," continuous " )
501
+ val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Continuous )
499
502
bins(featureIndex)(index) = bin
500
503
}
501
504
bins(featureIndex)(numBins- 1 )
502
- = new Bin (splits(featureIndex)(numBins- 3 ),new DummyHighSplit (" continuous " ), " continuous " )
505
+ = new Bin (splits(featureIndex)(numBins- 3 ),new DummyHighSplit (Continuous ), Continuous )
503
506
}
504
507
505
508
(splits,bins)
506
509
}
507
510
case MinMax => {
508
- ( Array .ofDim[ Split ](numFeatures,numBins), Array .ofDim[ Bin ](numFeatures,numBins + 2 ) )
511
+ throw new UnsupportedOperationException ( " minmax not supported yet. " )
509
512
}
510
513
case ApproxHist => {
511
514
throw new UnsupportedOperationException (" approximate histogram not supported yet." )
0 commit comments