@@ -81,7 +81,7 @@ The test error is calculated to measure the algorithm accuracy.
81
81
82
82
<div data-lang =" scala " >
83
83
{% highlight scala %}
84
- import org.apache.spark.mllib.tree.DecisionTree
84
+ import org.apache.spark.mllib.tree.RandomForest
85
85
import org.apache.spark.mllib.util.MLUtils
86
86
87
87
// Load and parse the data file.
@@ -90,16 +90,18 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
90
90
val splits = data.randomSplit(Array(0.7, 0.3))
91
91
val (trainingData, testData) = (splits(0), splits(1))
92
92
93
- // Train a DecisionTree model.
93
+ // Train a RandomForest model.
94
94
// Empty categoricalFeaturesInfo indicates all features are continuous.
95
95
val numClasses = 2
96
96
val categoricalFeaturesInfo = Map[ Int, Int] ( )
97
+ val numTrees = 3 // Use more in practice.
98
+ val featureSubsetStrategy = "auto" // Let the algorithm choose.
97
99
val impurity = "gini"
98
- val maxDepth = 5
100
+ val maxDepth = 4
99
101
val maxBins = 32
100
102
101
- val model = DecisionTree .trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
102
- impurity, maxDepth, maxBins)
103
+ val model = RandomForest .trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
104
+ numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
103
105
104
106
// Evaluate model on test instances and compute test error
105
107
val labelAndPreds = testData.map { point =>
@@ -108,26 +110,26 @@ val labelAndPreds = testData.map { point =>
108
110
}
109
111
val testErr = labelAndPreds.filter(r => r._ 1 != r._ 2).count.toDouble / testData.count()
110
112
println("Test Error = " + testErr)
111
- println("Learned classification tree model:\n" + model.toDebugString)
113
+ println("Learned classification forest model:\n" + model.toDebugString)
112
114
{% endhighlight %}
113
115
</div >
114
116
115
117
<div data-lang =" java " >
116
118
{% highlight java %}
117
- import java.util.HashMap;
118
119
import scala.Tuple2;
120
+ import java.util.HashMap;
121
+ import org.apache.spark.SparkConf;
119
122
import org.apache.spark.api.java.JavaPairRDD;
120
123
import org.apache.spark.api.java.JavaRDD;
121
124
import org.apache.spark.api.java.JavaSparkContext;
122
125
import org.apache.spark.api.java.function.Function;
123
126
import org.apache.spark.api.java.function.PairFunction;
124
127
import org.apache.spark.mllib.regression.LabeledPoint;
125
- import org.apache.spark.mllib.tree.DecisionTree ;
126
- import org.apache.spark.mllib.tree.model.DecisionTreeModel ;
128
+ import org.apache.spark.mllib.tree.RandomForest ;
129
+ import org.apache.spark.mllib.tree.model.RandomForestModel ;
127
130
import org.apache.spark.mllib.util.MLUtils;
128
- import org.apache.spark.SparkConf;
129
131
130
- SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree ");
132
+ SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification ");
131
133
JavaSparkContext sc = new JavaSparkContext(sparkConf);
132
134
133
135
// Load and parse the data file.
@@ -138,17 +140,20 @@ JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
138
140
JavaRDD<LabeledPoint > trainingData = splits[ 0] ;
139
141
JavaRDD<LabeledPoint > testData = splits[ 1] ;
140
142
141
- // Set parameters .
143
+ // Train a RandomForest model .
142
144
// Empty categoricalFeaturesInfo indicates all features are continuous.
143
145
Integer numClasses = 2;
144
- Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
146
+ HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
147
+ Integer numTrees = 3; // Use more in practice.
148
+ String featureSubsetStrategy = "auto"; // Let the algorithm choose.
145
149
String impurity = "gini";
146
150
Integer maxDepth = 5;
147
151
Integer maxBins = 32;
152
+ Integer seed = 12345;
148
153
149
- // Train a DecisionTree model for classification.
150
- final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses ,
151
- categoricalFeaturesInfo, impurity, maxDepth, maxBins );
154
+ final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
155
+ categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins ,
156
+ seed );
152
157
153
158
// Evaluate model on test instances and compute test error
154
159
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -166,38 +171,36 @@ Double testErr =
166
171
}
167
172
}).count() / testData.count();
168
173
System.out.println("Test Error: " + testErr);
169
- System.out.println("Learned classification tree model:\n" + model.toDebugString());
174
+ System.out.println("Learned classification forest model:\n" + model.toDebugString());
170
175
{% endhighlight %}
171
176
</div >
172
177
173
178
<div data-lang =" python " >
174
179
{% highlight python %}
175
- from pyspark.mllib.regression import LabeledPoint
176
- from pyspark.mllib.tree import DecisionTree
180
+ from pyspark.mllib.tree import RandomForest
177
181
from pyspark.mllib.util import MLUtils
178
182
179
183
# Load and parse the data file into an RDD of LabeledPoint.
180
184
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
181
185
# Split the data into training and test sets (30% held out for testing)
182
186
(trainingData, testData) = data.randomSplit([ 0.7, 0.3] )
183
187
184
- # Train a DecisionTree model.
188
+ # Train a RandomForest model.
185
189
# Empty categoricalFeaturesInfo indicates all features are continuous.
186
- model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
187
- impurity='gini', maxDepth=5, maxBins=32)
190
+ # Note: Use larger numTrees in practice.
191
+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
192
+ model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
193
+ numTrees=3, featureSubsetStrategy="auto",
194
+ impurity='gini', maxDepth=4, maxBins=32)
188
195
189
196
# Evaluate model on test instances and compute test error
190
197
predictions = model.predict(testData.map(lambda x: x.features))
191
198
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
192
199
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
193
200
print('Test Error = ' + str(testErr))
194
- print('Learned classification tree model:')
201
+ print('Learned classification forest model:')
195
202
print(model.toDebugString())
196
203
{% endhighlight %}
197
-
198
- Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
199
- than separately calling ` predict ` on each data point. This is because the Python code makes calls
200
- to an underlying ` DecisionTree ` model in Scala.
201
204
</div >
202
205
203
206
</div >
@@ -215,7 +218,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
215
218
216
219
<div data-lang =" scala " >
217
220
{% highlight scala %}
218
- import org.apache.spark.mllib.tree.DecisionTree
221
+ import org.apache.spark.mllib.tree.RandomForest
219
222
import org.apache.spark.mllib.util.MLUtils
220
223
221
224
// Load and parse the data file.
@@ -224,15 +227,18 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
224
227
val splits = data.randomSplit(Array(0.7, 0.3))
225
228
val (trainingData, testData) = (splits(0), splits(1))
226
229
227
- // Train a DecisionTree model.
230
+ // Train a RandomForest model.
228
231
// Empty categoricalFeaturesInfo indicates all features are continuous.
232
+ val numClasses = 2
229
233
val categoricalFeaturesInfo = Map[ Int, Int] ( )
234
+ val numTrees = 3 // Use more in practice.
235
+ val featureSubsetStrategy = "auto" // Let the algorithm choose.
230
236
val impurity = "variance"
231
- val maxDepth = 5
237
+ val maxDepth = 4
232
238
val maxBins = 32
233
239
234
- val model = DecisionTree .trainRegressor(trainingData, categoricalFeaturesInfo, impurity ,
235
- maxDepth, maxBins)
240
+ val model = RandomForest .trainRegressor(trainingData, categoricalFeaturesInfo,
241
+ numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
236
242
237
243
// Evaluate model on test instances and compute test error
238
244
val labelsAndPredictions = testData.map { point =>
@@ -241,7 +247,7 @@ val labelsAndPredictions = testData.map { point =>
241
247
}
242
248
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
243
249
println("Test Mean Squared Error = " + testMSE)
244
- println("Learned regression tree model:\n" + model.toDebugString)
250
+ println("Learned regression forest model:\n" + model.toDebugString)
245
251
{% endhighlight %}
246
252
</div >
247
253
@@ -256,12 +262,12 @@ import org.apache.spark.api.java.JavaSparkContext;
256
262
import org.apache.spark.api.java.function.Function;
257
263
import org.apache.spark.api.java.function.PairFunction;
258
264
import org.apache.spark.mllib.regression.LabeledPoint;
259
- import org.apache.spark.mllib.tree.DecisionTree ;
260
- import org.apache.spark.mllib.tree.model.DecisionTreeModel ;
265
+ import org.apache.spark.mllib.tree.RandomForest ;
266
+ import org.apache.spark.mllib.tree.model.RandomForestModel ;
261
267
import org.apache.spark.mllib.util.MLUtils;
262
268
import org.apache.spark.SparkConf;
263
269
264
- SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree ");
270
+ SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForest ");
265
271
JavaSparkContext sc = new JavaSparkContext(sparkConf);
266
272
267
273
// Load and parse the data file.
@@ -276,11 +282,11 @@ JavaRDD<LabeledPoint> testData = splits[1];
276
282
// Empty categoricalFeaturesInfo indicates all features are continuous.
277
283
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
278
284
String impurity = "variance";
279
- Integer maxDepth = 5 ;
285
+ Integer maxDepth = 4 ;
280
286
Integer maxBins = 32;
281
287
282
- // Train a DecisionTree model.
283
- final DecisionTreeModel model = DecisionTree .trainRegressor(trainingData,
288
+ // Train a RandomForest model.
289
+ final RandomForestModel model = RandomForest .trainRegressor(trainingData,
284
290
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
285
291
286
292
// Evaluate model on test instances and compute test error
@@ -305,38 +311,36 @@ Double testMSE =
305
311
}
306
312
}) / data.count();
307
313
System.out.println("Test Mean Squared Error: " + testMSE);
308
- System.out.println("Learned regression tree model:\n" + model.toDebugString());
314
+ System.out.println("Learned regression forest model:\n" + model.toDebugString());
309
315
{% endhighlight %}
310
316
</div >
311
317
312
318
<div data-lang =" python " >
313
319
{% highlight python %}
314
- from pyspark.mllib.regression import LabeledPoint
315
- from pyspark.mllib.tree import DecisionTree
320
+ from pyspark.mllib.tree import RandomForest
316
321
from pyspark.mllib.util import MLUtils
317
322
318
323
# Load and parse the data file into an RDD of LabeledPoint.
319
324
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
320
325
# Split the data into training and test sets (30% held out for testing)
321
326
(trainingData, testData) = data.randomSplit([ 0.7, 0.3] )
322
327
323
- # Train a DecisionTree model.
328
+ # Train a RandomForest model.
324
329
# Empty categoricalFeaturesInfo indicates all features are continuous.
325
- model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
326
- impurity='variance', maxDepth=5, maxBins=32)
330
+ # Note: Use larger numTrees in practice.
331
+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
332
+ model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={},
333
+ numTrees=3, featureSubsetStrategy="auto",
334
+ impurity='variance', maxDepth=4, maxBins=32)
327
335
328
336
# Evaluate model on test instances and compute test error
329
337
predictions = model.predict(testData.map(lambda x: x.features))
330
338
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
331
339
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count())
332
340
print('Test Mean Squared Error = ' + str(testMSE))
333
- print('Learned regression tree model:')
341
+ print('Learned regression forest model:')
334
342
print(model.toDebugString())
335
343
{% endhighlight %}
336
-
337
- Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
338
- than separately calling ` predict ` on each data point. This is because the Python code makes calls
339
- to an underlying ` DecisionTree ` model in Scala.
340
344
</div >
341
345
342
346
</div >
0 commit comments