Skip to content

Commit d36e673

Browse files
mengxrjkbradley
authored andcommitted
[SPARK-6965] [MLLIB] StringIndexer handles numeric input.
Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#5753 from mengxr/SPARK-6965 and squashes the following commits: 2e34f3c [Xiangrui Meng] add actual type in the error message ad938bf [Xiangrui Meng] StringIndexer handles numeric input.
1 parent 555213e commit d36e673

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
2323
import org.apache.spark.ml.attribute.NominalAttribute
2424
import org.apache.spark.ml.param._
2525
import org.apache.spark.ml.param.shared._
26-
import org.apache.spark.ml.util.SchemaUtils
2726
import org.apache.spark.sql.DataFrame
2827
import org.apache.spark.sql.functions._
29-
import org.apache.spark.sql.types.{StringType, StructType}
28+
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
3029
import org.apache.spark.util.collection.OpenHashMap
3130

3231
/**
@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
3736
/** Validates and transforms the input schema. */
3837
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
3938
val map = extractParamMap(paramMap)
40-
SchemaUtils.checkColumnType(schema, map(inputCol), StringType)
39+
val inputColName = map(inputCol)
40+
val inputDataType = schema(inputColName).dataType
41+
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
42+
s"The input column $inputColName must be either string type or numeric type, " +
43+
s"but got $inputDataType.")
4144
val inputFields = schema.fields
4245
val outputColName = map(outputCol)
4346
require(inputFields.forall(_.name != outputColName),
@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
5154
/**
5255
* :: AlphaComponent ::
5356
* A label indexer that maps a string column of labels to an ML column of label indices.
57+
* If the input column is numeric, we cast it to string and index the string values.
5458
* The indices are in [0, numLabels), ordered by label frequencies.
5559
* So the most frequent label gets index 0.
5660
*/
@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
6771

6872
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
6973
val map = extractParamMap(paramMap)
70-
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
74+
val counts = dataset.select(col(map(inputCol)).cast(StringType))
75+
.map(_.getString(0))
76+
.countByValue()
7177
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
7278
val model = new StringIndexerModel(this, map, labels)
7379
Params.inheritValues(map, this, model)
@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
119125
val outputColName = map(outputCol)
120126
val metadata = NominalAttribute.defaultAttr
121127
.withName(outputColName).withValues(labels).toMetadata()
122-
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
128+
dataset.select(col("*"),
129+
indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata))
123130
}
124131

125132
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
4949
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
5050
assert(output === expected)
5151
}
52+
53+
test("StringIndexer with a numeric input column") {
54+
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
55+
val df = sqlContext.createDataFrame(data).toDF("id", "label")
56+
val indexer = new StringIndexer()
57+
.setInputCol("label")
58+
.setOutputCol("labelIndex")
59+
.fit(df)
60+
val transformed = indexer.transform(df)
61+
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
62+
.asInstanceOf[NominalAttribute]
63+
assert(attr.values.get === Array("100", "300", "200"))
64+
val output = transformed.select("id", "labelIndex").map { r =>
65+
(r.getInt(0), r.getDouble(1))
66+
}.collect().toSet
67+
// 100 -> 0, 200 -> 2, 300 -> 1
68+
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
69+
assert(output === expected)
70+
}
5271
}

0 commit comments

Comments
 (0)