Skip to content

Commit 63e786b

Browse files
committed
added multiple train methods for java compatability
1 parent d3023b3 commit 63e786b

File tree

3 files changed

+72
-7
lines changed

3 files changed

+72
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
2828
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2929
import org.apache.spark.mllib.tree.configuration.FeatureType._
3030
import org.apache.spark.mllib.tree.configuration.Algo._
31+
import org.apache.spark.mllib.tree.impurity.Impurity
3132

3233
/**
3334
A class that implements a decision tree algorithm for classification and regression.
@@ -38,7 +39,7 @@ algorithm (classification,
3839
regression, etc.), feature type (continuous, categorical), depth of the tree,
3940
quantile calculation strategy, etc.
4041
*/
41-
class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
42+
class DecisionTree private (val strategy : Strategy) extends Serializable with Logging {
4243

4344
/**
4445
Method to train a decision tree model over an RDD
@@ -157,6 +158,70 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
157158

158159
object DecisionTree extends Serializable with Logging {
159160

161+
/**
162+
Method to train a decision tree model over an RDD
163+
164+
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
165+
for DecisionTree
166+
@param strategy The configuration parameters for the tree algorithm which specify the type of algorithm
167+
(classification, regression, etc.), feature type (continuous, categorical),
168+
depth of the tree, quantile calculation strategy, etc.
169+
@return a DecisionTreeModel that can be used for prediction
170+
*/
171+
def train(input : RDD[LabeledPoint], strategy : Strategy) : DecisionTreeModel = {
172+
new DecisionTree(strategy).train(input : RDD[LabeledPoint])
173+
}
174+
175+
/**
176+
Method to train a decision tree model over an RDD
177+
178+
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
179+
for DecisionTree
180+
@param algo classification or regression
181+
@param impurity criterion used for information gain calculation
182+
@param maxDepth maximum depth of the tree
183+
@return a DecisionTreeModel that can be used for prediction
184+
*/
185+
def train(
186+
input : RDD[LabeledPoint],
187+
algo : Algo,
188+
impurity : Impurity,
189+
maxDepth : Int
190+
) : DecisionTreeModel = {
191+
val strategy = new Strategy(algo,impurity,maxDepth)
192+
new DecisionTree(strategy).train(input : RDD[LabeledPoint])
193+
}
194+
195+
196+
/**
197+
Method to train a decision tree model over an RDD
198+
199+
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
200+
for DecisionTree
201+
@param algo classification or regression
202+
@param impurity criterion used for information gain calculation
203+
@param maxDepth maximum depth of the tree
204+
@param maxBins maximum number of bins used for splitting features
205+
@param quantileCalculationStrategy algorithm for calculating quantiles
206+
@param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete
207+
values they take. For example, an entry (n -> k) implies the feature n is
208+
categorical with k categories 0, 1, 2, ... , k-1. It's important to note that
209+
features are zero-indexed.
210+
@return a DecisionTreeModel that can be used for prediction
211+
*/
212+
def train(
213+
input : RDD[LabeledPoint],
214+
algo : Algo,
215+
impurity : Impurity,
216+
maxDepth : Int,
217+
maxBins : Int,
218+
quantileCalculationStrategy : QuantileStrategy,
219+
categoricalFeaturesInfo : Map[Int,Int]
220+
) : DecisionTreeModel = {
221+
val strategy = new Strategy(algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo)
222+
new DecisionTree(strategy).train(input : RDD[LabeledPoint])
223+
}
224+
160225
/**
161226
Returns an Array[Split] of optimal splits for all nodes at a given level
162227
@@ -717,13 +782,13 @@ object DecisionTree extends Serializable with Logging {
717782
for DecisionTree
718783
@param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
719784
parameters for construction the DecisionTree
720-
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,
721-
numSplits-1) and bins is an
722-
Array[Array[Bin]] of size (numFeatures,numSplits1)
785+
@return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model.Split] of
786+
size (numFeatures,numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of
787+
size (numFeatures,numSplits1)
723788
*/
724789
def findSplitsBins(
725790
input : RDD[LabeledPoint],
726-
strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
791+
strategy : Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
727792

728793
val count = input.count()
729794

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ object DecisionTreeRunner extends Logging {
8787
val maxBins = options.getOrElse('maxBins,"100").toString.toInt
8888

8989
val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins)
90-
val model = new DecisionTree(strategy).train(trainData)
90+
val model = DecisionTree.train(trainData,strategy)
9191

9292
//Load test data
9393
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Strategy (
3737
val algo : Algo,
3838
val impurity : Impurity,
3939
val maxDepth : Int,
40-
val maxBins : Int,
40+
val maxBins : Int = 100,
4141
val quantileCalculationStrategy : QuantileStrategy = Sort,
4242
val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable {
4343

0 commit comments

Comments
 (0)