Skip to content

Commit 4b736db

Browse files
committed
[SPARK-3530][MLLIB] pipeline and parameters with examples
This PR adds package "org.apache.spark.ml" with pipeline and parameters, as discussed on the JIRA. This is a joint work of jkbradley etrain shivaram and many others who helped on the design, also with help from marmbrus and liancheng on the Spark SQL side. The design doc can be found at: https://docs.google.com/document/d/1rVwXRjWKfIb-7PI6b86ipytwbUH7irSNLF1_6dLmh8o/edit?usp=sharing **org.apache.spark.ml** This is a new package with new set of ML APIs that address practical machine learning pipelines. (Sorry for taking so long!) It will be an alpha component, so this is definitely not something set in stone. The new set of APIs, inspired by the MLI project from AMPLab and scikit-learn, takes leverage on Spark SQL's schema support and execution plan optimization. It introduces the following components that help build a practical pipeline: 1. Transformer, which transforms a dataset into another 2. Estimator, which fits models to data, where models are transformers 3. Evaluator, which evaluates model output and returns a scalar metric 4. Pipeline, a simple pipeline that consists of transformers and estimators Parameters could be supplied at fit/transform or embedded with components. 1. Param: a strong-typed parameter key with self-contained doc 2. ParamMap: a param -> value map 3. Params: trait for components with parameters For any component that implements `Params`, user can easily check the doc by calling `explainParams`: ~~~ > val lr = new LogisticRegression > lr.explainParams maxIter: max number of iterations (default: 100) regParam: regularization constant (default: 0.1) labelCol: label column name (default: label) featuresCol: features column name (default: features) ~~~ or user can check individual param: ~~~ > lr.maxIter maxIter: max number of iterations (default: 100) ~~~ **Please start with the example code in test suites and under `org.apache.spark.examples.ml`, where I put several examples:** 1. run a simple logistic regression job ~~~ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold .select('label, 'score, 'prediction).collect() .foreach(println) ~~~ 2. run logistic regression with cross-validation and grid search using areaUnderROC (default) as the metric ~~~ val lr = new LogisticRegression val lrParamMaps = new ParamGridBuilder() .addGrid(lr.regParam, Array(0.1, 100.0)) .addGrid(lr.maxIter, Array(0, 5)) .build() val eval = new BinaryClassificationEvaluator val cv = new CrossValidator() .setEstimator(lr) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setNumFolds(3) val bestModel = cv.fit(dataset) ~~~ 3. run a pipeline that consists of a standard scaler and a logistic regression component ~~~ val scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") val lr = new LogisticRegression() .setFeaturesCol(scaler.getOutputCol) val pipeline = new Pipeline() .setStages(Array(scaler, lr)) val model = pipeline.fit(dataset) val predictions = model.transform(dataset) .select('label, 'score, 'prediction) .collect() .foreach(println) ~~~ 4. a simple text classification pipeline, which recognizes "spark": ~~~ val training = sparkContext.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), LabeledDocument(3L, "hadoop mapreduce", 0.0))) val tokenizer = new Tokenizer() .setInputCol("text") .setOutputCol("words") val hashingTF = new HashingTF() .setInputCol(tokenizer.getOutputCol) .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) val model = pipeline.fit(training) val test = sparkContext.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) model.transform(test) .select('id, 'text, 'prediction, 'score) .collect() .foreach(println) ~~~ Java examples are very similar. I put example code that creates a simple text classification pipeline in Scala and Java, where a simple tokenizer is defined as a transformer outside `org.apache.spark.ml`. **What are missing now and will be added soon:** 1. ~~Runtime check of schemas. So before we touch the data, we will go through the schema and make sure column names and types match the input parameters.~~ 2. ~~Java examples.~~ 3. ~~Store training parameters in trained models.~~ 4. (later) Serialization and Python API. Author: Xiangrui Meng <[email protected]> Closes #3099 from mengxr/SPARK-3530 and squashes the following commits: 2cc93fd [Xiangrui Meng] hide APIs as much as I can 34319ba [Xiangrui Meng] use local instead local[2] for unit tests 2524251 [Xiangrui Meng] rename PipelineStage.transform to transformSchema c9daab4 [Xiangrui Meng] remove mockito version 1397ab5 [Xiangrui Meng] use sqlContext from LocalSparkContext instead of TestSQLContext 6ffc389 [Xiangrui Meng] try to fix unit test a59d8b7 [Xiangrui Meng] doc updates 977fd9d [Xiangrui Meng] add scala ml package object 6d97fe6 [Xiangrui Meng] add AlphaComponent annotation 731f0e4 [Xiangrui Meng] update package doc 0435076 [Xiangrui Meng] remove ;this from setters fa21d9b [Xiangrui Meng] update extends indentation f1091b3 [Xiangrui Meng] typo 228a9f4 [Xiangrui Meng] do not persist before calling binary classification metrics f51cd27 [Xiangrui Meng] rename default to defaultValue b3be094 [Xiangrui Meng] refactor schema transform in lr 8791e8e [Xiangrui Meng] rename copyValues to inheritValues and make it do the right thing 51f1c06 [Xiangrui Meng] remove leftover code in Transformer 494b632 [Xiangrui Meng] compure score once ad678e9 [Xiangrui Meng] more doc for Transformer 4306ed4 [Xiangrui Meng] org imports in text pipeline 6e7c1c7 [Xiangrui Meng] update pipeline 4f9e34f [Xiangrui Meng] more doc for pipeline aa5dbd4 [Xiangrui Meng] fix typo 11be383 [Xiangrui Meng] fix unit tests 3df7952 [Xiangrui Meng] clean up 986593e [Xiangrui Meng] re-org java test suites 2b11211 [Xiangrui Meng] remove external data deps 9fd4933 [Xiangrui Meng] add unit test for pipeline 2a0df46 [Xiangrui Meng] update tests 2d52e4d [Xiangrui Meng] add @AlphaComponent to package-info 27582a4 [Xiangrui Meng] doc changes 73a000b [Xiangrui Meng] add schema transformation layer 6736e87 [Xiangrui Meng] more doc / remove HasMetricName trait 80a8b5e [Xiangrui Meng] rename SimpleTransformer to UnaryTransformer 62ca2bb [Xiangrui Meng] check param parent in set/get 1622349 [Xiangrui Meng] add getModel to PipelineModel a0e0054 [Xiangrui Meng] update StandardScaler to use SimpleTransformer d0faa04 [Xiangrui Meng] remove implicit mapping from ParamMap c7f6921 [Xiangrui Meng] move ParamGridBuilder test to ParamGridBuilderSuite e246f29 [Xiangrui Meng] re-org: 7772430 [Xiangrui Meng] remove modelParams add a simple text classification pipeline b95c408 [Xiangrui Meng] remove implicits add unit tests to params bab3e5b [Xiangrui Meng] update params fe0ee92 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3530 6e86d98 [Xiangrui Meng] some code clean-up 2d040b3 [Xiangrui Meng] implement setters inside each class, add Params.copyValues [ci skip] fd751fc [Xiangrui Meng] add java-friendly versions of fit and tranform 3f810cd [Xiangrui Meng] use multi-model training api in cv 5b8f413 [Xiangrui Meng] rename model to modelParams 9d2d35d [Xiangrui Meng] test varargs and chain model params f46e927 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3530 1ef26e0 [Xiangrui Meng] specialize methods/types for Java df293ed [Xiangrui Meng] switch to setter/getter 376db0a [Xiangrui Meng] pipeline and parameters
1 parent 84324fb commit 4b736db

33 files changed

+2425
-16
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
import java.util.List;
21+
22+
import com.google.common.collect.Lists;
23+
24+
import org.apache.spark.api.java.JavaSparkContext;
25+
import org.apache.spark.ml.Pipeline;
26+
import org.apache.spark.ml.PipelineModel;
27+
import org.apache.spark.ml.PipelineStage;
28+
import org.apache.spark.ml.classification.LogisticRegression;
29+
import org.apache.spark.ml.feature.HashingTF;
30+
import org.apache.spark.ml.feature.Tokenizer;
31+
import org.apache.spark.sql.api.java.JavaSQLContext;
32+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
33+
import org.apache.spark.sql.api.java.Row;
34+
import org.apache.spark.SparkConf;
35+
36+
/**
37+
* A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
38+
* bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of
39+
* this example {@link SimpleTextClassificationPipeline}. Run with
40+
* <pre>
41+
* bin/run-example ml.JavaSimpleTextClassificationPipeline
42+
* </pre>
43+
*/
44+
public class JavaSimpleTextClassificationPipeline {
45+
46+
public static void main(String[] args) {
47+
SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
48+
JavaSparkContext jsc = new JavaSparkContext(conf);
49+
JavaSQLContext jsql = new JavaSQLContext(jsc);
50+
51+
// Prepare training documents, which are labeled.
52+
List<LabeledDocument> localTraining = Lists.newArrayList(
53+
new LabeledDocument(0L, "a b c d e spark", 1.0),
54+
new LabeledDocument(1L, "b d", 0.0),
55+
new LabeledDocument(2L, "spark f g h", 1.0),
56+
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
57+
JavaSchemaRDD training =
58+
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
59+
60+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
61+
Tokenizer tokenizer = new Tokenizer()
62+
.setInputCol("text")
63+
.setOutputCol("words");
64+
HashingTF hashingTF = new HashingTF()
65+
.setNumFeatures(1000)
66+
.setInputCol(tokenizer.getOutputCol())
67+
.setOutputCol("features");
68+
LogisticRegression lr = new LogisticRegression()
69+
.setMaxIter(10)
70+
.setRegParam(0.01);
71+
Pipeline pipeline = new Pipeline()
72+
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
73+
74+
// Fit the pipeline to training documents.
75+
PipelineModel model = pipeline.fit(training);
76+
77+
// Prepare test documents, which are unlabeled.
78+
List<Document> localTest = Lists.newArrayList(
79+
new Document(4L, "spark i j k"),
80+
new Document(5L, "l m n"),
81+
new Document(6L, "mapreduce spark"),
82+
new Document(7L, "apache hadoop"));
83+
JavaSchemaRDD test =
84+
jsql.applySchema(jsc.parallelize(localTest), Document.class);
85+
86+
// Make predictions on test documents.
87+
model.transform(test).registerAsTable("prediction");
88+
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
89+
for (Row r: predictions.collect()) {
90+
System.out.println(r);
91+
}
92+
}
93+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml
19+
20+
import scala.beans.BeanInfo
21+
22+
import org.apache.spark.{SparkConf, SparkContext}
23+
import org.apache.spark.ml.Pipeline
24+
import org.apache.spark.ml.classification.LogisticRegression
25+
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
26+
import org.apache.spark.sql.SQLContext
27+
28+
@BeanInfo
29+
case class LabeledDocument(id: Long, text: String, label: Double)
30+
31+
@BeanInfo
32+
case class Document(id: Long, text: String)
33+
34+
/**
35+
* A simple text classification pipeline that recognizes "spark" from input text. This is to show
36+
* how to create and configure an ML pipeline. Run with
37+
* {{{
38+
* bin/run-example ml.SimpleTextClassificationPipeline
39+
* }}}
40+
*/
41+
object SimpleTextClassificationPipeline {
42+
43+
def main(args: Array[String]) {
44+
val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
45+
val sc = new SparkContext(conf)
46+
val sqlContext = new SQLContext(sc)
47+
import sqlContext._
48+
49+
// Prepare training documents, which are labeled.
50+
val training = sparkContext.parallelize(Seq(
51+
LabeledDocument(0L, "a b c d e spark", 1.0),
52+
LabeledDocument(1L, "b d", 0.0),
53+
LabeledDocument(2L, "spark f g h", 1.0),
54+
LabeledDocument(3L, "hadoop mapreduce", 0.0)))
55+
56+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
57+
val tokenizer = new Tokenizer()
58+
.setInputCol("text")
59+
.setOutputCol("words")
60+
val hashingTF = new HashingTF()
61+
.setNumFeatures(1000)
62+
.setInputCol(tokenizer.getOutputCol)
63+
.setOutputCol("features")
64+
val lr = new LogisticRegression()
65+
.setMaxIter(10)
66+
.setRegParam(0.01)
67+
val pipeline = new Pipeline()
68+
.setStages(Array(tokenizer, hashingTF, lr))
69+
70+
// Fit the pipeline to training documents.
71+
val model = pipeline.fit(training)
72+
73+
// Prepare test documents, which are unlabeled.
74+
val test = sparkContext.parallelize(Seq(
75+
Document(4L, "spark i j k"),
76+
Document(5L, "l m n"),
77+
Document(6L, "mapreduce spark"),
78+
Document(7L, "apache hadoop")))
79+
80+
// Make predictions on test documents.
81+
model.transform(test)
82+
.select('id, 'text, 'score, 'prediction)
83+
.collect()
84+
.foreach(println)
85+
}
86+
}

mllib/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@
100100
<artifactId>junit-interface</artifactId>
101101
<scope>test</scope>
102102
</dependency>
103+
<dependency>
104+
<groupId>org.mockito</groupId>
105+
<artifactId>mockito-all</artifactId>
106+
<scope>test</scope>
107+
</dependency>
103108
<dependency>
104109
<groupId>org.apache.spark</groupId>
105110
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import scala.annotation.varargs
21+
import scala.collection.JavaConverters._
22+
23+
import org.apache.spark.annotation.AlphaComponent
24+
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
25+
import org.apache.spark.sql.SchemaRDD
26+
import org.apache.spark.sql.api.java.JavaSchemaRDD
27+
28+
/**
29+
* :: AlphaComponent ::
30+
* Abstract class for estimators that fit models to data.
31+
*/
32+
@AlphaComponent
33+
abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
34+
35+
/**
36+
* Fits a single model to the input data with optional parameters.
37+
*
38+
* @param dataset input dataset
39+
* @param paramPairs optional list of param pairs (overwrite embedded params)
40+
* @return fitted model
41+
*/
42+
@varargs
43+
def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
44+
val map = new ParamMap().put(paramPairs: _*)
45+
fit(dataset, map)
46+
}
47+
48+
/**
49+
* Fits a single model to the input data with provided parameter map.
50+
*
51+
* @param dataset input dataset
52+
* @param paramMap parameter map
53+
* @return fitted model
54+
*/
55+
def fit(dataset: SchemaRDD, paramMap: ParamMap): M
56+
57+
/**
58+
* Fits multiple models to the input data with multiple sets of parameters.
59+
* The default implementation uses a for loop on each parameter map.
60+
* Subclasses could overwrite this to optimize multi-model training.
61+
*
62+
* @param dataset input dataset
63+
* @param paramMaps an array of parameter maps
64+
* @return fitted models, matching the input parameter maps
65+
*/
66+
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
67+
paramMaps.map(fit(dataset, _))
68+
}
69+
70+
// Java-friendly versions of fit.
71+
72+
/**
73+
* Fits a single model to the input data with optional parameters.
74+
*
75+
* @param dataset input dataset
76+
* @param paramPairs optional list of param pairs (overwrite embedded params)
77+
* @return fitted model
78+
*/
79+
@varargs
80+
def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
81+
fit(dataset.schemaRDD, paramPairs: _*)
82+
}
83+
84+
/**
85+
* Fits a single model to the input data with provided parameter map.
86+
*
87+
* @param dataset input dataset
88+
* @param paramMap parameter map
89+
* @return fitted model
90+
*/
91+
def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
92+
fit(dataset.schemaRDD, paramMap)
93+
}
94+
95+
/**
96+
* Fits multiple models to the input data with multiple sets of parameters.
97+
*
98+
* @param dataset input dataset
99+
* @param paramMaps an array of parameter maps
100+
* @return fitted models, matching the input parameter maps
101+
*/
102+
def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
103+
fit(dataset.schemaRDD, paramMaps).asJava
104+
}
105+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.param.ParamMap
22+
import org.apache.spark.sql.SchemaRDD
23+
24+
/**
25+
* :: AlphaComponent ::
26+
* Abstract class for evaluators that compute metrics from predictions.
27+
*/
28+
@AlphaComponent
29+
abstract class Evaluator extends Identifiable {
30+
31+
/**
32+
* Evaluates the output.
33+
*
34+
* @param dataset a dataset that contains labels/observations and predictions.
35+
* @param paramMap parameter map that specifies the input columns and output metrics
36+
* @return metric
37+
*/
38+
def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
39+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import java.util.UUID
21+
22+
/**
23+
* Object with a unique id.
24+
*/
25+
private[ml] trait Identifiable extends Serializable {
26+
27+
/**
28+
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
29+
* random hex chars.
30+
*/
31+
private[ml] val uid: String =
32+
this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
33+
}

0 commit comments

Comments
 (0)