|
| 1 | +--- |
| 2 | +layout: global |
| 3 | +title: Gradient-Boosted Trees - MLlib |
| 4 | +displayTitle: <a href="mllib-guide.html">MLlib</a> - Gradient-Boosted Trees |
| 5 | +--- |
| 6 | + |
| 7 | +* Table of contents |
| 8 | +{:toc} |
| 9 | + |
| 10 | +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) |
| 11 | +are ensembles of [decision trees](mllib-decision-tree.html). |
| 12 | +GBTs iteratively train decision trees in order to minimize a loss function. |
| 13 | +Like decision trees, GBTs handle categorical features, |
| 14 | +extend to the multiclass classification setting, do not require |
| 15 | +feature scaling, and are able to capture non-linearities and feature interactions. |
| 16 | + |
| 17 | +MLlib supports GBTs for binary classification and for regression, |
| 18 | +using both continuous and categorical features. |
| 19 | +MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. |
| 20 | + |
| 21 | +## Basic algorithm |
| 22 | + |
| 23 | +Gradient boosting iteratively trains a sequence of decision trees. |
| 24 | +On each iteration, the algorithm uses the current ensemble to predict the label of each training instance and then compares the prediction with the true label. The dataset is re-labeled to put more weight on training instances with poor predictions. Thus, in the next iteration, the decision tree will help correct for previous mistakes. |
| 25 | + |
| 26 | +The specific weight mechanism is defined by a loss function (discussed below). With each iteration, GBTs further reduce this loss function on the training data. |
| 27 | + |
| 28 | +### Comparison with Random Forests |
| 29 | + |
| 30 | +Both GBTs and [Random Forests](mllib-random-forest.html) are algorithms for learning ensembles of trees, but the training processes are different. There are several practical trade-offs: |
| 31 | + |
| 32 | + * GBTs may be able to achieve the same accuracy using fewer trees, so the model produced may be smaller (faster for test time prediction). |
| 33 | + * GBTs train one tree at a time, so they can take longer to train than random forests. Random Forests can train multiple trees in parallel. |
| 34 | + * On the other hand, it is often reasonable to use smaller trees with GBTs than with Random Forests, and training smaller trees takes less time. |
| 35 | + * Random Forests can be less prone to overfitting. Training more trees in a Random Forest reduces the likelihood of overfitting, but training more trees with GBTs increases the likelihood of overfitting. |
| 36 | + |
| 37 | +In short, both algorithms can be effective. GBTs may be more useful if test time prediction speed is important. Random Forests are arguably more successful in industry. |
| 38 | + |
| 39 | +### Losses |
| 40 | + |
| 41 | +The table below lists the losses currently supported by GBTs in MLlib. |
| 42 | +Note that each loss is applicable to one of classification or regression, not both. |
| 43 | + |
| 44 | +Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$. |
| 45 | + |
| 46 | +<table class="table"> |
| 47 | + <thead> |
| 48 | + <tr><th>Loss</th><th>Task</th><th>Formula</th><th>Description</th></tr> |
| 49 | + </thead> |
| 50 | + <tbody> |
| 51 | + <tr> |
| 52 | + <td>Log Loss</td> |
| 53 | + <td>Classification</td> |
| 54 | + <td>$2 \sum_{i=1}^{N} \log(1+\exp(-2 y_i F(x_i)))$</td><td>Twice binomial negative log likelihood.</td> |
| 55 | + </tr> |
| 56 | + <tr> |
| 57 | + <td>Squared Error</td> |
| 58 | + <td>Regression</td> |
| 59 | + <td>$\sum_{i=1}^{N} \frac{1}{2} (y_i - F(x_i))^2$</td><td>Also called L2 loss. Default loss for regression tasks.</td> |
| 60 | + </tr> |
| 61 | + <tr> |
| 62 | + <td>Absolute Error</td> |
| 63 | + <td>Regression</td> |
| 64 | + <td>$\sum_{i=1}^{N} |y_i - F(x_i)|$</td><td>Also called L1 loss. Can be more robust to outliers than Squared Error.</td> |
| 65 | + </tr> |
| 66 | + </tbody> |
| 67 | +</table> |
| 68 | + |
| 69 | +## Usage guide |
| 70 | + |
| 71 | +We include a few guidelines for using GBTs by discussing the various parameters. |
| 72 | +We omit some decision tree parameters since those are covered in the [decision tree guide](mllib-decision-tree.html). |
| 73 | + |
| 74 | +* **loss**: See the section above for information on losses and their applicability to tasks (classification vs. regression). Different losses can give significantly different results, depending on the dataset. |
| 75 | + |
| 76 | +* **numIterations**: This sets the number of trees in the ensemble. Each iteration produces one tree. Increasing this number makes the model more expressive, improving training data accuracy. However, test-time accuracy may suffer if this is too large. |
| 77 | + |
| 78 | +* **learningRate**: This parameter should not need to be tuned. If the algorithm behavior seems unstable, decreasing this value may improve stability. |
| 79 | + |
| 80 | +* **algo**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. |
| 81 | + |
| 82 | + |
| 83 | +## Examples |
| 84 | + |
| 85 | +TODO |
| 86 | + |
| 87 | +### Classification |
| 88 | + |
| 89 | +The example below demonstrates how to load a |
| 90 | +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), |
| 91 | +parse it as an RDD of `LabeledPoint` and then |
| 92 | +perform classification using a decision tree with Gini impurity as an impurity measure and a |
| 93 | +maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy. |
| 94 | + |
| 95 | +<div class="codetabs"> |
| 96 | + |
| 97 | +<div data-lang="scala"> |
| 98 | +{% highlight scala %} |
| 99 | +import org.apache.spark.mllib.tree.DecisionTree |
| 100 | +import org.apache.spark.mllib.util.MLUtils |
| 101 | + |
| 102 | +// Load and parse the data file. |
| 103 | +// Cache the data since we will use it again to compute training error. |
| 104 | +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache() |
| 105 | + |
| 106 | +// Train a DecisionTree model. |
| 107 | +// Empty categoricalFeaturesInfo indicates all features are continuous. |
| 108 | +val numClasses = 2 |
| 109 | +val categoricalFeaturesInfo = Map[Int, Int]() |
| 110 | +val impurity = "gini" |
| 111 | +val maxDepth = 5 |
| 112 | +val maxBins = 32 |
| 113 | + |
| 114 | +val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity, |
| 115 | + maxDepth, maxBins) |
| 116 | + |
| 117 | +// Evaluate model on training instances and compute training error |
| 118 | +val labelAndPreds = data.map { point => |
| 119 | + val prediction = model.predict(point.features) |
| 120 | + (point.label, prediction) |
| 121 | +} |
| 122 | +val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count |
| 123 | +println("Training Error = " + trainErr) |
| 124 | +println("Learned classification tree model:\n" + model) |
| 125 | +{% endhighlight %} |
| 126 | +</div> |
| 127 | + |
| 128 | +<div data-lang="java"> |
| 129 | +{% highlight java %} |
| 130 | +import java.util.HashMap; |
| 131 | +import scala.Tuple2; |
| 132 | +import org.apache.spark.api.java.function.Function2; |
| 133 | +import org.apache.spark.api.java.JavaPairRDD; |
| 134 | +import org.apache.spark.api.java.JavaRDD; |
| 135 | +import org.apache.spark.api.java.JavaSparkContext; |
| 136 | +import org.apache.spark.api.java.function.Function; |
| 137 | +import org.apache.spark.api.java.function.PairFunction; |
| 138 | +import org.apache.spark.mllib.regression.LabeledPoint; |
| 139 | +import org.apache.spark.mllib.tree.DecisionTree; |
| 140 | +import org.apache.spark.mllib.tree.model.DecisionTreeModel; |
| 141 | +import org.apache.spark.mllib.util.MLUtils; |
| 142 | +import org.apache.spark.SparkConf; |
| 143 | + |
| 144 | +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); |
| 145 | +JavaSparkContext sc = new JavaSparkContext(sparkConf); |
| 146 | + |
| 147 | +// Load and parse the data file. |
| 148 | +// Cache the data since we will use it again to compute training error. |
| 149 | +String datapath = "data/mllib/sample_libsvm_data.txt"; |
| 150 | +JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); |
| 151 | + |
| 152 | +// Set parameters. |
| 153 | +// Empty categoricalFeaturesInfo indicates all features are continuous. |
| 154 | +Integer numClasses = 2; |
| 155 | +HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); |
| 156 | +String impurity = "gini"; |
| 157 | +Integer maxDepth = 5; |
| 158 | +Integer maxBins = 32; |
| 159 | + |
| 160 | +// Train a DecisionTree model for classification. |
| 161 | +final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, |
| 162 | + categoricalFeaturesInfo, impurity, maxDepth, maxBins); |
| 163 | + |
| 164 | +// Evaluate model on training instances and compute training error |
| 165 | +JavaPairRDD<Double, Double> predictionAndLabel = |
| 166 | + data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { |
| 167 | + @Override public Tuple2<Double, Double> call(LabeledPoint p) { |
| 168 | + return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); |
| 169 | + } |
| 170 | + }); |
| 171 | +Double trainErr = |
| 172 | + 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { |
| 173 | + @Override public Boolean call(Tuple2<Double, Double> pl) { |
| 174 | + return !pl._1().equals(pl._2()); |
| 175 | + } |
| 176 | + }).count() / data.count(); |
| 177 | +System.out.println("Training error: " + trainErr); |
| 178 | +System.out.println("Learned classification tree model:\n" + model); |
| 179 | +{% endhighlight %} |
| 180 | +</div> |
| 181 | + |
| 182 | +<div data-lang="python"> |
| 183 | +{% highlight python %} |
| 184 | +from pyspark.mllib.regression import LabeledPoint |
| 185 | +from pyspark.mllib.tree import DecisionTree |
| 186 | +from pyspark.mllib.util import MLUtils |
| 187 | + |
| 188 | +# Load and parse the data file into an RDD of LabeledPoint. |
| 189 | +# Cache the data since we will use it again to compute training error. |
| 190 | +data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache() |
| 191 | + |
| 192 | +# Train a DecisionTree model. |
| 193 | +# Empty categoricalFeaturesInfo indicates all features are continuous. |
| 194 | +model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={}, |
| 195 | + impurity='gini', maxDepth=5, maxBins=32) |
| 196 | + |
| 197 | +# Evaluate model on training instances and compute training error |
| 198 | +predictions = model.predict(data.map(lambda x: x.features)) |
| 199 | +labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) |
| 200 | +trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count()) |
| 201 | +print('Training Error = ' + str(trainErr)) |
| 202 | +print('Learned classification tree model:') |
| 203 | +print(model) |
| 204 | +{% endhighlight %} |
| 205 | + |
| 206 | +Note: When making predictions for a dataset, it is more efficient to do batch prediction rather |
| 207 | +than separately calling `predict` on each data point. This is because the Python code makes calls |
| 208 | +to an underlying `DecisionTree` model in Scala. |
| 209 | +</div> |
| 210 | + |
| 211 | +</div> |
| 212 | + |
| 213 | +### Regression |
| 214 | + |
| 215 | +The example below demonstrates how to load a |
| 216 | +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), |
| 217 | +parse it as an RDD of `LabeledPoint` and then |
| 218 | +perform regression using a decision tree with variance as an impurity measure and a maximum tree |
| 219 | +depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate |
| 220 | +[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). |
| 221 | + |
| 222 | +<div class="codetabs"> |
| 223 | + |
| 224 | +<div data-lang="scala"> |
| 225 | +{% highlight scala %} |
| 226 | +import org.apache.spark.mllib.tree.DecisionTree |
| 227 | +import org.apache.spark.mllib.util.MLUtils |
| 228 | + |
| 229 | +// Load and parse the data file. |
| 230 | +// Cache the data since we will use it again to compute training error. |
| 231 | +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache() |
| 232 | + |
| 233 | +// Train a DecisionTree model. |
| 234 | +// Empty categoricalFeaturesInfo indicates all features are continuous. |
| 235 | +val categoricalFeaturesInfo = Map[Int, Int]() |
| 236 | +val impurity = "variance" |
| 237 | +val maxDepth = 5 |
| 238 | +val maxBins = 32 |
| 239 | + |
| 240 | +val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity, |
| 241 | + maxDepth, maxBins) |
| 242 | + |
| 243 | +// Evaluate model on training instances and compute training error |
| 244 | +val labelsAndPredictions = data.map { point => |
| 245 | + val prediction = model.predict(point.features) |
| 246 | + (point.label, prediction) |
| 247 | +} |
| 248 | +val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() |
| 249 | +println("Training Mean Squared Error = " + trainMSE) |
| 250 | +println("Learned regression tree model:\n" + model) |
| 251 | +{% endhighlight %} |
| 252 | +</div> |
| 253 | + |
| 254 | +<div data-lang="java"> |
| 255 | +{% highlight java %} |
| 256 | +import java.util.HashMap; |
| 257 | +import scala.Tuple2; |
| 258 | +import org.apache.spark.api.java.function.Function2; |
| 259 | +import org.apache.spark.api.java.JavaPairRDD; |
| 260 | +import org.apache.spark.api.java.JavaRDD; |
| 261 | +import org.apache.spark.api.java.JavaSparkContext; |
| 262 | +import org.apache.spark.api.java.function.Function; |
| 263 | +import org.apache.spark.api.java.function.PairFunction; |
| 264 | +import org.apache.spark.mllib.regression.LabeledPoint; |
| 265 | +import org.apache.spark.mllib.tree.DecisionTree; |
| 266 | +import org.apache.spark.mllib.tree.model.DecisionTreeModel; |
| 267 | +import org.apache.spark.mllib.util.MLUtils; |
| 268 | +import org.apache.spark.SparkConf; |
| 269 | + |
| 270 | +// Load and parse the data file. |
| 271 | +// Cache the data since we will use it again to compute training error. |
| 272 | +String datapath = "data/mllib/sample_libsvm_data.txt"; |
| 273 | +JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); |
| 274 | + |
| 275 | +SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); |
| 276 | +JavaSparkContext sc = new JavaSparkContext(sparkConf); |
| 277 | + |
| 278 | +// Set parameters. |
| 279 | +// Empty categoricalFeaturesInfo indicates all features are continuous. |
| 280 | +HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); |
| 281 | +String impurity = "variance"; |
| 282 | +Integer maxDepth = 5; |
| 283 | +Integer maxBins = 32; |
| 284 | + |
| 285 | +// Train a DecisionTree model. |
| 286 | +final DecisionTreeModel model = DecisionTree.trainRegressor(data, |
| 287 | + categoricalFeaturesInfo, impurity, maxDepth, maxBins); |
| 288 | + |
| 289 | +// Evaluate model on training instances and compute training error |
| 290 | +JavaPairRDD<Double, Double> predictionAndLabel = |
| 291 | + data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { |
| 292 | + @Override public Tuple2<Double, Double> call(LabeledPoint p) { |
| 293 | + return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); |
| 294 | + } |
| 295 | + }); |
| 296 | +Double trainMSE = |
| 297 | + predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { |
| 298 | + @Override public Double call(Tuple2<Double, Double> pl) { |
| 299 | + Double diff = pl._1() - pl._2(); |
| 300 | + return diff * diff; |
| 301 | + } |
| 302 | + }).reduce(new Function2<Double, Double, Double>() { |
| 303 | + @Override public Double call(Double a, Double b) { |
| 304 | + return a + b; |
| 305 | + } |
| 306 | + }) / data.count(); |
| 307 | +System.out.println("Training Mean Squared Error: " + trainMSE); |
| 308 | +System.out.println("Learned regression tree model:\n" + model); |
| 309 | +{% endhighlight %} |
| 310 | +</div> |
| 311 | + |
| 312 | +<div data-lang="python"> |
| 313 | +{% highlight python %} |
| 314 | +from pyspark.mllib.regression import LabeledPoint |
| 315 | +from pyspark.mllib.tree import DecisionTree |
| 316 | +from pyspark.mllib.util import MLUtils |
| 317 | + |
| 318 | +# Load and parse the data file into an RDD of LabeledPoint. |
| 319 | +# Cache the data since we will use it again to compute training error. |
| 320 | +data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache() |
| 321 | + |
| 322 | +# Train a DecisionTree model. |
| 323 | +# Empty categoricalFeaturesInfo indicates all features are continuous. |
| 324 | +model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={}, |
| 325 | + impurity='variance', maxDepth=5, maxBins=32) |
| 326 | + |
| 327 | +# Evaluate model on training instances and compute training error |
| 328 | +predictions = model.predict(data.map(lambda x: x.features)) |
| 329 | +labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions) |
| 330 | +trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count()) |
| 331 | +print('Training Mean Squared Error = ' + str(trainMSE)) |
| 332 | +print('Learned regression tree model:') |
| 333 | +print(model) |
| 334 | +{% endhighlight %} |
| 335 | + |
| 336 | +Note: When making predictions for a dataset, it is more efficient to do batch prediction rather |
| 337 | +than separately calling `predict` on each data point. This is because the Python code makes calls |
| 338 | +to an underlying `DecisionTree` model in Scala. |
| 339 | +</div> |
| 340 | + |
| 341 | +</div> |
0 commit comments