Skip to content

Commit 6e257b9

Browse files
committed
Review comments
1 parent 7c539cf commit 6e257b9

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
3232
* with 5 categories, an input value of 2.0 would map to an output vector of
3333
* (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
3434
* output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
35-
* of 0.0 would map to a vector of all zeros. Omitting the first category enables the vector
36-
* columns to be independent.
35+
* of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
36+
* linearly dependent because they sum up to one.
3737
*/
3838
@AlphaComponent
3939
class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
@@ -43,8 +43,8 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
4343
* Whether to include a component in the encoded vectors for the first category, defaults to true.
4444
* @group param
4545
*/
46-
final val includeFirst: Param[Boolean] =
47-
new Param[Boolean](this, "includeFirst", "include first category")
46+
final val includeFirst: BooleanParam =
47+
new BooleanParam(this, "includeFirst", "include first category")
4848
setDefault(includeFirst -> true)
4949

5050
/**
@@ -59,7 +59,7 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
5959
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
6060

6161
/** @group setParam */
62-
def setLabelNames(value: Array[String]): this.type = set(labelNames, value)
62+
def setLabelNames(attr: NominalAttribute): this.type = set(labelNames, attr.values.get)
6363

6464
/** @group setParam */
6565
override def setInputCol(value: String): this.type = set(inputCol, value)

mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
20+
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
2121
import org.apache.spark.mllib.linalg.Vector
2222
import org.apache.spark.mllib.util.MLlibTestSparkContext
2323

@@ -49,7 +49,7 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
4949
test("OneHotEncoder includeFirst = true") {
5050
val (transformed, attr) = stringIndexed()
5151
val encoder = new OneHotEncoder()
52-
.setLabelNames(attr.values.get)
52+
.setLabelNames(attr)
5353
.setInputCol("labelIndex")
5454
.setOutputCol("labelVec")
5555
val encoded = encoder.transform(transformed)
@@ -68,7 +68,7 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
6868
val (transformed, attr) = stringIndexed()
6969
val encoder = new OneHotEncoder()
7070
.setIncludeFirst(false)
71-
.setLabelNames(attr.values.get)
71+
.setLabelNames(attr)
7272
.setInputCol("labelIndex")
7373
.setOutputCol("labelVec")
7474
val encoded = encoder.transform(transformed)

0 commit comments

Comments
 (0)