Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 0221c7f

Browse files
committed
[SPARK-7582] [MLLIB] user guide for StringIndexer
This PR adds a Java unit test and user guide for `StringIndexer`. I put it before `OneHotEncoder` because they are closely related. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#6561 from mengxr/SPARK-7582 and squashes the following commits: 4bba4f1 [Xiangrui Meng] fix example ba1cd1b [Xiangrui Meng] fix style 7fa18d1 [Xiangrui Meng] add user guide for StringIndexer 136cb93 [Xiangrui Meng] add a Java unit test for StringIndexer
1 parent b53a011 commit 0221c7f

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

docs/ml-features.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
456456
</div>
457457
</div>
458458

459+
## StringIndexer
460+
461+
`StringIndexer` encodes a string column of labels to a column of label indices.
462+
The indices are in `[0, numLabels)`, ordered by label frequencies.
463+
So the most frequent label gets index `0`.
464+
If the input column is numeric, we cast it to string and index the string values.
465+
466+
**Examples**
467+
468+
Assume that we have the following DataFrame with columns `id` and `category`:
469+
470+
~~~~
471+
id | category
472+
----|----------
473+
0 | a
474+
1 | b
475+
2 | c
476+
3 | a
477+
4 | a
478+
5 | c
479+
~~~~
480+
481+
`category` is a string column with three labels: "a", "b", and "c".
482+
Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output
483+
column, we should get the following:
484+
485+
~~~~
486+
id | category | categoryIndex
487+
----|----------|---------------
488+
0 | a | 0.0
489+
1 | b | 2.0
490+
2 | c | 1.0
491+
3 | a | 0.0
492+
4 | a | 0.0
493+
5 | c | 1.0
494+
~~~~
495+
496+
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
497+
index `2`.
498+
499+
<div class="codetabs">
500+
501+
<div data-lang="scala" markdown="1">
502+
503+
[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input
504+
column name and an output column name.
505+
506+
{% highlight scala %}
507+
import org.apache.spark.ml.feature.StringIndexer
508+
509+
val df = sqlContext.createDataFrame(
510+
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
511+
).toDF("id", "category")
512+
val indexer = new StringIndexer()
513+
.setInputCol("category")
514+
.setOutputCol("categoryIndex")
515+
val indexed = indexer.fit(df).transform(df)
516+
indexed.show()
517+
{% endhighlight %}
518+
</div>
519+
520+
<div data-lang="java" markdown="1">
521+
[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column
522+
name and an output column name.
523+
524+
{% highlight java %}
525+
import java.util.Arrays;
526+
527+
import org.apache.spark.api.java.JavaRDD;
528+
import org.apache.spark.ml.feature.StringIndexer;
529+
import org.apache.spark.sql.DataFrame;
530+
import org.apache.spark.sql.Row;
531+
import org.apache.spark.sql.RowFactory;
532+
import org.apache.spark.sql.types.StructField;
533+
import org.apache.spark.sql.types.StructType;
534+
import static org.apache.spark.sql.types.DataTypes.*;
535+
536+
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
537+
RowFactory.create(0, "a"),
538+
RowFactory.create(1, "b"),
539+
RowFactory.create(2, "c"),
540+
RowFactory.create(3, "a"),
541+
RowFactory.create(4, "a"),
542+
RowFactory.create(5, "c")
543+
));
544+
StructType schema = new StructType(new StructField[] {
545+
createStructField("id", DoubleType, false),
546+
createStructField("category", StringType, false)
547+
});
548+
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
549+
StringIndexer indexer = new StringIndexer()
550+
.setInputCol("category")
551+
.setOutputCol("categoryIndex");
552+
DataFrame indexed = indexer.fit(df).transform(df);
553+
indexed.show();
554+
{% endhighlight %}
555+
</div>
556+
557+
<div data-lang="python" markdown="1">
558+
559+
[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input
560+
column name and an output column name.
561+
562+
{% highlight python %}
563+
from pyspark.ml.feature import StringIndexer
564+
565+
df = sqlContext.createDataFrame(
566+
[(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
567+
["id", "category"])
568+
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
569+
indexed = indexer.fit(df).transform(df)
570+
indexed.show()
571+
{% endhighlight %}
572+
</div>
573+
</div>
574+
459575
## OneHotEncoder
460576

461577
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature;
19+
20+
import java.util.Arrays;
21+
22+
import org.junit.After;
23+
import org.junit.Assert;
24+
import org.junit.Before;
25+
import org.junit.Test;
26+
27+
import org.apache.spark.api.java.JavaRDD;
28+
import org.apache.spark.api.java.JavaSparkContext;
29+
import org.apache.spark.sql.DataFrame;
30+
import org.apache.spark.sql.Row;
31+
import org.apache.spark.sql.RowFactory;
32+
import org.apache.spark.sql.SQLContext;
33+
import org.apache.spark.sql.types.StructField;
34+
import org.apache.spark.sql.types.StructType;
35+
import static org.apache.spark.sql.types.DataTypes.*;
36+
37+
public class JavaStringIndexerSuite {
38+
private transient JavaSparkContext jsc;
39+
private transient SQLContext sqlContext;
40+
41+
@Before
42+
public void setUp() {
43+
jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
44+
sqlContext = new SQLContext(jsc);
45+
}
46+
47+
@After
48+
public void tearDown() {
49+
jsc.stop();
50+
sqlContext = null;
51+
}
52+
53+
@Test
54+
public void testStringIndexer() {
55+
StructType schema = createStructType(new StructField[] {
56+
createStructField("id", IntegerType, false),
57+
createStructField("label", StringType, false)
58+
});
59+
JavaRDD<Row> rdd = jsc.parallelize(
60+
Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
61+
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
62+
63+
StringIndexer indexer = new StringIndexer()
64+
.setInputCol("label")
65+
.setOutputCol("labelIndex");
66+
DataFrame output = indexer.fit(dataset).transform(dataset);
67+
68+
Assert.assertArrayEquals(
69+
new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) },
70+
output.orderBy("id").select("id", "labelIndex").collect());
71+
}
72+
73+
/** An alias for RowFactory.create. */
74+
private Row c(Object... values) {
75+
return RowFactory.create(values);
76+
}
77+
}

0 commit comments

Comments
 (0)