Skip to content

Commit 42c192a

Browse files
committed
Merge branch 'rfs' into dt-opt3alt
2 parents d3cc46b + 00e4404 commit 42c192a

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,40 @@ private[tree] object TreePoint {
5555
input: RDD[LabeledPoint],
5656
bins: Array[Array[Bin]],
5757
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
58+
// Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
59+
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
60+
val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
61+
var featureIndex = 0
62+
while (featureIndex < metadata.numFeatures) {
63+
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
64+
isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
65+
featureIndex += 1
66+
}
5867
input.map { x =>
59-
TreePoint.labeledPointToTreePoint(x, bins, metadata)
68+
TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
6069
}
6170
}
6271

6372
/**
6473
* Convert one LabeledPoint into its TreePoint representation.
6574
* @param bins Bins for features, of size (numFeatures, numBins).
66-
* @param metadata DecisionTree training info, used for dataset metadata.
75+
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
76+
* for categorical features.
77+
* @param isUnordered Array index by feature, with value true for unordered categorical features.
6778
*/
6879
private def labeledPointToTreePoint(
6980
labeledPoint: LabeledPoint,
7081
bins: Array[Array[Bin]],
71-
metadata: DecisionTreeMetadata): TreePoint = {
82+
featureArity: Array[Int],
83+
isUnordered: Array[Boolean]): TreePoint = {
7284
val numFeatures = labeledPoint.features.size
7385
val arr = new Array[Int](numFeatures)
7486
var featureIndex = 0
7587
while (featureIndex < numFeatures) {
76-
val featureArity = metadata.featureArity.getOrElse(featureIndex, 0)
77-
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity,
78-
metadata.isUnordered(featureIndex), bins)
88+
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
89+
isUnordered(featureIndex), bins)
7990
featureIndex += 1
8091
}
81-
8292
new TreePoint(labeledPoint.label, arr)
8393
}
8494

0 commit comments

Comments
 (0)