Skip to content

Commit 9b23e92

Browse files
brkyvzrxin
authored andcommitted
[SPARK-8803] handle special characters in elements in crosstab
cc rxin Having back ticks or null as elements causes problems. Since elements become column names, we have to drop them from the element as back ticks are special characters. Having null throws exceptions, we could replace them with empty strings. Handling back ticks should be improved for 1.5 Author: Burak Yavuz <[email protected]> Closes #7201 from brkyvz/weird-ct-elements and squashes the following commits: e06b840 [Burak Yavuz] fix scalastyle 93a0d3f [Burak Yavuz] added tests for NaN and Infinity 9dba6ce [Burak Yavuz] address cr1 db71dbd [Burak Yavuz] handle special characters in elements in crosstab
1 parent f743c79 commit 9b23e92

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
391391
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
392392
*/
393393
private def fillCol[T](col: StructField, replacement: T): Column = {
394-
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
394+
coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
395395
}
396396

397397
/**

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
7878
* The first column of each row will be the distinct values of `col1` and the column names will
7979
* be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts
8080
* will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
81+
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
82+
* exist.
83+
*
8184
*
8285
* @param col1 The name of the first column. Distinct items will make the first item of
8386
* each row.

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging {
110110
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
111111
"the pairs. Please try reducing the amount of distinct items in your columns.")
112112
}
113+
def cleanElement(element: Any): String = {
114+
if (element == null) "null" else element.toString
115+
}
113116
// get the distinct values of column 2, so that we can make them the column names
114-
val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap
117+
val distinctCol2: Map[Any, Int] =
118+
counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
115119
val columnSize = distinctCol2.size
116120
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
117121
s"exceed 1e4. Currently $columnSize")
@@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging {
121125
// row.get(0) is column 1
122126
// row.get(1) is column 2
123127
// row.get(2) is the frequency
124-
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
128+
val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get
129+
countsRow.setLong(columnIndex + 1, row.getLong(2))
125130
}
126131
// the value of col1 is the first value, the rest are the counts
127-
countsRow.update(0, UTF8String.fromString(col1Item.toString))
132+
countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
128133
countsRow
129134
}.toSeq
135+
// Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
136+
// special keywords and `.`, wrap the column names in ``.
137+
def cleanColumnName(name: String): String = {
138+
name.replace("`", "")
139+
}
130140
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
131141
// SPARK-8681. We need to explicitly sort by the column index and assign the column names.
132-
val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType))
142+
val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
143+
StructField(cleanColumnName(r._1.toString), LongType)
144+
}
133145
val schema = StructType(StructField(tableName, StringType) +: headerNames)
134146

135147
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite {
8585
}
8686
}
8787

88+
test("special crosstab elements (., '', null, ``)") {
89+
val data = Seq(
90+
("a", Double.NaN, "ho"),
91+
(null, 2.0, "ho"),
92+
("a.b", Double.NegativeInfinity, ""),
93+
("b", Double.PositiveInfinity, "`ha`"),
94+
("a", 1.0, null)
95+
)
96+
val df = data.toDF("1", "2", "3")
97+
val ct1 = df.stat.crosstab("1", "2")
98+
// column fields should be 1 + distinct elements of second column
99+
assert(ct1.schema.fields.length === 6)
100+
assert(ct1.collect().length === 4)
101+
val ct2 = df.stat.crosstab("1", "3")
102+
assert(ct2.schema.fields.length === 5)
103+
assert(ct2.schema.fieldNames.contains("ha"))
104+
assert(ct2.collect().length === 4)
105+
val ct3 = df.stat.crosstab("3", "2")
106+
assert(ct3.schema.fields.length === 6)
107+
assert(ct3.schema.fieldNames.contains("NaN"))
108+
assert(ct3.schema.fieldNames.contains("Infinity"))
109+
assert(ct3.schema.fieldNames.contains("-Infinity"))
110+
assert(ct3.collect().length === 4)
111+
val ct4 = df.stat.crosstab("3", "1")
112+
assert(ct4.schema.fields.length === 5)
113+
assert(ct4.schema.fieldNames.contains("null"))
114+
assert(ct4.schema.fieldNames.contains("a.b"))
115+
assert(ct4.collect().length === 4)
116+
}
117+
88118
test("Frequent Items") {
89119
val rows = Seq.tabulate(1000) { i =>
90120
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)

0 commit comments

Comments
 (0)