Skip to content

Commit ad3e695

Browse files
committed
added gbt and random forest to programming guide. still need to update their examples
1 parent 6cf5076 commit ad3e695

File tree

4 files changed

+680
-6
lines changed

4 files changed

+680
-6
lines changed

docs/mllib-decision-tree.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@ displayTitle: <a href="mllib-guide.html">MLlib</a> - Decision Tree
1111
and their ensembles are popular methods for the machine learning tasks of
1212
classification and regression. Decision trees are widely used since they are easy to interpret,
1313
handle categorical features, extend to the multiclass classification setting, do not require
14-
feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble
14+
feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble
1515
algorithms such as random forests and boosting are among the top performers for classification and
1616
regression tasks.
1717

1818
MLlib supports decision trees for binary and multiclass classification and for regression,
1919
using both continuous and categorical features. The implementation partitions data by rows,
2020
allowing distributed training with millions of instances.
2121

22+
Ensembles of trees are described in [random forests](mllib-random-forest.html) and
23+
[gradient-boosted trees](mllib-gbt.html).
24+
2225
## Basic algorithm
2326

2427
The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature
@@ -42,18 +45,18 @@ impurity measure for regression (variance).
4245
<tr>
4346
<td>Gini impurity</td>
4447
<td>Classification</td>
45-
<td>$\sum_{i=1}^{M} f_i(1-f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td>
48+
<td>$\sum_{i=1}^{C} f_i(1-f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $C$ is the number of unique labels.</td>
4649
</tr>
4750
<tr>
4851
<td>Entropy</td>
4952
<td>Classification</td>
50-
<td>$\sum_{i=1}^{M} -f_ilog(f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td>
53+
<td>$\sum_{i=1}^{C} -f_ilog(f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $C$ is the number of unique labels.</td>
5154
</tr>
5255
<tr>
5356
<td>Variance</td>
5457
<td>Regression</td>
55-
<td>$\frac{1}{n} \sum_{i=1}^{N} (x_i - \mu)^2$</td><td>$y_i$ is label for an instance,
56-
$N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^n x_i$.</td>
58+
<td>$\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$</td><td>$y_i$ is label for an instance,
59+
$N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$.</td>
5760
</tr>
5861
</tbody>
5962
</table>

docs/mllib-gbt.md

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
---
2+
layout: global
3+
title: Gradient-Boosted Trees - MLlib
4+
displayTitle: <a href="mllib-guide.html">MLlib</a> - Gradient-Boosted Trees
5+
---
6+
7+
* Table of contents
8+
{:toc}
9+
10+
[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting)
11+
are ensembles of [decision trees](mllib-decision-tree.html).
12+
GBTs iteratively train decision trees in order to minimize a loss function.
13+
Like decision trees, GBTs handle categorical features,
14+
extend to the multiclass classification setting, do not require
15+
feature scaling, and are able to capture non-linearities and feature interactions.
16+
17+
MLlib supports GBTs for binary classification and for regression,
18+
using both continuous and categorical features.
19+
MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees.
20+
21+
## Basic algorithm
22+
23+
Gradient boosting iteratively trains a sequence of decision trees.
24+
On each iteration, the algorithm uses the current ensemble to predict the label of each training instance and then compares the prediction with the true label. The dataset is re-labeled to put more weight on training instances with poor predictions. Thus, in the next iteration, the decision tree will help correct for previous mistakes.
25+
26+
The specific weight mechanism is defined by a loss function (discussed below). With each iteration, GBTs further reduce this loss function on the training data.
27+
28+
### Comparison with Random Forests
29+
30+
Both GBTs and [Random Forests](mllib-random-forest.html) are algorithms for learning ensembles of trees, but the training processes are different. There are several practical trade-offs:
31+
32+
* GBTs may be able to achieve the same accuracy using fewer trees, so the model produced may be smaller (faster for test time prediction).
33+
* GBTs train one tree at a time, so they can take longer to train than random forests. Random Forests can train multiple trees in parallel.
34+
* On the other hand, it is often reasonable to use smaller trees with GBTs than with Random Forests, and training smaller trees takes less time.
35+
* Random Forests can be less prone to overfitting. Training more trees in a Random Forest reduces the likelihood of overfitting, but training more trees with GBTs increases the likelihood of overfitting.
36+
37+
In short, both algorithms can be effective. GBTs may be more useful if test time prediction speed is important. Random Forests are arguably more successful in industry.
38+
39+
### Losses
40+
41+
The table below lists the losses currently supported by GBTs in MLlib.
42+
Note that each loss is applicable to one of classification or regression, not both.
43+
44+
Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$.
45+
46+
<table class="table">
47+
<thead>
48+
<tr><th>Loss</th><th>Task</th><th>Formula</th><th>Description</th></tr>
49+
</thead>
50+
<tbody>
51+
<tr>
52+
<td>Log Loss</td>
53+
<td>Classification</td>
54+
<td>$2 \sum_{i=1}^{N} \log(1+\exp(-2 y_i F(x_i)))$</td><td>Twice binomial negative log likelihood.</td>
55+
</tr>
56+
<tr>
57+
<td>Squared Error</td>
58+
<td>Regression</td>
59+
<td>$\sum_{i=1}^{N} \frac{1}{2} (y_i - F(x_i))^2$</td><td>Also called L2 loss. Default loss for regression tasks.</td>
60+
</tr>
61+
<tr>
62+
<td>Absolute Error</td>
63+
<td>Regression</td>
64+
<td>$\sum_{i=1}^{N} |y_i - F(x_i)|$</td><td>Also called L1 loss. Can be more robust to outliers than Squared Error.</td>
65+
</tr>
66+
</tbody>
67+
</table>
68+
69+
## Usage guide
70+
71+
We include a few guidelines for using GBTs by discussing the various parameters.
72+
We omit some decision tree parameters since those are covered in the [decision tree guide](mllib-decision-tree.html).
73+
74+
* **loss**: See the section above for information on losses and their applicability to tasks (classification vs. regression). Different losses can give significantly different results, depending on the dataset.
75+
76+
* **numIterations**: This sets the number of trees in the ensemble. Each iteration produces one tree. Increasing this number makes the model more expressive, improving training data accuracy. However, test-time accuracy may suffer if this is too large.
77+
78+
* **learningRate**: This parameter should not need to be tuned. If the algorithm behavior seems unstable, decreasing this value may improve stability.
79+
80+
* **algo**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter.
81+
82+
83+
## Examples
84+
85+
TODO
86+
87+
### Classification
88+
89+
The example below demonstrates how to load a
90+
[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),
91+
parse it as an RDD of `LabeledPoint` and then
92+
perform classification using a decision tree with Gini impurity as an impurity measure and a
93+
maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
94+
95+
<div class="codetabs">
96+
97+
<div data-lang="scala">
98+
{% highlight scala %}
99+
import org.apache.spark.mllib.tree.DecisionTree
100+
import org.apache.spark.mllib.util.MLUtils
101+
102+
// Load and parse the data file.
103+
// Cache the data since we will use it again to compute training error.
104+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
105+
106+
// Train a DecisionTree model.
107+
// Empty categoricalFeaturesInfo indicates all features are continuous.
108+
val numClasses = 2
109+
val categoricalFeaturesInfo = Map[Int, Int]()
110+
val impurity = "gini"
111+
val maxDepth = 5
112+
val maxBins = 32
113+
114+
val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
115+
maxDepth, maxBins)
116+
117+
// Evaluate model on training instances and compute training error
118+
val labelAndPreds = data.map { point =>
119+
val prediction = model.predict(point.features)
120+
(point.label, prediction)
121+
}
122+
val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count
123+
println("Training Error = " + trainErr)
124+
println("Learned classification tree model:\n" + model)
125+
{% endhighlight %}
126+
</div>
127+
128+
<div data-lang="java">
129+
{% highlight java %}
130+
import java.util.HashMap;
131+
import scala.Tuple2;
132+
import org.apache.spark.api.java.function.Function2;
133+
import org.apache.spark.api.java.JavaPairRDD;
134+
import org.apache.spark.api.java.JavaRDD;
135+
import org.apache.spark.api.java.JavaSparkContext;
136+
import org.apache.spark.api.java.function.Function;
137+
import org.apache.spark.api.java.function.PairFunction;
138+
import org.apache.spark.mllib.regression.LabeledPoint;
139+
import org.apache.spark.mllib.tree.DecisionTree;
140+
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
141+
import org.apache.spark.mllib.util.MLUtils;
142+
import org.apache.spark.SparkConf;
143+
144+
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
145+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
146+
147+
// Load and parse the data file.
148+
// Cache the data since we will use it again to compute training error.
149+
String datapath = "data/mllib/sample_libsvm_data.txt";
150+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
151+
152+
// Set parameters.
153+
// Empty categoricalFeaturesInfo indicates all features are continuous.
154+
Integer numClasses = 2;
155+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
156+
String impurity = "gini";
157+
Integer maxDepth = 5;
158+
Integer maxBins = 32;
159+
160+
// Train a DecisionTree model for classification.
161+
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
162+
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
163+
164+
// Evaluate model on training instances and compute training error
165+
JavaPairRDD<Double, Double> predictionAndLabel =
166+
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
167+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
168+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
169+
}
170+
});
171+
Double trainErr =
172+
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
173+
@Override public Boolean call(Tuple2<Double, Double> pl) {
174+
return !pl._1().equals(pl._2());
175+
}
176+
}).count() / data.count();
177+
System.out.println("Training error: " + trainErr);
178+
System.out.println("Learned classification tree model:\n" + model);
179+
{% endhighlight %}
180+
</div>
181+
182+
<div data-lang="python">
183+
{% highlight python %}
184+
from pyspark.mllib.regression import LabeledPoint
185+
from pyspark.mllib.tree import DecisionTree
186+
from pyspark.mllib.util import MLUtils
187+
188+
# Load and parse the data file into an RDD of LabeledPoint.
189+
# Cache the data since we will use it again to compute training error.
190+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
191+
192+
# Train a DecisionTree model.
193+
# Empty categoricalFeaturesInfo indicates all features are continuous.
194+
model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
195+
impurity='gini', maxDepth=5, maxBins=32)
196+
197+
# Evaluate model on training instances and compute training error
198+
predictions = model.predict(data.map(lambda x: x.features))
199+
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
200+
trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count())
201+
print('Training Error = ' + str(trainErr))
202+
print('Learned classification tree model:')
203+
print(model)
204+
{% endhighlight %}
205+
206+
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
207+
than separately calling `predict` on each data point. This is because the Python code makes calls
208+
to an underlying `DecisionTree` model in Scala.
209+
</div>
210+
211+
</div>
212+
213+
### Regression
214+
215+
The example below demonstrates how to load a
216+
[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),
217+
parse it as an RDD of `LabeledPoint` and then
218+
perform regression using a decision tree with variance as an impurity measure and a maximum tree
219+
depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
220+
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
221+
222+
<div class="codetabs">
223+
224+
<div data-lang="scala">
225+
{% highlight scala %}
226+
import org.apache.spark.mllib.tree.DecisionTree
227+
import org.apache.spark.mllib.util.MLUtils
228+
229+
// Load and parse the data file.
230+
// Cache the data since we will use it again to compute training error.
231+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
232+
233+
// Train a DecisionTree model.
234+
// Empty categoricalFeaturesInfo indicates all features are continuous.
235+
val categoricalFeaturesInfo = Map[Int, Int]()
236+
val impurity = "variance"
237+
val maxDepth = 5
238+
val maxBins = 32
239+
240+
val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
241+
maxDepth, maxBins)
242+
243+
// Evaluate model on training instances and compute training error
244+
val labelsAndPredictions = data.map { point =>
245+
val prediction = model.predict(point.features)
246+
(point.label, prediction)
247+
}
248+
val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
249+
println("Training Mean Squared Error = " + trainMSE)
250+
println("Learned regression tree model:\n" + model)
251+
{% endhighlight %}
252+
</div>
253+
254+
<div data-lang="java">
255+
{% highlight java %}
256+
import java.util.HashMap;
257+
import scala.Tuple2;
258+
import org.apache.spark.api.java.function.Function2;
259+
import org.apache.spark.api.java.JavaPairRDD;
260+
import org.apache.spark.api.java.JavaRDD;
261+
import org.apache.spark.api.java.JavaSparkContext;
262+
import org.apache.spark.api.java.function.Function;
263+
import org.apache.spark.api.java.function.PairFunction;
264+
import org.apache.spark.mllib.regression.LabeledPoint;
265+
import org.apache.spark.mllib.tree.DecisionTree;
266+
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
267+
import org.apache.spark.mllib.util.MLUtils;
268+
import org.apache.spark.SparkConf;
269+
270+
// Load and parse the data file.
271+
// Cache the data since we will use it again to compute training error.
272+
String datapath = "data/mllib/sample_libsvm_data.txt";
273+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
274+
275+
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
276+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
277+
278+
// Set parameters.
279+
// Empty categoricalFeaturesInfo indicates all features are continuous.
280+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
281+
String impurity = "variance";
282+
Integer maxDepth = 5;
283+
Integer maxBins = 32;
284+
285+
// Train a DecisionTree model.
286+
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
287+
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
288+
289+
// Evaluate model on training instances and compute training error
290+
JavaPairRDD<Double, Double> predictionAndLabel =
291+
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
292+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
293+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
294+
}
295+
});
296+
Double trainMSE =
297+
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
298+
@Override public Double call(Tuple2<Double, Double> pl) {
299+
Double diff = pl._1() - pl._2();
300+
return diff * diff;
301+
}
302+
}).reduce(new Function2<Double, Double, Double>() {
303+
@Override public Double call(Double a, Double b) {
304+
return a + b;
305+
}
306+
}) / data.count();
307+
System.out.println("Training Mean Squared Error: " + trainMSE);
308+
System.out.println("Learned regression tree model:\n" + model);
309+
{% endhighlight %}
310+
</div>
311+
312+
<div data-lang="python">
313+
{% highlight python %}
314+
from pyspark.mllib.regression import LabeledPoint
315+
from pyspark.mllib.tree import DecisionTree
316+
from pyspark.mllib.util import MLUtils
317+
318+
# Load and parse the data file into an RDD of LabeledPoint.
319+
# Cache the data since we will use it again to compute training error.
320+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
321+
322+
# Train a DecisionTree model.
323+
# Empty categoricalFeaturesInfo indicates all features are continuous.
324+
model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
325+
impurity='variance', maxDepth=5, maxBins=32)
326+
327+
# Evaluate model on training instances and compute training error
328+
predictions = model.predict(data.map(lambda x: x.features))
329+
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
330+
trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count())
331+
print('Training Mean Squared Error = ' + str(trainMSE))
332+
print('Learned regression tree model:')
333+
print(model)
334+
{% endhighlight %}
335+
336+
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
337+
than separately calling `predict` on each data point. This is because the Python code makes calls
338+
to an underlying `DecisionTree` model in Scala.
339+
</div>
340+
341+
</div>

docs/mllib-guide.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv
1616
* random data generation
1717
* [Classification and regression](mllib-classification-regression.html)
1818
* [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html)
19-
* [decision trees](mllib-decision-tree.html)
2019
* [naive Bayes](mllib-naive-bayes.html)
20+
* [decision trees](mllib-decision-tree.html)
21+
* [random forests](mllib-random-forest.html)
22+
* [gradient-boosted trees](mllib-gbt.html)
2123
* [Collaborative filtering](mllib-collaborative-filtering.html)
2224
* alternating least squares (ALS)
2325
* [Clustering](mllib-clustering.html)

0 commit comments

Comments
 (0)