Skip to content

Commit 7c539cf

Browse files
committed
Vector transformers
1 parent 1c182dd commit 7c539cf

File tree

2 files changed

+108
-38
lines changed

2 files changed

+108
-38
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,49 +18,92 @@
1818
package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
21-
import org.apache.spark.ml.Transformer
21+
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.attribute.NominalAttribute
23+
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
2324
import org.apache.spark.ml.param._
24-
import org.apache.spark.sql.DataFrame
25-
import org.apache.spark.sql.functions._
26-
import org.apache.spark.sql.types.{StringType, StructType}
25+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
26+
import org.apache.spark.ml.util.SchemaUtils
27+
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
2728

29+
/**
30+
* A one-hot encoder that maps a column of label indices to a column of binary vectors, with
31+
* at most a single one-value. By default, the binary vector has an element for each category, so
32+
* with 5 categories, an input value of 2.0 would map to an output vector of
33+
* (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
34+
* output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
35+
* of 0.0 would map to a vector of all zeros. Omitting the first category enables the vector
36+
* columns to be independent.
37+
*/
2838
@AlphaComponent
29-
class OneHotEncoder(labelNames: Seq[String], includeFirst: Boolean = true) extends Transformer
30-
with HasInputCol {
39+
class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
40+
with HasInputCol with HasOutputCol {
3141

32-
/** @group setParam */
33-
def setInputCol(value: String): this.type = set(inputCol, value)
42+
/**
43+
* Whether to include a component in the encoded vectors for the first category, defaults to true.
44+
* @group param
45+
*/
46+
final val includeFirst: Param[Boolean] =
47+
new Param[Boolean](this, "includeFirst", "include first category")
48+
setDefault(includeFirst -> true)
3449

35-
private def outputColName(index: Int): String = {
36-
s"${get(inputCol)}_${labelNames(index)}"
37-
}
50+
/**
51+
* The names of the categories. Used to identify them in the attributes of the output column.
52+
* This is a required parameter.
53+
* @group param
54+
*/
55+
final val labelNames: Param[Array[String]] =
56+
new Param[Array[String]](this, "labelNames", "categorical label names")
3857

39-
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
40-
val map = this.paramMap ++ paramMap
58+
/** @group setParam */
59+
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
4160

42-
val startIndex = if (includeFirst) 0 else 1
43-
val cols = (startIndex until labelNames.length).map { index =>
44-
val colEncoder = udf { label: Double => if (index == label) 1.0 else 0.0 }
45-
colEncoder(dataset(map(inputCol))).as(outputColName(index))
46-
}
61+
/** @group setParam */
62+
def setLabelNames(value: Array[String]): this.type = set(labelNames, value)
4763

48-
dataset.select(Array(col("*")) ++ cols: _*)
49-
}
64+
/** @group setParam */
65+
override def setInputCol(value: String): this.type = set(inputCol, value)
66+
67+
/** @group setParam */
68+
override def setOutputCol(value: String): this.type = set(outputCol, value)
5069

5170
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
52-
val map = this.paramMap ++ paramMap
53-
checkInputColumn(schema, map(inputCol), StringType)
71+
val map = extractParamMap(paramMap)
72+
SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
5473
val inputFields = schema.fields
55-
val startIndex = if (includeFirst) 0 else 1
56-
val fields = (startIndex until labelNames.length).map { index =>
57-
val colName = outputColName(index)
58-
require(inputFields.forall(_.name != colName),
59-
s"Output column $colName already exists.")
60-
NominalAttribute.defaultAttr.withName(colName).toStructField()
61-
}
62-
63-
val outputFields = inputFields ++ fields
74+
val outputColName = map(outputCol)
75+
require(inputFields.forall(_.name != outputColName),
76+
s"Output column $outputColName already exists.")
77+
require(map.contains(labelNames), "OneHotEncoder missing category names")
78+
val categories = map(labelNames)
79+
val attrValues = (if (map(includeFirst)) categories else categories.drop(1)).toArray
80+
val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
81+
val outputFields = inputFields :+ attr.toStructField()
6482
StructType(outputFields)
6583
}
84+
85+
protected def createTransformFunc(paramMap: ParamMap): (Double) => Vector = {
86+
val map = extractParamMap(paramMap)
87+
val first = map(includeFirst)
88+
val vecLen = if (first) map(labelNames).length else map(labelNames).length - 1
89+
val oneValue = Array(1.0)
90+
val emptyValues = Array[Double]()
91+
val emptyIndices = Array[Int]()
92+
label: Double => {
93+
val values = if (first || label != 0.0) oneValue else emptyValues
94+
val indices = if (first) {
95+
Array(label.toInt)
96+
} else if (label != 0.0) {
97+
Array(label.toInt - 1)
98+
} else {
99+
emptyIndices
100+
}
101+
Vectors.sparse(vecLen, indices, values)
102+
}
103+
}
104+
105+
/**
106+
* Returns the data type of the output column.
107+
*/
108+
protected def outputDataType: DataType = new VectorUDT
66109
}

mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
21+
import org.apache.spark.mllib.linalg.Vector
2022
import org.apache.spark.mllib.util.MLlibTestSparkContext
2123

24+
import org.apache.spark.sql.{DataFrame, SQLContext}
25+
2226
import org.scalatest.FunSuite
23-
import org.apache.spark.sql.SQLContext
24-
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
2527

2628
class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
2729
private var sqlContext: SQLContext = _
@@ -31,7 +33,7 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
3133
sqlContext = new SQLContext(sc)
3234
}
3335

34-
test("OneHotEncoder") {
36+
def stringIndexed(): (DataFrame, NominalAttribute) = {
3537
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
3638
val df = sqlContext.createDataFrame(data).toDF("id", "label")
3739
val indexer = new StringIndexer()
@@ -41,19 +43,44 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
4143
val transformed = indexer.transform(df)
4244
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
4345
.asInstanceOf[NominalAttribute]
44-
assert(attr.values.get === Array("a", "c", "b"))
46+
(transformed, attr)
47+
}
4548

46-
val encoder = new OneHotEncoder(attr.values.get)
49+
test("OneHotEncoder includeFirst = true") {
50+
val (transformed, attr) = stringIndexed()
51+
val encoder = new OneHotEncoder()
52+
.setLabelNames(attr.values.get)
4753
.setInputCol("labelIndex")
54+
.setOutputCol("labelVec")
4855
val encoded = encoder.transform(transformed)
4956

50-
val output = encoded.select("id", "labelIndex_a", "labelIndex_c", "labelIndex_b").map { r =>
51-
(r.getInt(0), r.getDouble(1), r.getDouble(2), r.getDouble(3))
57+
val output = encoded.select("id", "labelVec").map { r =>
58+
val vec = r.get(1).asInstanceOf[Vector]
59+
(r.getInt(0), vec(0), vec(1), vec(2))
5260
}.collect().toSet
5361
// a -> 0, b -> 2, c -> 1
5462
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
5563
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
5664
assert(output === expected)
5765
}
5866

67+
test("OneHotEncoder includeFirst = false") {
68+
val (transformed, attr) = stringIndexed()
69+
val encoder = new OneHotEncoder()
70+
.setIncludeFirst(false)
71+
.setLabelNames(attr.values.get)
72+
.setInputCol("labelIndex")
73+
.setOutputCol("labelVec")
74+
val encoded = encoder.transform(transformed)
75+
76+
val output = encoded.select("id", "labelVec").map { r =>
77+
val vec = r.get(1).asInstanceOf[Vector]
78+
(r.getInt(0), vec(0), vec(1))
79+
}.collect().toSet
80+
// a -> 0, b -> 2, c -> 1
81+
val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
82+
(3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
83+
assert(output === expected)
84+
}
85+
5986
}

0 commit comments

Comments
 (0)