Skip to content

Commit a6a7c1d

Browse files
committed
Merge pull request #79 from markhamstra/csd-1.4_empty_sum_null
Sum of NULL values should return NULL
2 parents b197c0f + 86c0228 commit a6a7c1d

File tree

7 files changed

+33
-7
lines changed

7 files changed

+33
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
619619
private val sum = MutableLiteral(null, calcType)
620620

621621
private val addFunction =
622-
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
622+
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
623623

624624
override def update(input: Row): Unit = {
625625
sum.update(addFunction, input)

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ case class GeneratedAggregate(
107107
Add(
108108
Coalesce(currentSum :: zero :: Nil),
109109
Cast(expr, calcType)
110-
) :: currentSum :: zero :: Nil)
110+
) :: currentSum :: Nil)
111111
val result =
112112
expr.dataType match {
113113
case DecimalType.Fixed(_, _) =>

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
2626
import org.apache.spark.sql.execution.GeneratedAggregate
2727
import org.apache.spark.sql.functions._
2828
import org.apache.spark.sql.TestData._
29-
import org.apache.spark.sql.test.TestSQLContext
29+
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
3030
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
3131

3232
import org.apache.spark.sql.types._
@@ -253,7 +253,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
253253
// Aggregate with Code generation handling all null values
254254
testCodeGen(
255255
"SELECT sum('a'), avg('a'), count(null) FROM testData",
256-
Row(0, null, 0) :: Nil)
256+
Row(null, null, 0) :: Nil)
257257

258258
dropTempTable("testData3x")
259259
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
814814
"udaf_covar_pop",
815815
"udaf_covar_samp",
816816
"udaf_histogram_numeric",
817-
"udaf_number_format",
818817
"udf2",
819818
"udf5",
820819
"udf6",

sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964

Whitespace-only changes.

sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16

Lines changed: 0 additions & 1 deletion
This file was deleted.

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.hive.execution
1919

20+
import org.apache.spark.sql.TestData.NullInts
2021
import org.apache.spark.sql.catalyst.DefaultParserDialect
2122
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
2223
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -27,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
2728
import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation}
2829
import org.apache.spark.sql.parquet.ParquetRelation2
2930
import org.apache.spark.sql.sources.LogicalRelation
31+
import org.apache.spark.sql.test.SQLTestUtils
3032
import org.apache.spark.sql.types._
3133

3234
case class Nested1(f1: Nested2)
@@ -59,7 +61,9 @@ class MyDialect extends DefaultParserDialect
5961
* Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
6062
* valid, but Hive currently cannot execute it.
6163
*/
62-
class SQLQuerySuite extends QueryTest {
64+
class SQLQuerySuite extends QueryTest with SQLTestUtils {
65+
override val sqlContext: SQLContext = TestHive
66+
6367
test("SPARK-6835: udtf in lateral view") {
6468
val df = Seq((1, 1)).toDF("c1", "c2")
6569
df.registerTempTable("table1")
@@ -946,4 +950,28 @@ class SQLQuerySuite extends QueryTest {
946950

947951
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
948952
}
953+
954+
test("SPARK-8828 sum should return null if all input values are null") {
955+
val allNulls =
956+
TestHive.sparkContext.parallelize(
957+
NullInts(null) ::
958+
NullInts(null) ::
959+
NullInts(null) ::
960+
NullInts(null) :: Nil).toDF()
961+
allNulls.registerTempTable("allNulls")
962+
963+
withSQLConf(SQLConf.CODEGEN_ENABLED -> "true") {
964+
checkAnswer(
965+
sql("select sum(a), avg(a) from allNulls"),
966+
Seq(Row(null, null))
967+
)
968+
}
969+
withSQLConf(SQLConf.CODEGEN_ENABLED -> "false") {
970+
checkAnswer(
971+
sql("select sum(a), avg(a) from allNulls"),
972+
Seq(Row(null, null))
973+
)
974+
}
975+
}
976+
949977
}

0 commit comments

Comments
 (0)