Skip to content

Commit b32f21a

Browse files
committed
add doc for elastic net in sparkml
1 parent 937eef1 commit b32f21a

File tree

2 files changed

+82
-25
lines changed

2 files changed

+82
-25
lines changed

docs/ml-guide.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,60 @@ There are now several algorithms in the Pipelines API which are not in the lower
157157
* [Feature Extraction, Transformation, and Selection](ml-features.html)
158158
* [Ensembles](ml-ensembles.html)
159159

160+
## Linear Methods with Elastic Net Regularization
161+
162+
[Elastic net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf) is a hybrid of L1 and L2 regularization. Mathematically it is defined as a linear combination of the L1-norm and the L2-norm:
163+
`\[
164+
\alpha \lambda_1\|v\|_1 + (1-\alpha) \frac{\lambda_2}{2}\|v\|_2, \alpha \in [0, 1].
165+
\]`
166+
By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. We implement both linear regression and logistict regression with elastic net regularization.
167+
168+
**Examples**
169+
170+
<div class="codetabs">
171+
172+
<div data-lang="scala" markdown="1">
173+
The following code snippet illustrates how to load a sample dataset, execute a
174+
training algorithm on this training data using a static method in the algorithm
175+
object, and make predictions with the resulting model to compute the training
176+
error.
177+
178+
{% highlight scala %}
179+
180+
{% endhighlight %}
181+
182+
</div>
183+
184+
<div data-lang="java" markdown="1">
185+
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
186+
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
187+
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
188+
calling `.rdd()` on your `JavaRDD` object. A self-contained application example
189+
that is equivalent to the provided example in Scala is given bellow:
190+
191+
{% highlight java %}
192+
193+
{% endhighlight %}
194+
</div>
195+
196+
<div data-lang="python" markdown="1">
197+
The following example shows how to load a sample dataset, build Logistic Regression model,
198+
and make predictions with the resulting model to compute the training error.
199+
200+
Note that the Python API does not yet support model save/load but will in the future.
201+
202+
{% highlight python %}
203+
204+
{% endhighlight %}
205+
206+
</div>
207+
208+
</div>
209+
210+
### Optimization
211+
212+
The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
213+
(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net.
160214

161215
# Code Examples
162216

docs/mllib-linear-methods.md

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,26 @@ displayTitle: <a href="mllib-guide.html">MLlib</a> - Linear Methods
1010

1111
`\[
1212
\newcommand{\R}{\mathbb{R}}
13-
\newcommand{\E}{\mathbb{E}}
13+
\newcommand{\E}{\mathbb{E}}
1414
\newcommand{\x}{\mathbf{x}}
1515
\newcommand{\y}{\mathbf{y}}
1616
\newcommand{\wv}{\mathbf{w}}
1717
\newcommand{\av}{\mathbf{\alpha}}
1818
\newcommand{\bv}{\mathbf{b}}
1919
\newcommand{\N}{\mathbb{N}}
2020
\newcommand{\id}{\mathbf{I}}
21-
\newcommand{\ind}{\mathbf{1}}
22-
\newcommand{\0}{\mathbf{0}}
23-
\newcommand{\unit}{\mathbf{e}}
24-
\newcommand{\one}{\mathbf{1}}
21+
\newcommand{\ind}{\mathbf{1}}
22+
\newcommand{\0}{\mathbf{0}}
23+
\newcommand{\unit}{\mathbf{e}}
24+
\newcommand{\one}{\mathbf{1}}
2525
\newcommand{\zero}{\mathbf{0}}
2626
\]`
2727

2828
## Mathematical formulation
2929

3030
Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e.
3131
the task of finding a minimizer of a convex function `$f$` that depends on a variable vector
32-
`$\wv$` (called `weights` in the code), which has `$d$` entries.
32+
`$\wv$` (called `weights` in the code), which has `$d$` entries.
3333
Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where
3434
the objective function is of the form
3535
`\begin{equation}
@@ -39,7 +39,7 @@ the objective function is of the form
3939
\ .
4040
\end{equation}`
4141
Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and
42-
`$y_i\in\R$` are their corresponding labels, which we want to predict.
42+
`$y_i\in\R$` are their corresponding labels, which we want to predict.
4343
We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$.
4444
Several of MLlib's classification and regression algorithms fall into this category,
4545
and are discussed here.
@@ -99,6 +99,9 @@ regularizers in MLlib:
9999
<tr>
100100
<td>L1</td><td>$\|\wv\|_1$</td><td>$\mathrm{sign}(\wv)$</td>
101101
</tr>
102+
<tr>
103+
<td>elastic net</td><td>$\alpha \lambda_1\|\wv\|_1 + (1-\alpha)\lambda_2\|\wv\|_2$</td><td>$\alpha \lambda_1 \mathrm{sign}(\wv) + (1-\alpha)\lambda_2 \wv$</td>
104+
</tr>
102105
</tbody>
103106
</table>
104107

@@ -107,7 +110,7 @@ of `$\wv$`.
107110

108111
L2-regularized problems are generally easier to solve than L1-regularized due to smoothness.
109112
However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection.
110-
It is not recommended to train models without any regularization,
113+
[Elastic net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization,
111114
especially when the number of training examples is small.
112115

113116
### Optimization
@@ -531,16 +534,16 @@ print("Training Error = " + str(trainErr))
531534
### Linear least squares, Lasso, and ridge regression
532535

533536

534-
Linear least squares is the most common formulation for regression problems.
537+
Linear least squares is the most common formulation for regression problems.
535538
It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss
536539
function in the formulation given by the squared loss:
537540
`\[
538541
L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2.
539542
\]`
540543

541544
Various related regression methods are derived by using different types of regularization:
542-
[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or
543-
[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses
545+
[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or
546+
[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses
544547
no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2
545548
regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1
546549
regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is
@@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
552555

553556
<div data-lang="scala" markdown="1">
554557
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
555-
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
558+
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
556559
values. We compute the mean squared error at the end to evaluate
557560
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
558561

@@ -614,7 +617,7 @@ public class LinearRegression {
614617
public static void main(String[] args) {
615618
SparkConf conf = new SparkConf().setAppName("Linear Regression Example");
616619
JavaSparkContext sc = new JavaSparkContext(conf);
617-
620+
618621
// Load and parse the data
619622
String path = "data/mllib/ridge-data/lpsa.data";
620623
JavaRDD<String> data = sc.textFile(path);
@@ -634,7 +637,7 @@ public class LinearRegression {
634637

635638
// Building the model
636639
int numIterations = 100;
637-
final LinearRegressionModel model =
640+
final LinearRegressionModel model =
638641
LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations);
639642

640643
// Evaluate model on training examples and compute training error
@@ -665,7 +668,7 @@ public class LinearRegression {
665668

666669
<div data-lang="python" markdown="1">
667670
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
668-
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
671+
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
669672
values. We compute the mean squared error at the end to evaluate
670673
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
671674

@@ -702,8 +705,8 @@ a dependency.
702705

703706
###Streaming linear regression
704707

705-
When data arrive in a streaming fashion, it is useful to fit regression models online,
706-
updating the parameters of the model as new data arrives. MLlib currently supports
708+
When data arrive in a streaming fashion, it is useful to fit regression models online,
709+
updating the parameters of the model as new data arrives. MLlib currently supports
707710
streaming linear regression using ordinary least squares. The fitting is similar
708711
to that performed offline, except fitting occurs on each batch of data, so that
709712
the model continually updates to reflect the data from the stream.
@@ -718,7 +721,7 @@ online to the first stream, and make predictions on the second stream.
718721

719722
<div data-lang="scala" markdown="1">
720723

721-
First, we import the necessary classes for parsing our input data and creating the model.
724+
First, we import the necessary classes for parsing our input data and creating the model.
722725

723726
{% highlight scala %}
724727

@@ -730,7 +733,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
730733

731734
Then we make input streams for training and testing data. We assume a StreamingContext `ssc`
732735
has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing)
733-
for more info. For this example, we use labeled points in training and testing streams,
736+
for more info. For this example, we use labeled points in training and testing streams,
734737
but in practice you will likely want to use unlabeled vectors for test data.
735738

736739
{% highlight scala %}
@@ -750,7 +753,7 @@ val model = new StreamingLinearRegressionWithSGD()
750753

751754
{% endhighlight %}
752755

753-
Now we register the streams for training and testing and start the job.
756+
Now we register the streams for training and testing and start the job.
754757
Printing predictions alongside true labels lets us easily see the result.
755758

756759
{% highlight scala %}
@@ -760,14 +763,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
760763

761764
ssc.start()
762765
ssc.awaitTermination()
763-
766+
764767
{% endhighlight %}
765768

766769
We can now save text files with data to the training or testing folders.
767-
Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
768-
and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir`
769-
the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions.
770-
As you feed more data to the training directory, the predictions
770+
Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
771+
and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir`
772+
the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions.
773+
As you feed more data to the training directory, the predictions
771774
will get better!
772775

773776
</div>

0 commit comments

Comments
 (0)