@@ -152,17 +152,15 @@ object DecisionTree extends Serializable with Logging {
152
152
// Find the number of features by looking at the first sample
153
153
val numFeatures = input.take(1 )(0 ).features.length
154
154
logDebug(" numFeatures = " + numFeatures)
155
- val numSplits = strategy.numBins
156
- logDebug(" numSplits = " + numSplits )
155
+ val numBins = strategy.numBins
156
+ logDebug(" numBins = " + numBins )
157
157
158
158
/* Find the filters used before reaching the current code*/
159
159
def findParentFilters (nodeIndex : Int ): List [Filter ] = {
160
160
if (level == 0 ) {
161
161
List [Filter ]()
162
162
} else {
163
163
val nodeFilterIndex = scala.math.pow(2 , level).toInt - 1 + nodeIndex
164
- // val parentFilterIndex = nodeFilterIndex / 2
165
- // TODO: Check left or right filter
166
164
filters(nodeFilterIndex)
167
165
}
168
166
}
@@ -204,9 +202,9 @@ object DecisionTree extends Serializable with Logging {
204
202
}
205
203
206
204
/* Finds the right bin for the given feature*/
207
- def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinous : Boolean ) : Int = {
205
+ def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinuous : Boolean ) : Int = {
208
206
209
- if (isFeatureContinous ){
207
+ if (isFeatureContinuous ){
210
208
// TODO: Do binary search
211
209
for (binIndex <- 0 until strategy.numBins) {
212
210
val bin = bins(featureIndex)(binIndex)
@@ -245,11 +243,11 @@ object DecisionTree extends Serializable with Logging {
245
243
// calculating bin index and label per feature per node
246
244
val arr = new Array [Double ](1 + (numFeatures * numNodes))
247
245
arr(0 ) = labeledPoint.label
248
- for (index <- 0 until numNodes) {
249
- val parentFilters = findParentFilters(index )
246
+ for (nodeIndex <- 0 until numNodes) {
247
+ val parentFilters = findParentFilters(nodeIndex )
250
248
// Find out whether the sample qualifies for the particular node
251
249
val sampleValid = isSampleValid(parentFilters, labeledPoint)
252
- val shift = 1 + numFeatures * index
250
+ val shift = 1 + numFeatures * nodeIndex
253
251
if (! sampleValid) {
254
252
// Add to invalid bin index -1
255
253
for (featureIndex <- 0 until numFeatures) {
@@ -274,11 +272,11 @@ object DecisionTree extends Serializable with Logging {
274
272
val isSampleValidForNode = if (arr(validSignalIndex) != - 1 ) true else false
275
273
if (isSampleValidForNode) {
276
274
val label = arr(0 )
277
- for (feature <- 0 until numFeatures) {
275
+ for (featureIndex <- 0 until numFeatures) {
278
276
val arrShift = 1 + numFeatures * node
279
- val aggShift = 2 * numSplits * numFeatures * node
280
- val arrIndex = arrShift + feature
281
- val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2
277
+ val aggShift = 2 * numBins * numFeatures * node
278
+ val arrIndex = arrShift + featureIndex
279
+ val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
282
280
label match {
283
281
case (0.0 ) => agg(aggIndex) = agg(aggIndex) + 1
284
282
case (1.0 ) => agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + 1
@@ -296,9 +294,9 @@ object DecisionTree extends Serializable with Logging {
296
294
val label = arr(0 )
297
295
for (feature <- 0 until numFeatures) {
298
296
val arrShift = 1 + numFeatures * node
299
- val aggShift = 3 * numSplits * numFeatures * node
297
+ val aggShift = 3 * numBins * numFeatures * node
300
298
val arrIndex = arrShift + feature
301
- val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3
299
+ val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3
302
300
// count, sum, sum^2
303
301
agg(aggIndex) = agg(aggIndex) + 1
304
302
agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + label
@@ -318,7 +316,6 @@ object DecisionTree extends Serializable with Logging {
318
316
@return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
319
317
*/
320
318
def binSeqOp (agg : Array [Double ], arr : Array [Double ]) : Array [Double ] = {
321
- // TODO: Requires logic for regressions
322
319
strategy.algo match {
323
320
case Classification => classificationBinSeqOp(arr, agg)
324
321
// TODO: Implement this
@@ -327,10 +324,9 @@ object DecisionTree extends Serializable with Logging {
327
324
agg
328
325
}
329
326
330
- // TODO: This length is different for regression
331
327
val binAggregateLength = strategy.algo match {
332
- case Classification => 2 * numSplits * numFeatures * numNodes
333
- case Regression => 3 * numSplits * numFeatures * numNodes
328
+ case Classification => 2 * numBins * numFeatures * numNodes
329
+ case Regression => 3 * numBins * numFeatures * numNodes
334
330
}
335
331
logDebug(" binAggregateLength = " + binAggregateLength)
336
332
@@ -453,52 +449,52 @@ object DecisionTree extends Serializable with Logging {
453
449
strategy.algo match {
454
450
case Classification => {
455
451
456
- val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
457
- val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
452
+ val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numBins - 1 ))
453
+ val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numBins - 1 ))
458
454
for (featureIndex <- 0 until numFeatures) {
459
- val shift = 2 * featureIndex* numSplits
455
+ val shift = 2 * featureIndex* numBins
460
456
leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
461
457
leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
462
- rightNodeAgg(featureIndex)(2 * (numSplits - 2 )) = binData(shift + (2 * (numSplits - 1 )))
463
- rightNodeAgg(featureIndex)(2 * (numSplits - 2 ) + 1 ) = binData(shift + (2 * (numSplits - 1 )) + 1 )
464
- for (splitIndex <- 1 until numSplits - 1 ) {
458
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 )) = binData(shift + (2 * (numBins - 1 )))
459
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 ) + 1 ) = binData(shift + (2 * (numBins - 1 )) + 1 )
460
+ for (splitIndex <- 1 until numBins - 1 ) {
465
461
leftNodeAgg(featureIndex)(2 * splitIndex)
466
462
= binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
467
463
leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
468
464
= binData(shift + 2 * splitIndex + 1 ) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1 )
469
- rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex))
470
- = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex))
471
- rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1 )
472
- = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1 )
465
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
466
+ = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
467
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1 )
468
+ = binData(shift + (2 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
473
469
}
474
470
}
475
471
(leftNodeAgg, rightNodeAgg)
476
472
}
477
473
case Regression => {
478
474
479
- val leftNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numSplits - 1 ))
480
- val rightNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numSplits - 1 ))
475
+ val leftNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numBins - 1 ))
476
+ val rightNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numBins - 1 ))
481
477
for (featureIndex <- 0 until numFeatures) {
482
- val shift = 3 * featureIndex* numSplits
478
+ val shift = 3 * featureIndex* numBins
483
479
leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
484
480
leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
485
481
leftNodeAgg(featureIndex)(2 ) = binData(shift + 2 )
486
- rightNodeAgg(featureIndex)(3 * (numSplits - 2 )) = binData(shift + (3 * (numSplits - 1 )))
487
- rightNodeAgg(featureIndex)(3 * (numSplits - 2 ) + 1 ) = binData(shift + (3 * (numSplits - 1 )) + 1 )
488
- rightNodeAgg(featureIndex)(3 * (numSplits - 2 ) + 2 ) = binData(shift + (3 * (numSplits - 1 )) + 2 )
489
- for (splitIndex <- 1 until numSplits - 1 ) {
482
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 )) = binData(shift + (3 * (numBins - 1 )))
483
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 ) + 1 ) = binData(shift + (3 * (numBins - 1 )) + 1 )
484
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 ) + 2 ) = binData(shift + (3 * (numBins - 1 )) + 2 )
485
+ for (splitIndex <- 1 until numBins - 1 ) {
490
486
leftNodeAgg(featureIndex)(3 * splitIndex)
491
487
= binData(shift + 3 * splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 )
492
488
leftNodeAgg(featureIndex)(3 * splitIndex + 1 )
493
489
= binData(shift + 3 * splitIndex + 1 ) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1 )
494
490
leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
495
491
= binData(shift + 3 * splitIndex + 2 ) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2 )
496
- rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex))
497
- = binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex))
498
- rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1 )
499
- = binData(shift + (3 * (numSplits - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1 )
500
- rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2 )
501
- = binData(shift + (3 * (numSplits - 1 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2 )
492
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
493
+ = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
494
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1 )
495
+ = binData(shift + (3 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1 )
496
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2 )
497
+ = binData(shift + (3 * (numBins - 1 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2 )
502
498
}
503
499
}
504
500
(leftNodeAgg, rightNodeAgg)
@@ -509,10 +505,10 @@ object DecisionTree extends Serializable with Logging {
509
505
def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double )
510
506
: Array [Array [InformationGainStats ]] = {
511
507
512
- val gains = Array .ofDim[InformationGainStats ](numFeatures, numSplits - 1 )
508
+ val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
513
509
514
510
for (featureIndex <- 0 until numFeatures) {
515
- for (index <- 0 until numSplits - 1 ) {
511
+ for (index <- 0 until numBins - 1 ) {
516
512
// logDebug("splitIndex = " + index)
517
513
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
518
514
}
@@ -521,10 +517,10 @@ object DecisionTree extends Serializable with Logging {
521
517
}
522
518
523
519
/*
524
- Find the best split for a node given bin aggregate data
520
+ Find the best split for a node given bin aggregate data
525
521
526
- @param binData Array[Double] of size 2*numSplits*numFeatures
527
- */
522
+ @param binData Array[Double] of size 2*numSplits*numFeatures
523
+ */
528
524
def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : (Split , InformationGainStats ) = {
529
525
logDebug(" node impurity = " + nodeImpurity)
530
526
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
@@ -536,7 +532,7 @@ object DecisionTree extends Serializable with Logging {
536
532
// Initialization with infeasible values
537
533
var bestGainStats = new InformationGainStats (Double .MinValue ,- 1.0 ,- 1.0 ,- 1.0 ,- 1 )
538
534
for (featureIndex <- 0 until numFeatures) {
539
- for (splitIndex <- 0 until numSplits - 1 ){
535
+ for (splitIndex <- 0 until numBins - 1 ){
540
536
val gainStats = gains(featureIndex)(splitIndex)
541
537
if (gainStats.gain > bestGainStats.gain) {
542
538
bestGainStats = gainStats
@@ -556,13 +552,13 @@ object DecisionTree extends Serializable with Logging {
556
552
def getBinDataForNode (node : Int ): Array [Double ] = {
557
553
strategy.algo match {
558
554
case Classification => {
559
- val shift = 2 * node * numSplits * numFeatures
560
- val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures)
555
+ val shift = 2 * node * numBins * numFeatures
556
+ val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
561
557
binsForNode
562
558
}
563
559
case Regression => {
564
- val shift = 3 * node * numSplits * numFeatures
565
- val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures)
560
+ val shift = 3 * node * numBins * numFeatures
561
+ val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
566
562
binsForNode
567
563
}
568
564
}
0 commit comments