Skip to content

Commit e246f29

Browse files
committed
re-org:
1. move SimpleTokenizer to examples 2. move LR to classification, HashingTF and StandardScaler to feature, CV and ParamGridBuilder to tuning 3. define SimpleTransformer
1 parent 7772430 commit e246f29

File tree

16 files changed

+362
-222
lines changed

16 files changed

+362
-222
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
import java.util.List;
21+
22+
import com.google.common.collect.Lists;
23+
24+
import org.apache.spark.api.java.JavaSparkContext;
25+
import org.apache.spark.ml.Pipeline;
26+
import org.apache.spark.ml.PipelineModel;
27+
import org.apache.spark.ml.PipelineStage;
28+
import org.apache.spark.ml.classification.LogisticRegression;
29+
import org.apache.spark.ml.feature.HashingTF;
30+
import org.apache.spark.sql.api.java.JavaSQLContext;
31+
import org.apache.spark.sql.api.java.JavaSchemaRDD;
32+
import org.apache.spark.sql.api.java.Row;
33+
import org.apache.spark.SparkConf;
34+
35+
/**
36+
* A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
37+
* bean classes [[LabeledDocument]] and [[Document]], and the tokenizer [[SimpleTokenizer]] defined
38+
* in the Scalar counterpart of this example [[SimpleTextClassificationPipeline]]. Run with
39+
* <pre>
40+
* bin/run-example ml.JavaSimpleTextClassificationPipeline
41+
* </pre>
42+
*/
43+
public class JavaSimpleTextClassificationPipeline {
44+
45+
public static void main(String[] args) {
46+
SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
47+
JavaSparkContext jsc = new JavaSparkContext(conf);
48+
JavaSQLContext jsql = new JavaSQLContext(jsc);
49+
50+
// Prepare training documents, which are labeled.
51+
List<LabeledDocument> localTraining = Lists.newArrayList(
52+
new LabeledDocument(0L, "a b c d e spark", 1.0),
53+
new LabeledDocument(1L, "b d", 0.0),
54+
new LabeledDocument(2L, "spark f g h", 1.0),
55+
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
56+
JavaSchemaRDD training =
57+
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
58+
59+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
60+
SimpleTokenizer tokenizer = new SimpleTokenizer()
61+
.setInputCol("text")
62+
.setOutputCol("words");
63+
HashingTF hashingTF = new HashingTF()
64+
.setNumFeatures(1000)
65+
.setInputCol(tokenizer.getOutputCol())
66+
.setOutputCol("features");
67+
LogisticRegression lr = new LogisticRegression()
68+
.setMaxIter(10)
69+
.setRegParam(0.01);
70+
Pipeline pipeline = new Pipeline()
71+
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
72+
73+
// Fit the pipeline to training documents.
74+
PipelineModel model = pipeline.fit(training);
75+
76+
// Prepare test documents, which are unlabeled.
77+
List<Document> localTest = Lists.newArrayList(
78+
new Document(4L, "spark i j k"),
79+
new Document(5L, "l m n"),
80+
new Document(6L, "mapreduce spark"),
81+
new Document(7L, "apache hadoop"));
82+
JavaSchemaRDD test =
83+
jsql.applySchema(jsc.parallelize(localTest), Document.class);
84+
85+
// Make predictions on test documents.
86+
model.transform(test).registerAsTable("prediction");
87+
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
88+
for (Row r: predictions.collect()) {
89+
System.out.println(r);
90+
}
91+
}
92+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml
19+
20+
import scala.beans.BeanInfo
21+
22+
import org.apache.spark.ml.classification.LogisticRegression
23+
import org.apache.spark.ml.feature.HashingTF
24+
import org.apache.spark.ml.{Pipeline, SimpleTransformer}
25+
import org.apache.spark.sql.SQLContext
26+
import org.apache.spark.{SparkConf, SparkContext}
27+
28+
@BeanInfo
29+
case class LabeledDocument(id: Long, text: String, label: Double)
30+
31+
@BeanInfo
32+
case class Document(id: Long, text: String)
33+
34+
/**
35+
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
36+
*/
37+
class SimpleTokenizer extends SimpleTransformer[String, Seq[String], SimpleTokenizer]
38+
with Serializable {
39+
override def createTransformFunc: String => Seq[String] = _.toLowerCase.split("\\s")
40+
}
41+
42+
/**
43+
* A simple text classification pipeline that recognizes "spark" from input text. This is to show
44+
* how to define a simple tokenizer and then use it as part of a ML pipeline. Run with
45+
* {{{
46+
* bin/run-example ml.SimpleTextClassificationPipeline
47+
* }}}
48+
*/
49+
object SimpleTextClassificationPipeline {
50+
51+
def main(args: Array[String]) {
52+
val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
53+
val sc = new SparkContext(conf)
54+
val sqlContext = new SQLContext(sc)
55+
import sqlContext._
56+
57+
// Prepare training documents, which are labeled.
58+
val training = sparkContext.parallelize(Seq(
59+
LabeledDocument(0L, "a b c d e spark", 1.0),
60+
LabeledDocument(1L, "b d", 0.0),
61+
LabeledDocument(2L, "spark f g h", 1.0),
62+
LabeledDocument(3L, "hadoop mapreduce", 0.0)))
63+
64+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
65+
val tokenizer = new SimpleTokenizer()
66+
.setInputCol("text")
67+
.setOutputCol("words")
68+
val hashingTF = new HashingTF()
69+
.setNumFeatures(1000)
70+
.setInputCol(tokenizer.getOutputCol)
71+
.setOutputCol("features")
72+
val lr = new LogisticRegression()
73+
.setMaxIter(10)
74+
.setRegParam(0.01)
75+
val pipeline = new Pipeline()
76+
.setStages(Array(tokenizer, hashingTF, lr))
77+
78+
// Fit the pipeline to training documents.
79+
val model = pipeline.fit(training)
80+
81+
// Prepare test documents, which are unlabeled.
82+
val test = sparkContext.parallelize(Seq(
83+
Document(4L, "spark i j k"),
84+
Document(5L, "l m n"),
85+
Document(6L, "mapreduce spark"),
86+
Document(7L, "apache hadoop")))
87+
88+
// Make predictions on test documents.
89+
model.transform(test)
90+
.select('id, 'text, 'score, 'prediction)
91+
.collect()
92+
.foreach(println)
93+
}
94+
}

mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,12 @@ package org.apache.spark.ml
2020
import java.util.UUID
2121

2222
/**
23-
* Something with a unique id.
23+
* Object with a unique id.
2424
*/
2525
trait Identifiable extends Serializable {
2626

2727
/**
2828
* A unique id for the object.
2929
*/
30-
val uid: String = this.getClass.getSimpleName + "-" + Identifiable.randomUid
31-
}
32-
33-
object Identifiable {
34-
35-
/**
36-
* Returns a random uid, drawn uniformly from 4+ billion candidates.
37-
*/
38-
private def randomUid: String = UUID.randomUUID().toString.take(8)
30+
val uid: String = this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
3931
}

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.SchemaRDD
2525
/**
2626
* A stage in a pipeline, either an Estimator or an Transformer.
2727
*/
28-
trait PipelineStage extends Identifiable
28+
abstract class PipelineStage
2929

3030
/**
3131
* A simple pipeline, which acts as an estimator.

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
package org.apache.spark.ml
1919

2020
import scala.annotation.varargs
21+
import scala.reflect.runtime.universe.TypeTag
2122

22-
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
23+
import org.apache.spark.ml.param._
2324
import org.apache.spark.sql.SchemaRDD
2425
import org.apache.spark.sql.api.java.JavaSchemaRDD
26+
import org.apache.spark.sql.catalyst.analysis.Star
27+
import org.apache.spark.sql.catalyst.dsl._
2528

2629
/**
2730
* Abstract class for transformers that transform one dataset into another.
@@ -60,3 +63,23 @@ abstract class Transformer extends PipelineStage with Params {
6063
transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
6164
}
6265
}
66+
67+
/**
68+
* Abstract class for transformers that take one input column, apply transformation, and output the
69+
* result as a new column.
70+
*/
71+
abstract class SimpleTransformer[IN, OUT: TypeTag, SELF <: SimpleTransformer[IN, OUT, SELF]]
72+
extends Transformer with HasInputCol with HasOutputCol {
73+
74+
def setInputCol(value: String): SELF = { set(inputCol, value); this.asInstanceOf[SELF] }
75+
def setOutputCol(value: String): SELF = { set(outputCol, value); this.asInstanceOf[SELF] }
76+
77+
def createTransformFunc: IN => OUT
78+
79+
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
80+
import dataset.sqlContext._
81+
val map = this.paramMap ++ paramMap
82+
val udf: IN => OUT = this.createTransformFunc
83+
dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
84+
}
85+
}

mllib/src/main/scala/org/apache/spark/ml/example/LogisticRegression.scala renamed to mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.ml.example
18+
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.ml._
2121
import org.apache.spark.ml.param._

mllib/src/main/scala/org/apache/spark/ml/example/BinaryClassificationEvaluator.scala renamed to mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.ml.example
18+
package org.apache.spark.ml.evaluation
1919

2020
import org.apache.spark.ml._
2121
import org.apache.spark.ml.param._

mllib/src/main/scala/org/apache/spark/ml/example/Tokenizer.scala

Lines changed: 0 additions & 42 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/ml/example/HashingTF.scala renamed to mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.apache.spark.ml.example
1+
package org.apache.spark.ml.feature
22

33
import org.apache.spark.ml.Transformer
44
import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap}

mllib/src/main/scala/org/apache/spark/ml/example/StandardScaler.scala renamed to mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.ml.example
18+
package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.ml._
2121
import org.apache.spark.ml.param._
@@ -31,6 +31,10 @@ import org.apache.spark.sql.catalyst.expressions.Row
3131
*/
3232
trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
3333

34+
/**
35+
* Standardizes features by removing the mean and scaling to unit variance using column summary
36+
* statistics on the samples in the training set.
37+
*/
3438
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
3539

3640
def setInputCol(value: String): this.type = { set(inputCol, value); this }
@@ -50,6 +54,9 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
5054
}
5155
}
5256

57+
/**
58+
* Model fitted by [[StandardScaler]].
59+
*/
5360
class StandardScalerModel private[ml] (
5461
override val parent: StandardScaler,
5562
override val fittingParamMap: ParamMap,

0 commit comments

Comments
 (0)