Skip to content

Commit 27582a4

Browse files
committed
doc changes
1 parent 73a000b commit 27582a4

File tree

13 files changed

+98
-77
lines changed

13 files changed

+98
-77
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
package org.apache.spark.examples.ml
1919

20+
import org.apache.spark.sql.{StringType, DataType, SQLContext}
21+
2022
import scala.beans.BeanInfo
2123

2224
import org.apache.spark.{SparkConf, SparkContext}
2325
import org.apache.spark.ml.{Pipeline, UnaryTransformer}
2426
import org.apache.spark.ml.classification.LogisticRegression
2527
import org.apache.spark.ml.feature.HashingTF
2628
import org.apache.spark.ml.param.ParamMap
27-
import org.apache.spark.sql.SQLContext
2829

2930
@BeanInfo
3031
case class LabeledDocument(id: Long, text: String, label: Double)
@@ -35,10 +36,15 @@ case class Document(id: Long, text: String)
3536
/**
3637
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
3738
*/
38-
class MyTokenizer extends UnaryTransformer[String, Seq[String], MyTokenizer]
39-
with Serializable {
40-
override def createTransformFunc(paramMap: ParamMap): String => Seq[String] =
39+
class MyTokenizer extends UnaryTransformer[String, Seq[String], MyTokenizer] {
40+
41+
override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
4142
_.toLowerCase.split("\\s")
43+
}
44+
45+
override protected def validateInputType(inputType: DataType): Unit = {
46+
require(inputType == StringType, s"Input type must be string type but got $inputType.")
47+
}
4248
}
4349

4450
/**

mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import org.apache.spark.sql.SchemaRDD
2626
abstract class Evaluator extends Identifiable {
2727

2828
/**
29-
* Evaluate the output
29+
* Evaluates the output.
30+
*
3031
* @param dataset a dataset that contains labels/observations and predictions.
3132
* @param paramMap parameter map that specifies the input columns and output metrics
3233
* @return metric

mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import java.util.UUID
2525
trait Identifiable extends Serializable {
2626

2727
/**
28-
* A unique id for the object.
28+
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
29+
* random hex chars.
2930
*/
3031
val uid: String = this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
3132
}

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,33 @@ package org.apache.spark.ml
1919

2020
import scala.collection.mutable.ListBuffer
2121

22+
import org.apache.spark.Logging
2223
import org.apache.spark.ml.param.{Param, ParamMap}
23-
import org.apache.spark.sql.{StructType, SchemaRDD}
24+
import org.apache.spark.sql.{SchemaRDD, StructType}
2425

2526
/**
2627
* A stage in a pipeline, either an Estimator or an Transformer.
2728
*/
28-
abstract class PipelineStage {
29+
abstract class PipelineStage extends Serializable with Logging {
2930

3031
/**
3132
* Derives the output schema from the input schema and parameters.
3233
*/
3334
def transform(schema: StructType, paramMap: ParamMap): StructType
35+
36+
/**
37+
* Drives the output schema from the input schema and parameters, optionally with logging.
38+
*/
39+
protected def transform(schema: StructType, paramMap: ParamMap, logging: Boolean): StructType = {
40+
if (logging) {
41+
logDebug(s"Input schema: ${schema.json}")
42+
}
43+
val outputSchema = transform(schema, paramMap)
44+
if (logging) {
45+
logDebug(s"Expected output schema: ${outputSchema.json}")
46+
}
47+
outputSchema
48+
}
3449
}
3550

3651
/**
@@ -43,6 +58,7 @@ class Pipeline extends Estimator[PipelineModel] {
4358
def getStages: Array[PipelineStage] = get(stages)
4459

4560
override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
61+
transform(dataset.schema, paramMap, logging = true)
4662
val map = this.paramMap ++ paramMap
4763
val theStages = map(stages)
4864
// Search for the last estimator.
@@ -89,7 +105,7 @@ class Pipeline extends Estimator[PipelineModel] {
89105
class PipelineModel(
90106
override val parent: Pipeline,
91107
override val fittingParamMap: ParamMap,
92-
val transformers: Array[Transformer]) extends Model {
108+
val transformers: Array[Transformer]) extends Model with Logging {
93109

94110
/**
95111
* Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
@@ -110,6 +126,7 @@ class PipelineModel(
110126
}
111127

112128
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
129+
transform(dataset.schema, paramMap, logging = true)
113130
transformers.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
114131
}
115132

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
package org.apache.spark.ml
1919

20-
import org.apache.spark.sql.catalyst.ScalaReflection
21-
import org.apache.spark.sql.catalyst.types.{StructField, StructType}
22-
2320
import scala.annotation.varargs
2421
import scala.reflect.runtime.universe.TypeTag
2522

23+
import org.apache.spark.Logging
2624
import org.apache.spark.ml.param._
27-
import org.apache.spark.sql.{DataType, SchemaRDD}
25+
import org.apache.spark.sql.SchemaRDD
2826
import org.apache.spark.sql.api.java.JavaSchemaRDD
27+
import org.apache.spark.sql.catalyst.ScalaReflection
2928
import org.apache.spark.sql.catalyst.analysis.Star
3029
import org.apache.spark.sql.catalyst.dsl._
30+
import org.apache.spark.sql.catalyst.types._
3131

3232
/**
3333
* Abstract class for transformers that transform one dataset into another.
@@ -72,7 +72,7 @@ abstract class Transformer extends PipelineStage with Params {
7272
* result as a new column.
7373
*/
7474
abstract class UnaryTransformer[IN, OUT: TypeTag, SELF <: UnaryTransformer[IN, OUT, SELF]]
75-
extends Transformer with HasInputCol with HasOutputCol {
75+
extends Transformer with HasInputCol with HasOutputCol with Logging {
7676

7777
def setInputCol(value: String): SELF = { set(inputCol, value); this.asInstanceOf[SELF] }
7878
def setOutputCol(value: String): SELF = { set(outputCol, value); this.asInstanceOf[SELF] }
@@ -103,6 +103,11 @@ abstract class UnaryTransformer[IN, OUT: TypeTag, SELF <: UnaryTransformer[IN, O
103103
}
104104

105105
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
106+
transform(dataset.schema, paramMap, logging = true)
107+
val inputSchema = dataset.schema
108+
logDebug(s"Input schema: ${inputSchema.json}")
109+
val outputSchema = transform(dataset.schema, paramMap)
110+
logDebug(s"Expected output schema: ${outputSchema.json}")
106111
import dataset.sqlContext._
107112
val map = this.paramMap ++ paramMap
108113
val udf = this.createTransformFunc(map)

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,19 @@ package org.apache.spark.ml.classification
2020
import org.apache.spark.ml._
2121
import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
23-
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector}
23+
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
2424
import org.apache.spark.mllib.regression.LabeledPoint
25-
import org.apache.spark.sql.catalyst.types._
26-
import org.apache.spark.sql.SchemaRDD
25+
import org.apache.spark.sql._
2726
import org.apache.spark.sql.catalyst.analysis.Star
2827
import org.apache.spark.sql.catalyst.dsl._
29-
import org.apache.spark.sql.catalyst.expressions.{Cast, Row}
3028
import org.apache.spark.storage.StorageLevel
3129

3230
/**
3331
* Params for logistic regression.
3432
*/
35-
trait LogisticRegressionParams extends Params with HasRegParam with HasMaxIter with HasLabelCol
36-
with HasThreshold with HasFeaturesCol with HasScoreCol with HasPredictionCol
33+
private[classification] trait LogisticRegressionParams extends Params
34+
with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
35+
with HasScoreCol with HasPredictionCol
3736

3837
/**
3938
* Logistic regression.
@@ -53,9 +52,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
5352
def setPredictionCol(value: String): this.type = { set(predictionCol, value); this }
5453

5554
override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
55+
transform(dataset.schema, paramMap, logging = true)
5656
import dataset.sqlContext._
5757
val map = this.paramMap ++ paramMap
58-
val instances = dataset.select(Cast(map(labelCol).attr, DoubleType), map(featuresCol).attr)
58+
val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
5959
.map { case Row(label: Double, features: Vector) =>
6060
LabeledPoint(label, features)
6161
}.persist(StorageLevel.MEMORY_AND_DISK)
@@ -74,23 +74,15 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
7474
val map = this.paramMap ++ paramMap
7575
val featuresType = schema(map(featuresCol)).dataType
7676
// TODO: Support casting Array[Double] and Array[Float] to Vector.
77-
if (!featuresType.isInstanceOf[VectorUDT]) {
78-
throw new IllegalArgumentException(
79-
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
80-
}
81-
val validLabelTypes = Set[DataType](FloatType, DoubleType, IntegerType, BooleanType, LongType)
77+
require(featuresType.isInstanceOf[VectorUDT],
78+
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
8279
val labelType = schema(map(labelCol)).dataType
83-
if (!validLabelTypes.contains(labelType)) {
84-
throw new IllegalArgumentException(
85-
s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
86-
}
80+
require(labelType == DoubleType,
81+
s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
8782
val fieldNames = schema.fieldNames
88-
if (fieldNames.contains(map(scoreCol))) {
89-
throw new IllegalArgumentException(s"Score column ${map(scoreCol)} already exists.")
90-
}
91-
if (fieldNames.contains(map(predictionCol))) {
92-
throw new IllegalArgumentException(s"Prediction column ${map(predictionCol)} already exists.")
93-
}
83+
require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
84+
require(!fieldNames.contains(map(predictionCol)),
85+
s"Prediction column ${map(predictionCol)} already exists.")
9486
val outputFields = schema.fields ++ Seq(
9587
StructField(map(scoreCol), DoubleType, false),
9688
StructField(map(predictionCol), DoubleType, false))
@@ -115,24 +107,20 @@ class LogisticRegressionModel private[ml] (
115107
val map = this.paramMap ++ paramMap
116108
val featuresType = schema(map(featuresCol)).dataType
117109
// TODO: Support casting Array[Double] and Array[Float] to Vector.
118-
if (!featuresType.isInstanceOf[VectorUDT]) {
119-
throw new IllegalArgumentException(
120-
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
121-
}
110+
require(featuresType.isInstanceOf[VectorUDT],
111+
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
122112
val fieldNames = schema.fieldNames
123-
if (fieldNames.contains(map(scoreCol))) {
124-
throw new IllegalArgumentException(s"Score column ${map(scoreCol)} already exists.")
125-
}
126-
if (fieldNames.contains(map(predictionCol))) {
127-
throw new IllegalArgumentException(s"Prediction column ${map(predictionCol)} already exists.")
128-
}
113+
require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
114+
require(!fieldNames.contains(map(predictionCol)),
115+
s"Prediction column ${map(predictionCol)} already exists.")
129116
val outputFields = schema.fields ++ Seq(
130117
StructField(map(scoreCol), DoubleType, false),
131118
StructField(map(predictionCol), DoubleType, false))
132119
StructType(outputFields)
133120
}
134121

135122
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
123+
transform(dataset.schema, paramMap, logging = true)
136124
import dataset.sqlContext._
137125
val map = this.paramMap ++ paramMap
138126
val score: Vector => Double = (v) => {

mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.ml.evaluation
2020
import org.apache.spark.ml._
2121
import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
23-
import org.apache.spark.sql.SchemaRDD
24-
import org.apache.spark.sql.catalyst.expressions.Row
23+
import org.apache.spark.sql.{DoubleType, Row, SchemaRDD}
2524
import org.apache.spark.storage.StorageLevel
2625

2726
/**
@@ -40,8 +39,17 @@ class BinaryClassificationEvaluator extends Evaluator with Params
4039
def setLabelCol(value: String): this.type = { set(labelCol, value); this }
4140

4241
override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
43-
import dataset.sqlContext._
4442
val map = this.paramMap ++ paramMap
43+
44+
val schema = dataset.schema
45+
val scoreType = schema(map(scoreCol)).dataType
46+
require(scoreType == DoubleType,
47+
s"Score column ${map(scoreCol)} must be double type but found $scoreType")
48+
val labelType = schema(map(labelCol)).dataType
49+
require(labelType == DoubleType,
50+
s"Label column ${map(labelCol)} must be double type but found $labelType")
51+
52+
import dataset.sqlContext._
4553
val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
4654
.map { case Row(score: Double, label: Double) =>
4755
(score, label)

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,15 @@ package org.apache.spark.ml.feature
2020
import org.apache.spark.ml._
2121
import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.feature
23-
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
24-
import org.apache.spark.sql.catalyst.types.StructField
25-
import org.apache.spark.sql.{StructType, SchemaRDD}
23+
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
24+
import org.apache.spark.sql._
2625
import org.apache.spark.sql.catalyst.analysis.Star
2726
import org.apache.spark.sql.catalyst.dsl._
28-
import org.apache.spark.sql.catalyst.expressions.Row
2927

3028
/**
3129
* Params for [[StandardScaler]] and [[StandardScalerModel]].
3230
*/
33-
trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
31+
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
3432

3533
/**
3634
* Standardizes features by removing the mean and scaling to unit variance using column summary
@@ -42,6 +40,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
4240
def setOutputCol(value: String): this.type = { set(outputCol, value); this }
4341

4442
override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
43+
transform(dataset.schema, paramMap, logging = true)
4544
import dataset.sqlContext._
4645
val map = this.paramMap ++ paramMap
4746
val input = dataset.select(map(inputCol).attr)
@@ -78,6 +77,7 @@ class StandardScalerModel private[ml] (
7877
def setOutputCol(value: String): this.type = { set(outputCol, value); this }
7978

8079
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
80+
transform(dataset.schema, paramMap, logging = true)
8181
import dataset.sqlContext._
8282
val map = this.paramMap ++ paramMap
8383
val scale: (Vector) => Vector = (v) => {

mllib/src/main/scala/org/apache/spark/ml/package.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark
1919

2020
/**
21-
* "org.apache.spark.ml" is an ALPHA component that adapts the new set of machine learning APIs.
21+
* :: AlphaComponent ::
22+
* This is an ALPHA component that adapts the new set of machine learning APIs.
2223
*/
2324
package object ml {
2425
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Param[T] (
4242
/**
4343
* Creates a param pair with the given value (for Java).
4444
*/
45-
def w(value: T): ParamPair[T] = ParamPair(this, value)
45+
def w(value: T): ParamPair[T] = this -> value
4646

4747
/**
4848
* Creates a param pair with the given value (for Scala).
@@ -303,5 +303,3 @@ object ParamMap {
303303
new ParamMap().put(paramPairs: _*)
304304
}
305305
}
306-
307-

0 commit comments

Comments
 (0)