Skip to content

Commit 07fc11d

Browse files
committed
Renamed numClassesForClassification to numClasses everywhere in trees and ensembles.
This is a breaking API change, but it was necessary to correct an API inconsistency in Spark 1.1 (where Python DecisionTree used numClasses but Scala used numClassesForClassification). Added examples to programming guide for all ensembles.
1 parent cdfdfbc commit 07fc11d

File tree

14 files changed

+128
-124
lines changed

14 files changed

+128
-124
lines changed

docs/mllib-random-forest.md

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The test error is calculated to measure the algorithm accuracy.
8181

8282
<div data-lang="scala">
8383
{% highlight scala %}
84-
import org.apache.spark.mllib.tree.DecisionTree
84+
import org.apache.spark.mllib.tree.RandomForest
8585
import org.apache.spark.mllib.util.MLUtils
8686

8787
// Load and parse the data file.
@@ -90,16 +90,18 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
9090
val splits = data.randomSplit(Array(0.7, 0.3))
9191
val (trainingData, testData) = (splits(0), splits(1))
9292

93-
// Train a DecisionTree model.
93+
// Train a RandomForest model.
9494
// Empty categoricalFeaturesInfo indicates all features are continuous.
9595
val numClasses = 2
9696
val categoricalFeaturesInfo = Map[Int, Int]()
97+
val numTrees = 3 // Use more in practice.
98+
val featureSubsetStrategy = "auto" // Let the algorithm choose.
9799
val impurity = "gini"
98-
val maxDepth = 5
100+
val maxDepth = 4
99101
val maxBins = 32
100102

101-
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
102-
impurity, maxDepth, maxBins)
103+
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
104+
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
103105

104106
// Evaluate model on test instances and compute test error
105107
val labelAndPreds = testData.map { point =>
@@ -108,26 +110,26 @@ val labelAndPreds = testData.map { point =>
108110
}
109111
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
110112
println("Test Error = " + testErr)
111-
println("Learned classification tree model:\n" + model.toDebugString)
113+
println("Learned classification forest model:\n" + model.toDebugString)
112114
{% endhighlight %}
113115
</div>
114116

115117
<div data-lang="java">
116118
{% highlight java %}
117-
import java.util.HashMap;
118119
import scala.Tuple2;
120+
import java.util.HashMap;
121+
import org.apache.spark.SparkConf;
119122
import org.apache.spark.api.java.JavaPairRDD;
120123
import org.apache.spark.api.java.JavaRDD;
121124
import org.apache.spark.api.java.JavaSparkContext;
122125
import org.apache.spark.api.java.function.Function;
123126
import org.apache.spark.api.java.function.PairFunction;
124127
import org.apache.spark.mllib.regression.LabeledPoint;
125-
import org.apache.spark.mllib.tree.DecisionTree;
126-
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
128+
import org.apache.spark.mllib.tree.RandomForest;
129+
import org.apache.spark.mllib.tree.model.RandomForestModel;
127130
import org.apache.spark.mllib.util.MLUtils;
128-
import org.apache.spark.SparkConf;
129131

130-
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
132+
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification");
131133
JavaSparkContext sc = new JavaSparkContext(sparkConf);
132134

133135
// Load and parse the data file.
@@ -138,17 +140,20 @@ JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
138140
JavaRDD<LabeledPoint> trainingData = splits[0];
139141
JavaRDD<LabeledPoint> testData = splits[1];
140142

141-
// Set parameters.
143+
// Train a RandomForest model.
142144
// Empty categoricalFeaturesInfo indicates all features are continuous.
143145
Integer numClasses = 2;
144-
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
146+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
147+
Integer numTrees = 3; // Use more in practice.
148+
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
145149
String impurity = "gini";
146150
Integer maxDepth = 5;
147151
Integer maxBins = 32;
152+
Integer seed = 12345;
148153

149-
// Train a DecisionTree model for classification.
150-
final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
151-
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
154+
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
155+
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
156+
seed);
152157

153158
// Evaluate model on test instances and compute test error
154159
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -166,38 +171,36 @@ Double testErr =
166171
}
167172
}).count() / testData.count();
168173
System.out.println("Test Error: " + testErr);
169-
System.out.println("Learned classification tree model:\n" + model.toDebugString());
174+
System.out.println("Learned classification forest model:\n" + model.toDebugString());
170175
{% endhighlight %}
171176
</div>
172177

173178
<div data-lang="python">
174179
{% highlight python %}
175-
from pyspark.mllib.regression import LabeledPoint
176-
from pyspark.mllib.tree import DecisionTree
180+
from pyspark.mllib.tree import RandomForest
177181
from pyspark.mllib.util import MLUtils
178182

179183
# Load and parse the data file into an RDD of LabeledPoint.
180184
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
181185
# Split the data into training and test sets (30% held out for testing)
182186
(trainingData, testData) = data.randomSplit([0.7, 0.3])
183187

184-
# Train a DecisionTree model.
188+
# Train a RandomForest model.
185189
# Empty categoricalFeaturesInfo indicates all features are continuous.
186-
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
187-
impurity='gini', maxDepth=5, maxBins=32)
190+
# Note: Use larger numTrees in practice.
191+
# Setting featureSubsetStrategy="auto" lets the algorithm choose.
192+
model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
193+
numTrees=3, featureSubsetStrategy="auto",
194+
impurity='gini', maxDepth=4, maxBins=32)
188195

189196
# Evaluate model on test instances and compute test error
190197
predictions = model.predict(testData.map(lambda x: x.features))
191198
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
192199
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
193200
print('Test Error = ' + str(testErr))
194-
print('Learned classification tree model:')
201+
print('Learned classification forest model:')
195202
print(model.toDebugString())
196203
{% endhighlight %}
197-
198-
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
199-
than separately calling `predict` on each data point. This is because the Python code makes calls
200-
to an underlying `DecisionTree` model in Scala.
201204
</div>
202205

203206
</div>
@@ -215,7 +218,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
215218

216219
<div data-lang="scala">
217220
{% highlight scala %}
218-
import org.apache.spark.mllib.tree.DecisionTree
221+
import org.apache.spark.mllib.tree.RandomForest
219222
import org.apache.spark.mllib.util.MLUtils
220223

221224
// Load and parse the data file.
@@ -224,15 +227,18 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
224227
val splits = data.randomSplit(Array(0.7, 0.3))
225228
val (trainingData, testData) = (splits(0), splits(1))
226229

227-
// Train a DecisionTree model.
230+
// Train a RandomForest model.
228231
// Empty categoricalFeaturesInfo indicates all features are continuous.
232+
val numClasses = 2
229233
val categoricalFeaturesInfo = Map[Int, Int]()
234+
val numTrees = 3 // Use more in practice.
235+
val featureSubsetStrategy = "auto" // Let the algorithm choose.
230236
val impurity = "variance"
231-
val maxDepth = 5
237+
val maxDepth = 4
232238
val maxBins = 32
233239

234-
val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
235-
maxDepth, maxBins)
240+
val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo,
241+
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
236242

237243
// Evaluate model on test instances and compute test error
238244
val labelsAndPredictions = testData.map { point =>
@@ -241,7 +247,7 @@ val labelsAndPredictions = testData.map { point =>
241247
}
242248
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
243249
println("Test Mean Squared Error = " + testMSE)
244-
println("Learned regression tree model:\n" + model.toDebugString)
250+
println("Learned regression forest model:\n" + model.toDebugString)
245251
{% endhighlight %}
246252
</div>
247253

@@ -256,12 +262,12 @@ import org.apache.spark.api.java.JavaSparkContext;
256262
import org.apache.spark.api.java.function.Function;
257263
import org.apache.spark.api.java.function.PairFunction;
258264
import org.apache.spark.mllib.regression.LabeledPoint;
259-
import org.apache.spark.mllib.tree.DecisionTree;
260-
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
265+
import org.apache.spark.mllib.tree.RandomForest;
266+
import org.apache.spark.mllib.tree.model.RandomForestModel;
261267
import org.apache.spark.mllib.util.MLUtils;
262268
import org.apache.spark.SparkConf;
263269

264-
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
270+
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForest");
265271
JavaSparkContext sc = new JavaSparkContext(sparkConf);
266272

267273
// Load and parse the data file.
@@ -276,11 +282,11 @@ JavaRDD<LabeledPoint> testData = splits[1];
276282
// Empty categoricalFeaturesInfo indicates all features are continuous.
277283
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
278284
String impurity = "variance";
279-
Integer maxDepth = 5;
285+
Integer maxDepth = 4;
280286
Integer maxBins = 32;
281287

282-
// Train a DecisionTree model.
283-
final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
288+
// Train a RandomForest model.
289+
final RandomForestModel model = RandomForest.trainRegressor(trainingData,
284290
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
285291

286292
// Evaluate model on test instances and compute test error
@@ -305,38 +311,36 @@ Double testMSE =
305311
}
306312
}) / data.count();
307313
System.out.println("Test Mean Squared Error: " + testMSE);
308-
System.out.println("Learned regression tree model:\n" + model.toDebugString());
314+
System.out.println("Learned regression forest model:\n" + model.toDebugString());
309315
{% endhighlight %}
310316
</div>
311317

312318
<div data-lang="python">
313319
{% highlight python %}
314-
from pyspark.mllib.regression import LabeledPoint
315-
from pyspark.mllib.tree import DecisionTree
320+
from pyspark.mllib.tree import RandomForest
316321
from pyspark.mllib.util import MLUtils
317322

318323
# Load and parse the data file into an RDD of LabeledPoint.
319324
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
320325
# Split the data into training and test sets (30% held out for testing)
321326
(trainingData, testData) = data.randomSplit([0.7, 0.3])
322327

323-
# Train a DecisionTree model.
328+
# Train a RandomForest model.
324329
# Empty categoricalFeaturesInfo indicates all features are continuous.
325-
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
326-
impurity='variance', maxDepth=5, maxBins=32)
330+
# Note: Use larger numTrees in practice.
331+
# Setting featureSubsetStrategy="auto" lets the algorithm choose.
332+
model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={},
333+
numTrees=3, featureSubsetStrategy="auto",
334+
impurity='variance', maxDepth=4, maxBins=32)
327335

328336
# Evaluate model on test instances and compute test error
329337
predictions = model.predict(testData.map(lambda x: x.features))
330338
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
331339
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count())
332340
print('Test Mean Squared Error = ' + str(testMSE))
333-
print('Learned regression tree model:')
341+
print('Learned regression forest model:')
334342
print(model.toDebugString())
335343
{% endhighlight %}
336-
337-
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
338-
than separately calling `predict` on each data point. This is because the Python code makes calls
339-
to an underlying `DecisionTree` model in Scala.
340344
</div>
341345

342346
</div>

examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public static void main(String[] args) {
7373
return p.label();
7474
}
7575
}).countByValue().size();
76-
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
76+
boostingStrategy.treeStrategy().setNumClasses(numClasses);
7777

7878
// Train a GradientBoosting model for classification.
7979
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ object DecisionTreeRunner {
276276
impurity = impurityCalculator,
277277
maxDepth = params.maxDepth,
278278
maxBins = params.maxBins,
279-
numClassesForClassification = numClasses,
279+
numClasses = numClasses,
280280
minInstancesPerNode = params.minInstancesPerNode,
281281
minInfoGain = params.minInfoGain,
282282
useNodeIdCache = params.useNodeIdCache,

examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ object GradientBoostedTreesRunner {
103103
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
104104

105105
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
106-
boostingStrategy.treeStrategy.numClassesForClassification = numClasses
106+
boostingStrategy.treeStrategy.numClasses = numClasses
107107
boostingStrategy.numIterations = params.numIterations
108108
boostingStrategy.treeStrategy.maxDepth = params.maxDepth
109109

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class PythonMLLibAPI extends Serializable {
477477
algo = algo,
478478
impurity = impurity,
479479
maxDepth = maxDepth,
480-
numClassesForClassification = numClasses,
480+
numClasses = numClasses,
481481
maxBins = maxBins,
482482
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
483483
minInstancesPerNode = minInstancesPerNode,
@@ -513,7 +513,7 @@ class PythonMLLibAPI extends Serializable {
513513
algo = algo,
514514
impurity = impurity,
515515
maxDepth = maxDepth,
516-
numClassesForClassification = numClasses,
516+
numClasses = numClasses,
517517
maxBins = maxBins,
518518
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
519519
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,16 @@ object DecisionTree extends Serializable with Logging {
136136
* @param impurity impurity criterion used for information gain calculation
137137
* @param maxDepth Maximum depth of the tree.
138138
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
139-
* @param numClassesForClassification number of classes for classification. Default value of 2.
139+
* @param numClasses number of classes for classification. Default value of 2.
140140
* @return DecisionTreeModel that can be used for prediction
141141
*/
142142
def train(
143143
input: RDD[LabeledPoint],
144144
algo: Algo,
145145
impurity: Impurity,
146146
maxDepth: Int,
147-
numClassesForClassification: Int): DecisionTreeModel = {
148-
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
147+
numClasses: Int): DecisionTreeModel = {
148+
val strategy = new Strategy(algo, impurity, maxDepth, numClasses)
149149
new DecisionTree(strategy).run(input)
150150
}
151151

@@ -164,7 +164,7 @@ object DecisionTree extends Serializable with Logging {
164164
* @param impurity criterion used for information gain calculation
165165
* @param maxDepth Maximum depth of the tree.
166166
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
167-
* @param numClassesForClassification number of classes for classification. Default value of 2.
167+
* @param numClasses number of classes for classification. Default value of 2.
168168
* @param maxBins maximum number of bins used for splitting features
169169
* @param quantileCalculationStrategy algorithm for calculating quantiles
170170
* @param categoricalFeaturesInfo Map storing arity of categorical features.
@@ -177,11 +177,11 @@ object DecisionTree extends Serializable with Logging {
177177
algo: Algo,
178178
impurity: Impurity,
179179
maxDepth: Int,
180-
numClassesForClassification: Int,
180+
numClasses: Int,
181181
maxBins: Int,
182182
quantileCalculationStrategy: QuantileStrategy,
183183
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
184-
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
184+
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
185185
quantileCalculationStrategy, categoricalFeaturesInfo)
186186
new DecisionTree(strategy).run(input)
187187
}
@@ -191,7 +191,7 @@ object DecisionTree extends Serializable with Logging {
191191
*
192192
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
193193
* Labels should take values {0, 1, ..., numClasses-1}.
194-
* @param numClassesForClassification number of classes for classification.
194+
* @param numClasses number of classes for classification.
195195
* @param categoricalFeaturesInfo Map storing arity of categorical features.
196196
* E.g., an entry (n -> k) indicates that feature n is categorical
197197
* with k categories indexed from 0: {0, 1, ..., k-1}.
@@ -206,13 +206,13 @@ object DecisionTree extends Serializable with Logging {
206206
*/
207207
def trainClassifier(
208208
input: RDD[LabeledPoint],
209-
numClassesForClassification: Int,
209+
numClasses: Int,
210210
categoricalFeaturesInfo: Map[Int, Int],
211211
impurity: String,
212212
maxDepth: Int,
213213
maxBins: Int): DecisionTreeModel = {
214214
val impurityType = Impurities.fromString(impurity)
215-
train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
215+
train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
216216
categoricalFeaturesInfo)
217217
}
218218

@@ -221,12 +221,12 @@ object DecisionTree extends Serializable with Logging {
221221
*/
222222
def trainClassifier(
223223
input: JavaRDD[LabeledPoint],
224-
numClassesForClassification: Int,
224+
numClasses: Int,
225225
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
226226
impurity: String,
227227
maxDepth: Int,
228228
maxBins: Int): DecisionTreeModel = {
229-
trainClassifier(input.rdd, numClassesForClassification,
229+
trainClassifier(input.rdd, numClasses,
230230
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
231231
impurity, maxDepth, maxBins)
232232
}

0 commit comments

Comments
 (0)