Skip to content

Commit 41ad9b1

Browse files
committed
Added examples for spark.ml: SimpleParamsExample + Java version, CrossValidatorExample + Java version. CrossValidatorExample not working yet. Added programming guide for spark.ml, but need to add CrossValidatorExample to it once CrossValidatorExample works.
1 parent 2b233f5 commit 41ad9b1

File tree

12 files changed

+990
-6
lines changed

12 files changed

+990
-6
lines changed

docs/img/ml-Pipeline.png

72.3 KB
Loading

docs/img/ml-PipelineModel.png

74.2 KB
Loading

docs/img/ml-Pipelines.pptx

55.4 KB
Binary file not shown.

docs/ml-guide.md

Lines changed: 524 additions & 0 deletions
Large diffs are not rendered by default.

docs/mllib-guide.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
layout: global
3-
title: Machine Learning Library (MLlib)
3+
title: Machine Learning Library (MLlib) Programming Guide
44
---
55

66
MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities,
@@ -34,6 +34,17 @@ MLlib is under active development.
3434
The APIs marked `Experimental`/`DeveloperApi` may change in future releases,
3535
and the migration guide below will explain all changes between releases.
3636

37+
# spark.ml: The New ML Package
38+
39+
Spark 1.2 includes a new machine learning package called `spark.ml`, currently an alpha component but potentially a successor to `spark.mllib`. The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines.
40+
41+
See the **[spark.ml programming guide](ml-guide.html)** for more information on this package.
42+
43+
Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`.
44+
45+
Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`.
46+
See the `spark.ml` programming guide linked above for more details.
47+
3748
# Dependencies
3849

3950
MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/),
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
23+
import com.google.common.collect.Lists;
24+
25+
import org.apache.spark.SparkConf;
26+
import org.apache.spark.api.java.JavaSparkContext;
27+
import org.apache.spark.ml.Model;
28+
import org.apache.spark.ml.Pipeline;
29+
import org.apache.spark.ml.PipelineStage;
30+
import org.apache.spark.ml.classification.LogisticRegression;
31+
import org.apache.spark.ml.classification.LogisticRegressionModel;
32+
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
33+
import org.apache.spark.ml.feature.HashingTF;
34+
import org.apache.spark.ml.feature.Tokenizer;
35+
import org.apache.spark.ml.param.ParamMap;
36+
import org.apache.spark.ml.tuning.CrossValidator;
37+
import org.apache.spark.ml.tuning.CrossValidatorModel;
38+
import org.apache.spark.ml.tuning.ParamGridBuilder;
39+
import org.apache.spark.sql.api.java.JavaSQLContext;
40+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
41+
import org.apache.spark.sql.api.java.Row;
42+
43+
/**
44+
* A simple example demonstrating model selection using CrossValidator.
45+
* This example also demonstrates how Pipelines are Estimators.
46+
*
47+
* This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and
48+
* {@link org.apache.spark.examples.ml.Document} defined in the Scala example
49+
* {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}.
50+
*
51+
* Run with
52+
* <pre>
53+
* bin/run-example ml.JavaCrossValidatorExample
54+
* </pre>
55+
*/
56+
public class JavaCrossValidatorExample {
57+
58+
public static void main(String[] args) {
59+
SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
60+
JavaSparkContext jsc = new JavaSparkContext(conf);
61+
JavaSQLContext jsql = new JavaSQLContext(jsc);
62+
63+
// Prepare training documents, which are labeled.
64+
List<LabeledDocument> localTraining = Lists.newArrayList(
65+
new LabeledDocument(0L, "a b c d e spark", 1.0),
66+
new LabeledDocument(1L, "b d", 0.0),
67+
new LabeledDocument(2L, "spark f g h", 1.0),
68+
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
69+
JavaSchemaRDD training =
70+
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
71+
72+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
73+
Tokenizer tokenizer = new Tokenizer()
74+
.setInputCol("text")
75+
.setOutputCol("words");
76+
HashingTF hashingTF = new HashingTF()
77+
.setNumFeatures(1000)
78+
.setInputCol(tokenizer.getOutputCol())
79+
.setOutputCol("features");
80+
LogisticRegression lr = new LogisticRegression()
81+
.setMaxIter(10)
82+
.setRegParam(0.01);
83+
Pipeline pipeline = new Pipeline()
84+
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
85+
86+
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
87+
// This will allow us to jointly choose parameters for all Pipeline stages.
88+
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
89+
CrossValidator crossval = new CrossValidator()
90+
.setEstimator(pipeline)
91+
.setEvaluator(new BinaryClassificationEvaluator());
92+
// We use a ParamGridBuilder to construct a grid of parameters to search over.
93+
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
94+
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
95+
ParamMap[] paramGrid = new ParamGridBuilder()
96+
.addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000})
97+
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
98+
.build();
99+
crossval.setEstimatorParamMaps(paramGrid);
100+
crossval.setNumFolds(2);
101+
102+
// Run cross-validation, and choose the best set of parameters.
103+
CrossValidatorModel cvModel = crossval.fit(training);
104+
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
105+
Model lrModel = cvModel.bestModel();
106+
107+
// Prepare test documents, which are unlabeled.
108+
List<Document> localTest = Lists.newArrayList(
109+
new Document(4L, "spark i j k"),
110+
new Document(5L, "l m n"),
111+
new Document(6L, "mapreduce spark"),
112+
new Document(7L, "apache hadoop"));
113+
JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
114+
115+
// Make predictions on test documents.
116+
lrModel.transform(test).registerAsTable("prediction");
117+
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
118+
for (Row r: predictions.collect()) {
119+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
120+
+ ", prediction=" + r.get(3));
121+
}
122+
}
123+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
import java.util.List;
21+
22+
import com.google.common.collect.Lists;
23+
24+
import org.apache.spark.SparkConf;
25+
import org.apache.spark.api.java.JavaSparkContext;
26+
import org.apache.spark.ml.classification.LogisticRegressionModel;
27+
import org.apache.spark.ml.param.ParamMap;
28+
import org.apache.spark.ml.classification.LogisticRegression;
29+
import org.apache.spark.mllib.linalg.Vectors;
30+
import org.apache.spark.mllib.regression.LabeledPoint;
31+
import org.apache.spark.sql.api.java.JavaSQLContext;
32+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
33+
import org.apache.spark.sql.api.java.Row;
34+
35+
/**
36+
* A simple example demonstrating ways to specify parameters for Estimators and Transformers.
37+
* Run with
38+
* {{{
39+
* bin/run-example ml.JavaSimpleParamsExample
40+
* }}}
41+
*/
42+
public class JavaSimpleParamsExample {
43+
44+
public static void main(String[] args) {
45+
SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample");
46+
JavaSparkContext jsc = new JavaSparkContext(conf);
47+
JavaSQLContext jsql = new JavaSQLContext(jsc);
48+
49+
// Prepare training data.
50+
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
51+
// into SchemaRDDs, where it uses the case class metadata to infer the schema.
52+
List<LabeledPoint> localTraining = Lists.newArrayList(
53+
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
54+
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
55+
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
56+
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
57+
JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
58+
59+
// Create a LogisticRegression instance. This instance is an Estimator.
60+
LogisticRegression lr = new LogisticRegression();
61+
// Print out the parameters, documentation, and any default values.
62+
System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");
63+
64+
// We may set parameters using setter methods.
65+
lr.setMaxIter(10)
66+
.setRegParam(0.01);
67+
68+
// Learn a LogisticRegression model. This uses the parameters stored in lr.
69+
LogisticRegressionModel model1 = lr.fit(training);
70+
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
71+
// we can view the parameters it used during fit().
72+
// This prints the parameter (name: value) pairs, where names are unique IDs for this
73+
// LogisticRegression instance.
74+
System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
75+
76+
// We may alternatively specify parameters using a ParamMap.
77+
ParamMap paramMap = new ParamMap();
78+
paramMap.put(lr.maxIter(), 20); // Specify 1 Param.
79+
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
80+
paramMap.put(lr.regParam(), 0.1);
81+
82+
// One can also combine ParamMaps.
83+
ParamMap paramMap2 = new ParamMap();
84+
paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name.
85+
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
86+
87+
// Now learn a new model using the paramMapCombined parameters.
88+
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
89+
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
90+
System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
91+
92+
// Prepare test documents.
93+
List<LabeledPoint> localTest = Lists.newArrayList(
94+
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
95+
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
96+
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
97+
JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
98+
99+
// Make predictions on test documents using the Transformer.transform() method.
100+
// LogisticRegression.transform will only use the 'features' column.
101+
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
102+
// column since we renamed the lr.scoreCol parameter previously.
103+
model2.transform(test).registerAsTable("results");
104+
JavaSchemaRDD results =
105+
jsql.sql("SELECT features, label, probability, prediction FROM results");
106+
for (Row r: results.collect()) {
107+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
108+
+ ", prediction=" + r.get(3));
109+
}
110+
}
111+
}

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,14 @@ public static void main(String[] args) {
8080
new Document(5L, "l m n"),
8181
new Document(6L, "mapreduce spark"),
8282
new Document(7L, "apache hadoop"));
83-
JavaSchemaRDD test =
84-
jsql.applySchema(jsc.parallelize(localTest), Document.class);
83+
JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
8584

8685
// Make predictions on test documents.
8786
model.transform(test).registerAsTable("prediction");
8887
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
8988
for (Row r: predictions.collect()) {
90-
System.out.println(r);
89+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
90+
+ ", prediction=" + r.get(3));
9191
}
9292
}
9393
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml
19+
20+
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.ml.Pipeline
22+
import org.apache.spark.ml.classification.LogisticRegression
23+
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
24+
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
25+
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
26+
import org.apache.spark.sql.{Row, SQLContext}
27+
28+
/**
29+
* A simple example demonstrating model selection using CrossValidator.
30+
* This example also demonstrates how Pipelines are Estimators.
31+
*
32+
* This example uses the [[LabeledDocument]] and [[Document]] case classes from
33+
* [[SimpleTextClassificationPipeline]].
34+
*
35+
* Run with
36+
* {{{
37+
* bin/run-example ml.CrossValidatorExample
38+
* }}}
39+
*/
40+
object CrossValidatorExample {
41+
42+
def main(args: Array[String]) {
43+
val conf = new SparkConf().setAppName("CrossValidatorExample")
44+
val sc = new SparkContext(conf)
45+
val sqlContext = new SQLContext(sc)
46+
import sqlContext._
47+
48+
// Prepare training documents, which are labeled.
49+
val training = sparkContext.parallelize(Seq(
50+
LabeledDocument(0L, "a b c d e spark", 1.0),
51+
LabeledDocument(1L, "b d", 0.0),
52+
LabeledDocument(2L, "spark f g h", 1.0),
53+
LabeledDocument(3L, "hadoop mapreduce", 0.0)))
54+
55+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
56+
val tokenizer = new Tokenizer()
57+
.setInputCol("text")
58+
.setOutputCol("words")
59+
val hashingTF = new HashingTF()
60+
.setInputCol(tokenizer.getOutputCol)
61+
.setOutputCol("features")
62+
val lr = new LogisticRegression()
63+
.setMaxIter(10)
64+
val pipeline = new Pipeline()
65+
.setStages(Array(tokenizer, hashingTF, lr))
66+
67+
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
68+
// This will allow us to jointly choose parameters for all Pipeline stages.
69+
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
70+
val crossval = new CrossValidator()
71+
.setEstimator(pipeline)
72+
.setEvaluator(new BinaryClassificationEvaluator)
73+
// We use a ParamGridBuilder to construct a grid of parameters to search over.
74+
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
75+
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
76+
val paramGrid = new ParamGridBuilder()
77+
.addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
78+
.addGrid(lr.regParam, Array(0.1, 0.01))
79+
.build()
80+
crossval.setEstimatorParamMaps(paramGrid)
81+
crossval.setNumFolds(2)
82+
83+
// Run cross-validation, and choose the best set of parameters.
84+
val cvModel = try {
85+
crossval.fit(training)
86+
} catch {
87+
case e: Exception =>
88+
println("\nSTACK TRACE\n")
89+
println(e.getStackTraceString)
90+
println("\nSTACK TRACE OF CAUSE\n")
91+
println(e.getCause.getStackTraceString)
92+
throw e
93+
}
94+
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
95+
val lrModel = cvModel.bestModel
96+
97+
// Prepare test documents, which are unlabeled.
98+
val test = sparkContext.parallelize(Seq(
99+
Document(4L, "spark i j k"),
100+
Document(5L, "l m n"),
101+
Document(6L, "mapreduce spark"),
102+
Document(7L, "apache hadoop")))
103+
104+
// Make predictions on test documents using the best LogisticRegression model.
105+
lrModel.transform(test)
106+
.select('id, 'text, 'score, 'prediction)
107+
.collect()
108+
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
109+
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
110+
}
111+
}
112+
}

0 commit comments

Comments
 (0)