Skip to content

Commit cfc0495

Browse files
peter-tothcloud-fan
authored andcommitted
[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions without aggregate function
### What changes were proposed in this pull request? This PR adds a new rule `PullOutGroupingExpressions` to pull out complex grouping expressions to a `Project` node under an `Aggregate`. These expressions are then referenced in both grouping expressions and aggregate expressions without aggregate functions to ensure that optimization rules don't change the aggregate expressions to invalid ones that no longer refer to any grouping expressions. ### Why are the changes needed? If aggregate expressions (without aggregate functions) in an `Aggregate` node are complex then the `Optimizer` can optimize out grouping expressions from them and so making aggregate expressions invalid. Here is a simple example: ``` SELECT not(t.id IS NULL) , count(*) FROM t GROUP BY t.id IS NULL ``` In this case the `BooleanSimplification` rule does this: ``` === Applying Rule org.apache.spark.sql.catalyst.optimizer.BooleanSimplification === !Aggregate [isnull(id#222)], [NOT isnull(id#222) AS (NOT (id IS NULL))#226, count(1) AS c#224L] Aggregate [isnull(id#222)], [isnotnull(id#222) AS (NOT (id IS NULL))#226, count(1) AS c#224L] +- Project [value#219 AS id#222] +- Project [value#219 AS id#222] +- LocalRelation [value#219] +- LocalRelation [value#219] ``` where `NOT isnull(id#222)` is optimized to `isnotnull(id#222)` and so it no longer refers to any grouping expression. Before this PR: ``` == Optimized Logical Plan == Aggregate [isnull(id#222)], [isnotnull(id#222) AS (NOT (id IS NULL))#234, count(1) AS c#232L] +- Project [value#219 AS id#222] +- LocalRelation [value#219] ``` and running the query throws an error: ``` Couldn't find id#222 in [isnull(id#222)#230,count(1)#226L] java.lang.IllegalStateException: Couldn't find id#222 in [isnull(id#222)#230,count(1)#226L] ``` After this PR: ``` == Optimized Logical Plan == Aggregate [_groupingexpression#233], [NOT _groupingexpression#233 AS (NOT (id IS NULL))#230, count(1) AS c#228L] +- Project [isnull(value#219) AS _groupingexpression#233] +- LocalRelation [value#219] ``` and the query works. ### Does this PR introduce _any_ user-facing change? Yes, the query works. ### How was this patch tested? Added new UT. Closes #32396 from peter-toth/SPARK-34581-keep-grouping-expressions-2. Authored-by: Peter Toth <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 6ce1b16 commit cfc0495

File tree

24 files changed

+239
-139
lines changed

24 files changed

+239
-139
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ object AggregateExpression {
8080
filter,
8181
NamedExpression.newExprId)
8282
}
83+
84+
def containsAggregate(expr: Expression): Boolean = {
85+
expr.find(isAggregate).isDefined
86+
}
87+
88+
def isAggregate(expr: Expression): Boolean = {
89+
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
90+
}
8391
}
8492

8593
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,14 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323

2424
/**
2525
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
2626
*/
2727
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
2828
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
29-
// One place where this optimization is invalid is an aggregation where the select
30-
// list expression is a function of a grouping expression:
31-
//
32-
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
33-
//
34-
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
35-
// optimization for Aggregates (although this misses some cases where the optimization
36-
// can be made).
37-
case a: Aggregate => a
3829
case p => p.transformExpressionsUp {
3930
// Remove redundant field extraction.
4031
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
148148
EliminateView,
149149
ReplaceExpressions,
150150
RewriteNonCorrelatedExists,
151+
PullOutGroupingExpressions,
151152
ComputeCurrentTime,
152153
GetCurrentDatabaseAndCatalog(catalogManager)) ::
153154
//////////////////////////////////////////////////////////////////////////////////////////
@@ -267,7 +268,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
267268
RewriteCorrelatedScalarSubquery.ruleName ::
268269
RewritePredicateSubquery.ruleName ::
269270
NormalizeFloatingNumbers.ruleName ::
270-
ReplaceUpdateFieldsExpression.ruleName :: Nil
271+
ReplaceUpdateFieldsExpression.ruleName ::
272+
PullOutGroupingExpressions.ruleName :: Nil
271273

272274
/**
273275
* Optimize all the subqueries inside expression.
@@ -524,23 +526,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
524526
}
525527

526528
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
527-
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
529+
val upperHasNoAggregateExpressions =
530+
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
528531

529532
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
530533
lower
531534
.aggregateExpressions
532535
.filter(_.deterministic)
533-
.filter(!isAggregate(_))
536+
.filterNot(AggregateExpression.containsAggregate)
534537
.map(_.toAttribute)
535538
))
536539

537540
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
538541
}
539-
540-
private def isAggregate(expr: Expression): Boolean = {
541-
expr.find(e => e.isInstanceOf[AggregateExpression] ||
542-
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
543-
}
544542
}
545543

546544
/**
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
24+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
25+
import org.apache.spark.sql.catalyst.rules.Rule
26+
27+
/**
28+
* This rule ensures that [[Aggregate]] nodes doesn't contain complex grouping expressions in the
29+
* optimization phase.
30+
*
31+
* Complex grouping expressions are pulled out to a [[Project]] node under [[Aggregate]] and are
32+
* referenced in both grouping expressions and aggregate expressions without aggregate functions.
33+
* These references ensure that optimization rules don't change the aggregate expressions to invalid
34+
* ones that no longer refer to any grouping expressions and also simplify the expression
35+
* transformations on the node (need to transform the expression only once).
36+
*
37+
* For example, in the following query Spark shouldn't optimize the aggregate expression
38+
* `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`:
39+
* SELECT not(c IS NULL)
40+
* FROM t
41+
* GROUP BY c IS NULL
42+
* Instead, the aggregate expression references a `_groupingexpression` attribute:
43+
* Aggregate [_groupingexpression#233], [NOT _groupingexpression#233 AS (NOT (c IS NULL))#230]
44+
* +- Project [isnull(c#219) AS _groupingexpression#233]
45+
* +- LocalRelation [c#219]
46+
*/
47+
object PullOutGroupingExpressions extends Rule[LogicalPlan] {
48+
override def apply(plan: LogicalPlan): LogicalPlan = {
49+
plan transform {
50+
case a: Aggregate if a.resolved =>
51+
val complexGroupingExpressionMap = mutable.LinkedHashMap.empty[Expression, NamedExpression]
52+
val newGroupingExpressions = a.groupingExpressions.map {
53+
case e if !e.foldable && e.children.nonEmpty =>
54+
complexGroupingExpressionMap
55+
.getOrElseUpdate(e.canonicalized, Alias(e, s"_groupingexpression")())
56+
.toAttribute
57+
case o => o
58+
}
59+
if (complexGroupingExpressionMap.nonEmpty) {
60+
def replaceComplexGroupingExpressions(e: Expression): Expression = {
61+
e match {
62+
case _ if AggregateExpression.isAggregate(e) => e
63+
case _ if e.foldable => e
64+
case _ if complexGroupingExpressionMap.contains(e.canonicalized) =>
65+
complexGroupingExpressionMap.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
66+
case _ => e.mapChildren(replaceComplexGroupingExpressions)
67+
}
68+
}
69+
70+
val newAggregateExpressions = a.aggregateExpressions
71+
.map(replaceComplexGroupingExpressions(_).asInstanceOf[NamedExpression])
72+
val newChild = Project(a.child.output ++ complexGroupingExpressionMap.values, a.child)
73+
Aggregate(newGroupingExpressions, newAggregateExpressions, newChild)
74+
} else {
75+
a
76+
}
77+
}
78+
}
79+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,9 @@ object PhysicalAggregation {
297297
val aggregateExpressions = resultExpressions.flatMap { expr =>
298298
expr.collect {
299299
// addExpr() always returns false for non-deterministic expressions and do not add them.
300-
case agg: AggregateExpression
301-
if !equivalentAggregateExpressions.addExpr(agg) => agg
302-
case udf: PythonUDF
303-
if PythonUDF.isGroupedAggPandasUDF(udf) &&
304-
!equivalentAggregateExpressions.addExpr(udf) => udf
300+
case a
301+
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
302+
a
305303
}
306304
}
307305

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
3636

3737
object Optimizer extends RuleExecutor[LogicalPlan] {
3838
val batches =
39+
Batch("Finish Analysis", Once,
40+
PullOutGroupingExpressions) ::
3941
Batch("collapse projections", FixedPoint(10),
4042
CollapseProject) ::
4143
Batch("Constant Folding", FixedPoint(10),
@@ -57,7 +59,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
5759
private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
5860
val optimized = Optimizer.execute(originalQuery.analyze)
5961
assert(optimized.resolved, "optimized plans must be still resolvable")
60-
comparePlans(optimized, correctAnswer.analyze)
62+
comparePlans(optimized, PullOutGroupingExpressions(correctAnswer.analyze))
6163
}
6264

6365
test("explicit get from namedStruct") {
@@ -405,14 +407,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
405407
val arrayAggRel = relation.groupBy(
406408
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
407409
checkRule(arrayAggRel, arrayAggRel)
408-
409-
// This could be done if we had a more complex rule that checks that
410-
// the CreateMap does not come from key.
411-
val originalQuery = relation
412-
.groupBy('id)(
413-
GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
414-
)
415-
checkRule(originalQuery, originalQuery)
416410
}
417411

418412
test("SPARK-23500: namedStruct and getField in the same Project #1") {

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,13 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(
179179

180180
-- Aggregate with multiple distinct decimal columns
181181
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col);
182+
183+
-- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function
184+
SELECT not(a IS NULL), count(*) AS c
185+
FROM testData
186+
GROUP BY a IS NULL;
187+
188+
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
189+
FROM testData
190+
GROUP BY a IS NULL;
191+

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 62
2+
-- Number of queries: 64
33

44

55
-- !query
@@ -642,3 +642,25 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1
642642
struct<avg(DISTINCT decimal_col):decimal(13,4),sum(DISTINCT decimal_col):decimal(19,0)>
643643
-- !query output
644644
1.0000 1
645+
646+
647+
-- !query
648+
SELECT not(a IS NULL), count(*) AS c
649+
FROM testData
650+
GROUP BY a IS NULL
651+
-- !query schema
652+
struct<(NOT (a IS NULL)):boolean,c:bigint>
653+
-- !query output
654+
false 2
655+
true 7
656+
657+
658+
-- !query
659+
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
660+
FROM testData
661+
GROUP BY a IS NULL
662+
-- !query schema
663+
struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
664+
-- !query output
665+
0.7604953758285915 7
666+
1.0 2

sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,19 @@ Right keys [1]: [i_item_sk#16]
199199
Join condition: None
200200

201201
(23) Project [codegen id : 8]
202-
Output [3]: [d_date#12, i_item_sk#16, i_item_desc#17]
202+
Output [3]: [d_date#12, i_item_sk#16, substr(i_item_desc#17, 1, 30) AS _groupingexpression#19]
203203
Input [4]: [ss_item_sk#8, d_date#12, i_item_sk#16, i_item_desc#17]
204204

205205
(24) HashAggregate [codegen id : 8]
206-
Input [3]: [d_date#12, i_item_sk#16, i_item_desc#17]
207-
Keys [3]: [substr(i_item_desc#17, 1, 30) AS substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12]
206+
Input [3]: [d_date#12, i_item_sk#16, _groupingexpression#19]
207+
Keys [3]: [_groupingexpression#19, i_item_sk#16, d_date#12]
208208
Functions [1]: [partial_count(1)]
209209
Aggregate Attributes [1]: [count#20]
210-
Results [4]: [substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12, count#21]
210+
Results [4]: [_groupingexpression#19, i_item_sk#16, d_date#12, count#21]
211211

212212
(25) HashAggregate [codegen id : 8]
213-
Input [4]: [substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12, count#21]
214-
Keys [3]: [substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12]
213+
Input [4]: [_groupingexpression#19, i_item_sk#16, d_date#12, count#21]
214+
Keys [3]: [_groupingexpression#19, i_item_sk#16, d_date#12]
215215
Functions [1]: [count(1)]
216216
Aggregate Attributes [1]: [count(1)#22]
217217
Results [2]: [i_item_sk#16 AS item_sk#23, count(1)#22 AS count(1)#24]
@@ -406,19 +406,19 @@ Right keys [1]: [i_item_sk#56]
406406
Join condition: None
407407

408408
(69) Project [codegen id : 25]
409-
Output [3]: [d_date#55, i_item_sk#56, i_item_desc#57]
409+
Output [3]: [d_date#55, i_item_sk#56, substr(i_item_desc#57, 1, 30) AS _groupingexpression#58]
410410
Input [4]: [ss_item_sk#54, d_date#55, i_item_sk#56, i_item_desc#57]
411411

412412
(70) HashAggregate [codegen id : 25]
413-
Input [3]: [d_date#55, i_item_sk#56, i_item_desc#57]
414-
Keys [3]: [substr(i_item_desc#57, 1, 30) AS substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55]
413+
Input [3]: [d_date#55, i_item_sk#56, _groupingexpression#58]
414+
Keys [3]: [_groupingexpression#58, i_item_sk#56, d_date#55]
415415
Functions [1]: [partial_count(1)]
416416
Aggregate Attributes [1]: [count#59]
417-
Results [4]: [substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55, count#60]
417+
Results [4]: [_groupingexpression#58, i_item_sk#56, d_date#55, count#60]
418418

419419
(71) HashAggregate [codegen id : 25]
420-
Input [4]: [substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55, count#60]
421-
Keys [3]: [substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55]
420+
Input [4]: [_groupingexpression#58, i_item_sk#56, d_date#55, count#60]
421+
Keys [3]: [_groupingexpression#58, i_item_sk#56, d_date#55]
422422
Functions [1]: [count(1)]
423423
Aggregate Attributes [1]: [count(1)#61]
424424
Results [2]: [i_item_sk#56 AS item_sk#23, count(1)#61 AS count(1)#62]

sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/simplified.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ WholeStageCodegen (36)
3434
Sort [item_sk]
3535
Project [item_sk]
3636
Filter [count(1)]
37-
HashAggregate [substr(i_item_desc, 1, 30),i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
38-
HashAggregate [i_item_desc,i_item_sk,d_date] [count,substr(i_item_desc, 1, 30),count]
37+
HashAggregate [_groupingexpression,i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
38+
HashAggregate [_groupingexpression,i_item_sk,d_date] [count,count]
3939
Project [d_date,i_item_sk,i_item_desc]
4040
SortMergeJoin [ss_item_sk,i_item_sk]
4141
InputAdapter
@@ -177,8 +177,8 @@ WholeStageCodegen (36)
177177
Sort [item_sk]
178178
Project [item_sk]
179179
Filter [count(1)]
180-
HashAggregate [substr(i_item_desc, 1, 30),i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
181-
HashAggregate [i_item_desc,i_item_sk,d_date] [count,substr(i_item_desc, 1, 30),count]
180+
HashAggregate [_groupingexpression,i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
181+
HashAggregate [_groupingexpression,i_item_sk,d_date] [count,count]
182182
Project [d_date,i_item_sk,i_item_desc]
183183
SortMergeJoin [ss_item_sk,i_item_sk]
184184
InputAdapter

0 commit comments

Comments
 (0)