Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 6396cc0

Browse files
committed
[SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear
Author: Reynold Xin <[email protected]> Closes apache#6566 from rxin/crosstab and squashes the following commits: e0ace1c [Reynold Xin] [SPARK-7982][SQL] DataFrame.stat.crosstab should use 0 instead of null for pairs that don't appear
1 parent 6b44278 commit 6396cc0

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.stat
1919

2020
import org.apache.spark.Logging
21-
import org.apache.spark.sql.{Column, DataFrame}
21+
import org.apache.spark.sql.{Row, Column, DataFrame}
2222
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
2323
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2424
import org.apache.spark.sql.functions._
@@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging {
116116
s"exceed 1e4. Currently $columnSize")
117117
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
118118
val countsRow = new GenericMutableRow(columnSize + 1)
119-
rows.foreach { row =>
119+
rows.foreach { (row: Row) =>
120+
// row.get(0) is column 1
121+
// row.get(1) is column 2
122+
// row.get(3) is the frequency
120123
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
121124
}
122125
// the value of col1 is the first value, the rest are the counts
@@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging {
126129
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
127130
val schema = StructType(StructField(tableName, StringType) +: headerNames)
128131

129-
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
132+
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
130133
}
131134
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ class DataFrameStatSuite extends SparkFunSuite {
7474
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
7575
assert(rows(0).get(0).toString === "0")
7676
assert(rows(0).getLong(1) === 2L)
77-
assert(rows(0).get(2) === null)
77+
assert(rows(0).get(2) === 0L)
7878
assert(rows(1).get(0).toString === "1")
7979
assert(rows(1).getLong(1) === 1L)
80-
assert(rows(1).get(2) === null)
80+
assert(rows(1).get(2) === 0L)
8181
assert(rows(2).get(0).toString === "2")
8282
assert(rows(2).getLong(1) === 2L)
8383
assert(rows(2).getLong(2) === 1L)

0 commit comments

Comments
 (0)