@@ -32,8 +32,8 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
32
32
* with 5 categories, an input value of 2.0 would map to an output vector of
33
33
* (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
34
34
* output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
35
- * of 0.0 would map to a vector of all zeros. Omitting the first category enables the vector
36
- * columns to be independent .
35
+ * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
36
+ * linearly dependent because they sum up to one .
37
37
*/
38
38
@ AlphaComponent
39
39
class OneHotEncoder extends UnaryTransformer [Double , Vector , OneHotEncoder ]
@@ -43,8 +43,8 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
43
43
* Whether to include a component in the encoded vectors for the first category, defaults to true.
44
44
* @group param
45
45
*/
46
- final val includeFirst : Param [ Boolean ] =
47
- new Param [ Boolean ] (this , " includeFirst" , " include first category" )
46
+ final val includeFirst : BooleanParam =
47
+ new BooleanParam (this , " includeFirst" , " include first category" )
48
48
setDefault(includeFirst -> true )
49
49
50
50
/**
@@ -59,7 +59,7 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
59
59
def setIncludeFirst (value : Boolean ): this .type = set(includeFirst, value)
60
60
61
61
/** @group setParam */
62
- def setLabelNames (value : Array [ String ] ): this .type = set(labelNames, value )
62
+ def setLabelNames (attr : NominalAttribute ): this .type = set(labelNames, attr.values.get )
63
63
64
64
/** @group setParam */
65
65
override def setInputCol (value : String ): this .type = set(inputCol, value)
0 commit comments