|
18 | 18 | package org.apache.spark.ml.feature
|
19 | 19 |
|
20 | 20 | import org.apache.spark.annotation.AlphaComponent
|
21 |
| -import org.apache.spark.ml.Transformer |
| 21 | +import org.apache.spark.ml.UnaryTransformer |
22 | 22 | import org.apache.spark.ml.attribute.NominalAttribute
|
| 23 | +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} |
23 | 24 | import org.apache.spark.ml.param._
|
24 |
| -import org.apache.spark.sql.DataFrame |
25 |
| -import org.apache.spark.sql.functions._ |
26 |
| -import org.apache.spark.sql.types.{StringType, StructType} |
| 25 | +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} |
| 26 | +import org.apache.spark.ml.util.SchemaUtils |
| 27 | +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} |
27 | 28 |
|
| 29 | +/** |
| 30 | + * A one-hot encoder that maps a column of label indices to a column of binary vectors, with |
| 31 | + * at most a single one-value. By default, the binary vector has an element for each category, so |
| 32 | + * with 5 categories, an input value of 2.0 would map to an output vector of |
| 33 | + * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the |
| 34 | + * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value |
| 35 | + * of 0.0 would map to a vector of all zeros. Omitting the first category enables the vector |
| 36 | + * columns to be independent. |
| 37 | + */ |
28 | 38 | @AlphaComponent
|
29 |
| -class OneHotEncoder(labelNames: Seq[String], includeFirst: Boolean = true) extends Transformer |
30 |
| - with HasInputCol { |
| 39 | +class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] |
| 40 | + with HasInputCol with HasOutputCol { |
31 | 41 |
|
32 |
| - /** @group setParam */ |
33 |
| - def setInputCol(value: String): this.type = set(inputCol, value) |
| 42 | + /** |
| 43 | + * Whether to include a component in the encoded vectors for the first category, defaults to true. |
| 44 | + * @group param |
| 45 | + */ |
| 46 | + final val includeFirst: Param[Boolean] = |
| 47 | + new Param[Boolean](this, "includeFirst", "include first category") |
| 48 | + setDefault(includeFirst -> true) |
34 | 49 |
|
35 |
| - private def outputColName(index: Int): String = { |
36 |
| - s"${get(inputCol)}_${labelNames(index)}" |
37 |
| - } |
| 50 | + /** |
| 51 | + * The names of the categories. Used to identify them in the attributes of the output column. |
| 52 | + * This is a required parameter. |
| 53 | + * @group param |
| 54 | + */ |
| 55 | + final val labelNames: Param[Array[String]] = |
| 56 | + new Param[Array[String]](this, "labelNames", "categorical label names") |
38 | 57 |
|
39 |
| - override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { |
40 |
| - val map = this.paramMap ++ paramMap |
| 58 | + /** @group setParam */ |
| 59 | + def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) |
41 | 60 |
|
42 |
| - val startIndex = if (includeFirst) 0 else 1 |
43 |
| - val cols = (startIndex until labelNames.length).map { index => |
44 |
| - val colEncoder = udf { label: Double => if (index == label) 1.0 else 0.0 } |
45 |
| - colEncoder(dataset(map(inputCol))).as(outputColName(index)) |
46 |
| - } |
| 61 | + /** @group setParam */ |
| 62 | + def setLabelNames(value: Array[String]): this.type = set(labelNames, value) |
47 | 63 |
|
48 |
| - dataset.select(Array(col("*")) ++ cols: _*) |
49 |
| - } |
| 64 | + /** @group setParam */ |
| 65 | + override def setInputCol(value: String): this.type = set(inputCol, value) |
| 66 | + |
| 67 | + /** @group setParam */ |
| 68 | + override def setOutputCol(value: String): this.type = set(outputCol, value) |
50 | 69 |
|
51 | 70 | override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
52 |
| - val map = this.paramMap ++ paramMap |
53 |
| - checkInputColumn(schema, map(inputCol), StringType) |
| 71 | + val map = extractParamMap(paramMap) |
| 72 | + SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType) |
54 | 73 | val inputFields = schema.fields
|
55 |
| - val startIndex = if (includeFirst) 0 else 1 |
56 |
| - val fields = (startIndex until labelNames.length).map { index => |
57 |
| - val colName = outputColName(index) |
58 |
| - require(inputFields.forall(_.name != colName), |
59 |
| - s"Output column $colName already exists.") |
60 |
| - NominalAttribute.defaultAttr.withName(colName).toStructField() |
61 |
| - } |
62 |
| - |
63 |
| - val outputFields = inputFields ++ fields |
| 74 | + val outputColName = map(outputCol) |
| 75 | + require(inputFields.forall(_.name != outputColName), |
| 76 | + s"Output column $outputColName already exists.") |
| 77 | + require(map.contains(labelNames), "OneHotEncoder missing category names") |
| 78 | + val categories = map(labelNames) |
| 79 | + val attrValues = (if (map(includeFirst)) categories else categories.drop(1)).toArray |
| 80 | + val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) |
| 81 | + val outputFields = inputFields :+ attr.toStructField() |
64 | 82 | StructType(outputFields)
|
65 | 83 | }
|
| 84 | + |
| 85 | + protected def createTransformFunc(paramMap: ParamMap): (Double) => Vector = { |
| 86 | + val map = extractParamMap(paramMap) |
| 87 | + val first = map(includeFirst) |
| 88 | + val vecLen = if (first) map(labelNames).length else map(labelNames).length - 1 |
| 89 | + val oneValue = Array(1.0) |
| 90 | + val emptyValues = Array[Double]() |
| 91 | + val emptyIndices = Array[Int]() |
| 92 | + label: Double => { |
| 93 | + val values = if (first || label != 0.0) oneValue else emptyValues |
| 94 | + val indices = if (first) { |
| 95 | + Array(label.toInt) |
| 96 | + } else if (label != 0.0) { |
| 97 | + Array(label.toInt - 1) |
| 98 | + } else { |
| 99 | + emptyIndices |
| 100 | + } |
| 101 | + Vectors.sparse(vecLen, indices, values) |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + /** |
| 106 | + * Returns the data type of the output column. |
| 107 | + */ |
| 108 | + protected def outputDataType: DataType = new VectorUDT |
66 | 109 | }
|
0 commit comments