@@ -55,30 +55,40 @@ private[tree] object TreePoint {
55
55
input : RDD [LabeledPoint ],
56
56
bins : Array [Array [Bin ]],
57
57
metadata : DecisionTreeMetadata ): RDD [TreePoint ] = {
58
+ // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
59
+ val featureArity : Array [Int ] = new Array [Int ](metadata.numFeatures)
60
+ val isUnordered : Array [Boolean ] = new Array [Boolean ](metadata.numFeatures)
61
+ var featureIndex = 0
62
+ while (featureIndex < metadata.numFeatures) {
63
+ featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0 )
64
+ isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
65
+ featureIndex += 1
66
+ }
58
67
input.map { x =>
59
- TreePoint .labeledPointToTreePoint(x, bins, metadata )
68
+ TreePoint .labeledPointToTreePoint(x, bins, featureArity, isUnordered )
60
69
}
61
70
}
62
71
63
72
/**
64
73
* Convert one LabeledPoint into its TreePoint representation.
65
74
* @param bins Bins for features, of size (numFeatures, numBins).
66
- * @param metadata DecisionTree training info, used for dataset metadata.
75
+ * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
76
+ * for categorical features.
77
+ * @param isUnordered Array index by feature, with value true for unordered categorical features.
67
78
*/
68
79
private def labeledPointToTreePoint (
69
80
labeledPoint : LabeledPoint ,
70
81
bins : Array [Array [Bin ]],
71
- metadata : DecisionTreeMetadata ): TreePoint = {
82
+ featureArity : Array [Int ],
83
+ isUnordered : Array [Boolean ]): TreePoint = {
72
84
val numFeatures = labeledPoint.features.size
73
85
val arr = new Array [Int ](numFeatures)
74
86
var featureIndex = 0
75
87
while (featureIndex < numFeatures) {
76
- val featureArity = metadata.featureArity.getOrElse(featureIndex, 0 )
77
- arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity,
78
- metadata.isUnordered(featureIndex), bins)
88
+ arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
89
+ isUnordered(featureIndex), bins)
79
90
featureIndex += 1
80
91
}
81
-
82
92
new TreePoint (labeledPoint.label, arr)
83
93
}
84
94
0 commit comments