@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
23
23
import org .apache .spark .ml .attribute .NominalAttribute
24
24
import org .apache .spark .ml .param ._
25
25
import org .apache .spark .ml .param .shared ._
26
- import org .apache .spark .ml .util .SchemaUtils
27
26
import org .apache .spark .sql .DataFrame
28
27
import org .apache .spark .sql .functions ._
29
- import org .apache .spark .sql .types .{StringType , StructType }
28
+ import org .apache .spark .sql .types .{NumericType , StringType , StructType }
30
29
import org .apache .spark .util .collection .OpenHashMap
31
30
32
31
/**
@@ -37,7 +36,10 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
37
36
/** Validates and transforms the input schema. */
38
37
protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
39
38
val map = extractParamMap(paramMap)
40
- SchemaUtils .checkColumnType(schema, map(inputCol), StringType )
39
+ val inputColName = map(inputCol)
40
+ val inputDataType = schema(inputColName).dataType
41
+ require(inputDataType == StringType || inputDataType.isInstanceOf [NumericType ],
42
+ s " The input column $inputColName must be either string type or numeric type. " )
41
43
val inputFields = schema.fields
42
44
val outputColName = map(outputCol)
43
45
require(inputFields.forall(_.name != outputColName),
@@ -51,6 +53,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
51
53
/**
52
54
* :: AlphaComponent ::
53
55
* A label indexer that maps a string column of labels to an ML column of label indices.
56
+ * If the input column is numeric, we cast it to string and index the string values.
54
57
* The indices are in [0, numLabels), ordered by label frequencies.
55
58
* So the most frequent label gets index 0.
56
59
*/
@@ -67,7 +70,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
67
70
68
71
override def fit (dataset : DataFrame , paramMap : ParamMap ): StringIndexerModel = {
69
72
val map = extractParamMap(paramMap)
70
- val counts = dataset.select(map(inputCol)).map(_.getString(0 )).countByValue()
73
+ val counts = dataset.select(col(map(inputCol)).cast(StringType ))
74
+ .map(_.getString(0 ))
75
+ .countByValue()
71
76
val labels = counts.toSeq.sortBy(- _._2).map(_._1).toArray
72
77
val model = new StringIndexerModel (this , map, labels)
73
78
Params .inheritValues(map, this , model)
@@ -119,7 +124,8 @@ class StringIndexerModel private[ml] (
119
124
val outputColName = map(outputCol)
120
125
val metadata = NominalAttribute .defaultAttr
121
126
.withName(outputColName).withValues(labels).toMetadata()
122
- dataset.select(col(" *" ), indexer(dataset(map(inputCol))).as(outputColName, metadata))
127
+ dataset.select(col(" *" ),
128
+ indexer(dataset(map(inputCol)).cast(StringType )).as(outputColName, metadata))
123
129
}
124
130
125
131
override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
0 commit comments