Skip to content

Commit 60a908e

Browse files
maropudongjoon-hyun
authored andcommitted
[SPARK-29708][SQL][2.4] Correct aggregated values when grouping sets are duplicated
### What changes were proposed in this pull request? This pr intends to fix wrong aggregated values in `GROUPING SETS` when there are duplicated grouping sets in a query (e.g., `GROUPING SETS ((k1),(k1))`). For example; ``` scala> spark.table("t").show() +---+---+---+ | k1| k2| v| +---+---+---+ | 0| 0| 3| +---+---+---+ scala> sql("""select grouping_id(), k1, k2, sum(v) from t group by grouping sets ((k1),(k1,k2),(k2,k1),(k1,k2))""").show() +-------------+---+----+------+ |grouping_id()| k1| k2|sum(v)| +-------------+---+----+------+ | 0| 0| 0| 9| <---- wrong aggregate value and the correct answer is `3` | 1| 0|null| 3| +-------------+---+----+------+ // PostgreSQL case postgres=# select k1, k2, sum(v) from t group by grouping sets ((k1),(k1,k2),(k2,k1),(k1,k2)); k1 | k2 | sum ----+------+----- 0 | 0 | 3 0 | 0 | 3 0 | 0 | 3 0 | NULL | 3 (4 rows) // Hive case hive> select GROUPING__ID, k1, k2, sum(v) from t group by k1, k2 grouping sets ((k1),(k1,k2),(k2,k1),(k1,k2)); 1 0 NULL 3 0 0 0 3 ``` [MS SQL Server has the same behaviour with PostgreSQL](#26961 (comment)). This pr follows the behaviour of PostgreSQL/SQL server; it adds one more virtual attribute in `Expand` for avoiding wrongly grouping rows with the same grouping ID. This is the #26961 backport for `branch-2.4` ### Why are the changes needed? To fix bugs. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? The existing tests. Closes #27229 from maropu/SPARK-29708-BRANCHC2.4. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d6261a1 commit 60a908e

File tree

3 files changed

+70
-6
lines changed

3 files changed

+70
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,14 @@ object Expand {
670670
child: LogicalPlan): Expand = {
671671
val attrMap = groupByAttrs.zipWithIndex.toMap
672672

673+
val hasDuplicateGroupingSets = groupingSetsAttrs.size !=
674+
groupingSetsAttrs.map(_.map(_.exprId).toSet).distinct.size
675+
673676
// Create an array of Projections for the child projection, and replace the projections'
674677
// expressions which equal GroupBy expressions with Literal(null), if those expressions
675678
// are not set for this grouping set.
676-
val projections = groupingSetsAttrs.map { groupingSetAttrs =>
677-
child.output ++ groupByAttrs.map { attr =>
679+
val projections = groupingSetsAttrs.zipWithIndex.map { case (groupingSetAttrs, i) =>
680+
val projAttrs = child.output ++ groupByAttrs.map { attr =>
678681
if (!groupingSetAttrs.contains(attr)) {
679682
// if the input attribute in the Invalid Grouping Expression set of for this group
680683
// replace it with constant null
@@ -684,11 +687,25 @@ object Expand {
684687
}
685688
// groupingId is the last output, here we use the bit mask as the concrete value for it.
686689
} :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType)
690+
691+
if (hasDuplicateGroupingSets) {
692+
// If `groupingSetsAttrs` has duplicate entries (e.g., GROUPING SETS ((key), (key))),
693+
// we add one more virtual grouping attribute (`_gen_grouping_pos`) to avoid
694+
// wrongly grouping rows with the same grouping ID.
695+
projAttrs :+ Literal.create(i, IntegerType)
696+
} else {
697+
projAttrs
698+
}
687699
}
688700

689701
// the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
690702
// grouping expression or null, so here we create new instance of it.
691-
val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
703+
val output = if (hasDuplicateGroupingSets) {
704+
val gpos = AttributeReference("_gen_grouping_pos", IntegerType, false)()
705+
child.output ++ groupByAttrs.map(_.newInstance) :+ gid :+ gpos
706+
} else {
707+
child.output ++ groupByAttrs.map(_.newInstance) :+ gid
708+
}
692709
Expand(projections, output, Project(child.output ++ groupByAliases, child))
693710
}
694711
}

sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,9 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE;
5151

5252
SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (());
5353

54+
-- duplicate entries in grouping sets
55+
SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1));
56+
57+
SELECT grouping__id, k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1));
58+
59+
SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1));

sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 15
2+
-- Number of queries: 18
33

44

55
-- !query 0
@@ -110,8 +110,10 @@ SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUP
110110
-- !query 10 schema
111111
struct<(a + b):int,b:int,sum(c):bigint>
112112
-- !query 10 output
113-
2 NULL 2
114-
4 NULL 4
113+
2 NULL 1
114+
2 NULL 1
115+
4 NULL 2
116+
4 NULL 2
115117
NULL 1 1
116118
NULL 2 2
117119

@@ -164,3 +166,42 @@ struct<>
164166
-- !query 14 output
165167
org.apache.spark.sql.AnalysisException
166168
expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
169+
170+
171+
-- !query 15
172+
SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1))
173+
-- !query 15 schema
174+
struct<k1:int,k2:int,avg(v):double>
175+
-- !query 15 output
176+
1 1 1.0
177+
1 1 1.0
178+
1 NULL 1.0
179+
2 2 2.0
180+
2 2 2.0
181+
2 NULL 2.0
182+
183+
184+
-- !query 16
185+
SELECT grouping__id, k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1))
186+
-- !query 16 schema
187+
struct<grouping__id:int,k1:int,k2:int,avg(v):double>
188+
-- !query 16 output
189+
0 1 1 1.0
190+
0 1 1 1.0
191+
0 2 2 2.0
192+
0 2 2 2.0
193+
1 1 NULL 1.0
194+
1 2 NULL 2.0
195+
196+
197+
-- !query 17
198+
SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1))
199+
-- !query 17 schema
200+
struct<grouping(k1):tinyint,k1:int,k2:int,avg(v):double>
201+
-- !query 17 output
202+
0 1 1 1.0
203+
0 1 1 1.0
204+
0 1 NULL 1.0
205+
0 2 2 2.0
206+
0 2 2 2.0
207+
0 2 NULL 2.0

0 commit comments

Comments
 (0)