Skip to content

[SPARK-8803] handle special characters in elements in crosstab #7201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* The first column of each row will be the distinct values of `col1` and the column names will
* be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts
* will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging {
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.")
}
def cleanElement(element: Any): String = {
if (element == null) "null" else element.toString
}
// get the distinct values of column 2, so that we can make them the column names
val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap
val distinctCol2: Map[Any, Int] =
counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
Expand All @@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging {
// row.get(0) is column 1
// row.get(1) is column 2
// row.get(2) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get
countsRow.setLong(columnIndex + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
countsRow.update(0, UTF8String.fromString(col1Item.toString))
countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
countsRow
}.toSeq
// Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
// special keywords and `.`, wrap the column names in ``.
def cleanColumnName(name: String): String = {
name.replace("`", "")
}
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
// SPARK-8681. We need to explicitly sort by the column index and assign the column names.
val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType))
val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
StructField(cleanColumnName(r._1.toString), LongType)
}
val schema = StructType(StructField(tableName, StringType) +: headerNames)

new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ class DataFrameStatSuite extends SparkFunSuite {
}
}

test("special crosstab elements (., '', null, ``)") {
val data = Seq(
("a", Double.NaN, "ho"),
(null, 2.0, "ho"),
("a.b", Double.NegativeInfinity, ""),
("b", Double.PositiveInfinity, "`ha`"),
("a", 1.0, null)
)
val df = data.toDF("1", "2", "3")
val ct1 = df.stat.crosstab("1", "2")
// column fields should be 1 + distinct elements of second column
assert(ct1.schema.fields.length === 6)
assert(ct1.collect().length === 4)
val ct2 = df.stat.crosstab("1", "3")
assert(ct2.schema.fields.length === 5)
assert(ct2.schema.fieldNames.contains("ha"))
assert(ct2.collect().length === 4)
val ct3 = df.stat.crosstab("3", "2")
assert(ct3.schema.fields.length === 6)
assert(ct3.schema.fieldNames.contains("NaN"))
assert(ct3.schema.fieldNames.contains("Infinity"))
assert(ct3.schema.fieldNames.contains("-Infinity"))
assert(ct3.collect().length === 4)
val ct4 = df.stat.crosstab("3", "1")
assert(ct4.schema.fields.length === 5)
assert(ct4.schema.fieldNames.contains("null"))
assert(ct4.schema.fieldNames.contains("a.b"))
assert(ct4.collect().length === 4)
}

test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
Expand Down