@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
28
28
import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
29
29
import org .apache .spark .mllib .tree .configuration .FeatureType ._
30
30
import org .apache .spark .mllib .tree .configuration .Algo ._
31
+ import org .apache .spark .mllib .tree .impurity .Impurity
31
32
32
33
/**
33
34
A class that implements a decision tree algorithm for classification and regression.
@@ -38,7 +39,7 @@ algorithm (classification,
38
39
regression, etc.), feature type (continuous, categorical), depth of the tree,
39
40
quantile calculation strategy, etc.
40
41
*/
41
- class DecisionTree (val strategy : Strategy ) extends Serializable with Logging {
42
+ class DecisionTree private (val strategy : Strategy ) extends Serializable with Logging {
42
43
43
44
/**
44
45
Method to train a decision tree model over an RDD
@@ -157,6 +158,70 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
157
158
158
159
object DecisionTree extends Serializable with Logging {
159
160
161
+ /**
162
+ Method to train a decision tree model over an RDD
163
+
164
+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
165
+ for DecisionTree
166
+ @param strategy The configuration parameters for the tree algorithm which specify the type of algorithm
167
+ (classification, regression, etc.), feature type (continuous, categorical),
168
+ depth of the tree, quantile calculation strategy, etc.
169
+ @return a DecisionTreeModel that can be used for prediction
170
+ */
171
+ def train (input : RDD [LabeledPoint ], strategy : Strategy ) : DecisionTreeModel = {
172
+ new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
173
+ }
174
+
175
+ /**
176
+ Method to train a decision tree model over an RDD
177
+
178
+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
179
+ for DecisionTree
180
+ @param algo classification or regression
181
+ @param impurity criterion used for information gain calculation
182
+ @param maxDepth maximum depth of the tree
183
+ @return a DecisionTreeModel that can be used for prediction
184
+ */
185
+ def train (
186
+ input : RDD [LabeledPoint ],
187
+ algo : Algo ,
188
+ impurity : Impurity ,
189
+ maxDepth : Int
190
+ ) : DecisionTreeModel = {
191
+ val strategy = new Strategy (algo,impurity,maxDepth)
192
+ new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
193
+ }
194
+
195
+
196
+ /**
197
+ Method to train a decision tree model over an RDD
198
+
199
+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
200
+ for DecisionTree
201
+ @param algo classification or regression
202
+ @param impurity criterion used for information gain calculation
203
+ @param maxDepth maximum depth of the tree
204
+ @param maxBins maximum number of bins used for splitting features
205
+ @param quantileCalculationStrategy algorithm for calculating quantiles
206
+ @param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete
207
+ values they take. For example, an entry (n -> k) implies the feature n is
208
+ categorical with k categories 0, 1, 2, ... , k-1. It's important to note that
209
+ features are zero-indexed.
210
+ @return a DecisionTreeModel that can be used for prediction
211
+ */
212
+ def train (
213
+ input : RDD [LabeledPoint ],
214
+ algo : Algo ,
215
+ impurity : Impurity ,
216
+ maxDepth : Int ,
217
+ maxBins : Int ,
218
+ quantileCalculationStrategy : QuantileStrategy ,
219
+ categoricalFeaturesInfo : Map [Int ,Int ]
220
+ ) : DecisionTreeModel = {
221
+ val strategy = new Strategy (algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo)
222
+ new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
223
+ }
224
+
160
225
/**
161
226
Returns an Array[Split] of optimal splits for all nodes at a given level
162
227
@@ -717,13 +782,13 @@ object DecisionTree extends Serializable with Logging {
717
782
for DecisionTree
718
783
@param strategy [[org.apache.spark.mllib.tree.configuration.Strategy ]] instance containing
719
784
parameters for construction the DecisionTree
720
- @return a tuple of (splits,bins) where Split is an Array[Array[ Split]] of size (numFeatures,
721
- numSplits-1) and bins is an
722
- Array[Array[Bin]] of size (numFeatures,numSplits1)
785
+ @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model. Split] of
786
+ size (numFeatures, numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of
787
+ size (numFeatures,numSplits1)
723
788
*/
724
789
def findSplitsBins (
725
790
input : RDD [LabeledPoint ],
726
- strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
791
+ strategy : Strategy ): (Array [Array [Split ]], Array [Array [Bin ]]) = {
727
792
728
793
val count = input.count()
729
794
0 commit comments