@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
23
23
import org .apache .spark .ml .attribute .NominalAttribute
24
24
import org .apache .spark .ml .param ._
25
25
import org .apache .spark .ml .param .shared ._
26
- import org .apache .spark .ml .util .SchemaUtils
27
26
import org .apache .spark .sql .DataFrame
28
27
import org .apache .spark .sql .functions ._
29
- import org .apache .spark .sql .types .{StringType , StructType }
28
+ import org .apache .spark .sql .types .{NumericType , StringType , StructType }
30
29
import org .apache .spark .util .collection .OpenHashMap
31
30
32
31
/**
@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
37
36
/** Validates and transforms the input schema. */
38
37
protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
39
38
val map = extractParamMap(paramMap)
40
- SchemaUtils .checkColumnType(schema, map(inputCol), StringType )
39
+ val inputColName = map(inputCol)
40
+ val inputDataType = schema(inputColName).dataType
41
+ require(inputDataType == StringType || inputDataType.isInstanceOf [NumericType ],
42
+ s " The input column $inputColName must be either string type or numeric type, " +
43
+ s " but got $inputDataType. " )
41
44
val inputFields = schema.fields
42
45
val outputColName = map(outputCol)
43
46
require(inputFields.forall(_.name != outputColName),
@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
51
54
/**
52
55
* :: AlphaComponent ::
53
56
* A label indexer that maps a string column of labels to an ML column of label indices.
57
+ * If the input column is numeric, we cast it to string and index the string values.
54
58
* The indices are in [0, numLabels), ordered by label frequencies.
55
59
* So the most frequent label gets index 0.
56
60
*/
@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
67
71
68
72
override def fit (dataset : DataFrame , paramMap : ParamMap ): StringIndexerModel = {
69
73
val map = extractParamMap(paramMap)
70
- val counts = dataset.select(map(inputCol)).map(_.getString(0 )).countByValue()
74
+ val counts = dataset.select(col(map(inputCol)).cast(StringType ))
75
+ .map(_.getString(0 ))
76
+ .countByValue()
71
77
val labels = counts.toSeq.sortBy(- _._2).map(_._1).toArray
72
78
val model = new StringIndexerModel (this , map, labels)
73
79
Params .inheritValues(map, this , model)
@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
119
125
val outputColName = map(outputCol)
120
126
val metadata = NominalAttribute .defaultAttr
121
127
.withName(outputColName).withValues(labels).toMetadata()
122
- dataset.select(col(" *" ), indexer(dataset(map(inputCol))).as(outputColName, metadata))
128
+ dataset.select(col(" *" ),
129
+ indexer(dataset(map(inputCol)).cast(StringType )).as(outputColName, metadata))
123
130
}
124
131
125
132
override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
0 commit comments