Skip to content

[SPARK-3530][MLLIB] pipeline and parameters with examples #3099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 55 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
376db0a
pipeline and parameters
mengxr Nov 5, 2014
df293ed
switch to setter/getter
mengxr Nov 5, 2014
1ef26e0
specialize methods/types for Java
mengxr Nov 5, 2014
f46e927
Merge remote-tracking branch 'apache/master' into SPARK-3530
mengxr Nov 6, 2014
9d2d35d
test varargs and chain model params
mengxr Nov 6, 2014
5b8f413
rename model to modelParams
mengxr Nov 6, 2014
3f810cd
use multi-model training api in cv
mengxr Nov 6, 2014
fd751fc
add java-friendly versions of fit and tranform
mengxr Nov 6, 2014
2d040b3
implement setters inside each class, add Params.copyValues [ci skip]
mengxr Nov 6, 2014
6e86d98
some code clean-up
mengxr Nov 6, 2014
fe0ee92
Merge remote-tracking branch 'apache/master' into SPARK-3530
mengxr Nov 7, 2014
bab3e5b
update params
mengxr Nov 7, 2014
b95c408
remove implicits
mengxr Nov 7, 2014
7772430
remove modelParams
mengxr Nov 7, 2014
e246f29
re-org:
mengxr Nov 9, 2014
c7f6921
move ParamGridBuilder test to ParamGridBuilderSuite
mengxr Nov 9, 2014
d0faa04
remove implicit mapping from ParamMap
mengxr Nov 9, 2014
a0e0054
update StandardScaler to use SimpleTransformer
mengxr Nov 9, 2014
1622349
add getModel to PipelineModel
mengxr Nov 9, 2014
62ca2bb
check param parent in set/get
mengxr Nov 9, 2014
80a8b5e
rename SimpleTransformer to UnaryTransformer
mengxr Nov 10, 2014
6736e87
more doc / remove HasMetricName trait
mengxr Nov 10, 2014
73a000b
add schema transformation layer
mengxr Nov 10, 2014
27582a4
doc changes
mengxr Nov 10, 2014
2d52e4d
add @AlphaComponent to package-info
mengxr Nov 10, 2014
2a0df46
update tests
mengxr Nov 10, 2014
9fd4933
add unit test for pipeline
mengxr Nov 10, 2014
2b11211
remove external data deps
mengxr Nov 10, 2014
986593e
re-org java test suites
mengxr Nov 10, 2014
3df7952
clean up
mengxr Nov 10, 2014
11be383
fix unit tests
mengxr Nov 10, 2014
aa5dbd4
fix typo
mengxr Nov 10, 2014
4f9e34f
more doc for pipeline
mengxr Nov 10, 2014
6e7c1c7
update pipeline
mengxr Nov 11, 2014
4306ed4
org imports in text pipeline
mengxr Nov 11, 2014
ad678e9
more doc for Transformer
mengxr Nov 11, 2014
494b632
compure score once
mengxr Nov 11, 2014
51f1c06
remove leftover code in Transformer
mengxr Nov 11, 2014
8791e8e
rename copyValues to inheritValues and make it do the right thing
mengxr Nov 11, 2014
b3be094
refactor schema transform in lr
mengxr Nov 11, 2014
f51cd27
rename default to defaultValue
mengxr Nov 11, 2014
228a9f4
do not persist before calling binary classification metrics
mengxr Nov 11, 2014
f1091b3
typo
mengxr Nov 11, 2014
fa21d9b
update extends indentation
mengxr Nov 11, 2014
0435076
remove ;this from setters
mengxr Nov 11, 2014
731f0e4
update package doc
mengxr Nov 11, 2014
6d97fe6
add AlphaComponent annotation
mengxr Nov 11, 2014
977fd9d
add scala ml package object
mengxr Nov 11, 2014
a59d8b7
doc updates
mengxr Nov 11, 2014
6ffc389
try to fix unit test
mengxr Nov 11, 2014
1397ab5
use sqlContext from LocalSparkContext instead of TestSQLContext
mengxr Nov 11, 2014
c9daab4
remove mockito version
mengxr Nov 11, 2014
2524251
rename PipelineStage.transform to transformSchema
mengxr Nov 11, 2014
34319ba
use local instead local[2] for unit tests
mengxr Nov 12, 2014
2cc93fd
hide APIs as much as I can
mengxr Nov 12, 2014
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import java.util.List;

import com.google.common.collect.Lists;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.api.java.JavaSQLContext;
import org.apache.spark.sql.api.java.JavaSchemaRDD;
import org.apache.spark.sql.api.java.Row;
import org.apache.spark.SparkConf;

/**
* A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
* bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of
* this example {@link SimpleTextClassificationPipeline}. Run with
* <pre>
* bin/run-example ml.JavaSimpleTextClassificationPipeline
* </pre>
*/
public class JavaSimpleTextClassificationPipeline {

public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
JavaSparkContext jsc = new JavaSparkContext(conf);
JavaSQLContext jsql = new JavaSQLContext(jsc);

// Prepare training documents, which are labeled.
List<LabeledDocument> localTraining = Lists.newArrayList(
new LabeledDocument(0L, "a b c d e spark", 1.0),
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
JavaSchemaRDD training =
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words");
HashingTF hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(tokenizer.getOutputCol())
.setOutputCol("features");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01);
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});

// Fit the pipeline to training documents.
PipelineModel model = pipeline.fit(training);

// Prepare test documents, which are unlabeled.
List<Document> localTest = Lists.newArrayList(
new Document(4L, "spark i j k"),
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
JavaSchemaRDD test =
jsql.applySchema(jsc.parallelize(localTest), Document.class);

// Make predictions on test documents.
model.transform(test).registerAsTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println(r);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml

import scala.beans.BeanInfo

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.sql.SQLContext

@BeanInfo
case class LabeledDocument(id: Long, text: String, label: Double)

@BeanInfo
case class Document(id: Long, text: String)

/**
* A simple text classification pipeline that recognizes "spark" from input text. This is to show
* how to create and configure an ML pipeline. Run with
* {{{
* bin/run-example ml.SimpleTextClassificationPipeline
* }}}
*/
object SimpleTextClassificationPipeline {

def main(args: Array[String]) {
val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._

// Prepare training documents, which are labeled.
val training = sparkContext.parallelize(Seq(
LabeledDocument(0L, "a b c d e spark", 1.0),
LabeledDocument(1L, "b d", 0.0),
LabeledDocument(2L, "spark f g h", 1.0),
LabeledDocument(3L, "hadoop mapreduce", 0.0)))

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01)
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))

// Fit the pipeline to training documents.
val model = pipeline.fit(training)

// Prepare test documents, which are unlabeled.
val test = sparkContext.parallelize(Seq(
Document(4L, "spark i j k"),
Document(5L, "l m n"),
Document(6L, "mapreduce spark"),
Document(7L, "apache hadoop")))

// Make predictions on test documents.
model.transform(test)
.select('id, 'text, 'score, 'prediction)
.collect()
.foreach(println)
}
}
5 changes: 5 additions & 0 deletions mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@
<artifactId>junit-interface</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
Expand Down
105 changes: 105 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import scala.annotation.varargs
import scala.collection.JavaConverters._

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.api.java.JavaSchemaRDD

/**
* :: AlphaComponent ::
* Abstract class for estimators that fit models to data.
*/
@AlphaComponent
abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {

/**
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
* @param paramPairs optional list of param pairs (overwrite embedded params)
* @return fitted model
*/
@varargs
def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap().put(paramPairs: _*)
fit(dataset, map)
}

/**
* Fits a single model to the input data with provided parameter map.
*
* @param dataset input dataset
* @param paramMap parameter map
* @return fitted model
*/
def fit(dataset: SchemaRDD, paramMap: ParamMap): M

/**
* Fits multiple models to the input data with multiple sets of parameters.
* The default implementation uses a for loop on each parameter map.
* Subclasses could overwrite this to optimize multi-model training.
*
* @param dataset input dataset
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.toArray ?

}

// Java-friendly versions of fit.

/**
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
* @param paramPairs optional list of param pairs (overwrite embedded params)
* @return fitted model
*/
@varargs
def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
fit(dataset.schemaRDD, paramPairs: _*)
}

/**
* Fits a single model to the input data with provided parameter map.
*
* @param dataset input dataset
* @param paramMap parameter map
* @return fitted model
*/
def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
fit(dataset.schemaRDD, paramMap)
}

/**
* Fits multiple models to the input data with multiple sets of parameters.
*
* @param dataset input dataset
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
fit(dataset.schemaRDD, paramMaps).asJava
}
}
39 changes: 39 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.SchemaRDD

/**
* :: AlphaComponent ::
* Abstract class for evaluators that compute metrics from predictions.
*/
@AlphaComponent
abstract class Evaluator extends Identifiable {

/**
* Evaluates the output.
*
* @param dataset a dataset that contains labels/observations and predictions.
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
}
33 changes: 33 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import java.util.UUID

/**
* Object with a unique id.
*/
private[ml] trait Identifiable extends Serializable {

/**
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
* random hex chars.
*/
private[ml] val uid: String =
this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
}
Loading