Skip to content

Commit 3e01d41

Browse files
bersprocketsdongjoon-hyun
authored andcommitted
[SPARK-50091][SQL][3.5] Handle case of aggregates in left-hand operand of IN-subquery
### What changes were proposed in this pull request? This is a back-port of #48627. This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.: ``` Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40] +- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L) :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L] : +- LocalRelation [col1#32, col2#33] +- LocalRelation [c2#39L] ``` `sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression. ### Why are the changes needed? The following query fails: ``` create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1); create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1); select col1, sum(col2) in (select c2 from v1) from v2 group by col1; ``` It fails with this error: ``` [INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000 ``` With SPARK_TESTING=1, it fails with this error: ``` [PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan: Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19] +- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L) :- LocalRelation [col1#11, col2#12] +- LocalRelation [c2#18L] ``` The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression. The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression. This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions. ### Does this PR introduce _any_ user-facing change? No, other than allowing this type of query to succeed. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49663 from bersprockets/aggregate_in_set_issue_br35. Authored-by: Bruce Robbins <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7fd9ced commit 3e01d41

File tree

3 files changed

+136
-9
lines changed

3 files changed

+136
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
2626
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2727
import org.apache.spark.sql.catalyst.expressions.aggregate._
2828
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
29+
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
2930
import org.apache.spark.sql.catalyst.plans._
3031
import org.apache.spark.sql.catalyst.plans.logical._
3132
import org.apache.spark.sql.catalyst.rules._
@@ -100,6 +101,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
100101
}
101102
}
102103

104+
def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
105+
exprs.exists { expr =>
106+
exprContainsAggregateInSubquery(expr)
107+
}
108+
}
109+
110+
def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
111+
expr.exists {
112+
case InSubquery(values, _) =>
113+
values.exists { v =>
114+
v.exists {
115+
case _: AggregateExpression => true
116+
case _ => false
117+
}
118+
}
119+
case _ => false;
120+
}
121+
}
122+
103123
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
104124
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
105125
case Filter(condition, child)
@@ -162,15 +182,75 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
162182
Project(p.output, Filter(newCond.get, inputPlan))
163183
}
164184

185+
// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
186+
//
187+
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
188+
// expressions from the Aggregate node into a new Project node. The new Project node
189+
// will then be handled by the Unary node handler.
190+
//
191+
// The Unary node handler uses the left-hand side of the IN-subquery in a
192+
// join condition. Thus, without this pre-transformation, the join condition
193+
// contains an aggregate, which is illegal. With this pre-transformation, the
194+
// join condition contains an attribute from the left-hand side of the
195+
// IN-subquery contained in the Project node.
196+
//
197+
// For example:
198+
//
199+
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
200+
// FROM v2;
201+
//
202+
// The above query has this plan on entry to RewritePredicateSubquery#apply:
203+
//
204+
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
205+
// : +- LocalRelation [c3#28L]
206+
// +- LocalRelation [col2#18, col3#19]
207+
//
208+
// Note that the Aggregate node contains the IN-subquery and the left-hand
209+
// side of the IN-subquery is an aggregate expression sum(col2#18)).
210+
//
211+
// This handler transforms the above plan into the following:
212+
// scalastyle:off line.size.limit
213+
//
214+
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
215+
// : +- LocalRelation [c3#28L]
216+
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
217+
// +- LocalRelation [col2#18, col3#19]
218+
//
219+
// scalastyle:on
220+
// Note that both the IN-subquery and the greater-than expressions have been
221+
// pulled up into the Project node. These expressions use attributes
222+
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
223+
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
224+
case p @ PhysicalAggregation(
225+
groupingExpressions, aggregateExpressions, resultExpressions, child)
226+
if exprsContainsAggregateInSubquery(p.expressions) =>
227+
val aggExprs = aggregateExpressions.map(
228+
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
229+
val aggExprIds = aggExprs.map(_.exprId).toSet
230+
val resExprs = resultExpressions.map(_.transform {
231+
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
232+
a.withName("_aggregateexpression")
233+
}.asInstanceOf[NamedExpression])
234+
// Rewrite the projection and the aggregate separately and then piece them together.
235+
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
236+
val newProj = Project(resExprs, newAgg)
237+
handleUnaryNode(newProj)
238+
165239
case u: UnaryNode if u.expressions.exists(
166-
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
167-
var newChild = u.child
168-
u.mapExpressions(expr => {
169-
val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
170-
newChild = p
171-
// The newExpr can not be None
172-
newExpr.get
173-
}).withNewChildren(Seq(newChild))
240+
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
241+
}
242+
243+
/**
244+
* Handle the unary node case
245+
*/
246+
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
247+
var newChild = u.child
248+
u.mapExpressions(expr => {
249+
val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild)
250+
newChild = p
251+
// The newExpr can not be None
252+
newExpr.get
253+
}).withNewChildren(Seq(newChild))
174254
}
175255

176256
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.QueryPlanningTracker
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
23+
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
2424
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
2525
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2626
import org.apache.spark.sql.catalyst.rules.RuleExecutor
27+
import org.apache.spark.sql.types.LongType
2728

2829

2930
class RewriteSubquerySuite extends PlanTest {
@@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
7980
Optimize.executeAndTrack(query.analyze, tracker)
8081
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
8182
}
83+
84+
test("SPARK-50091: Don't put aggregate expression in join condition") {
85+
val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
86+
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
87+
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
88+
val optimized = Optimize.execute(plan.analyze)
89+
val aggregate = relation2
90+
.select($"col2")
91+
.groupBy()(sum($"col2").as("_aggregateexpression"))
92+
val correctAnswer = aggregate
93+
.join(relation1.select(Cast($"c3", LongType).as("c3")),
94+
ExistenceJoin($"exists".boolean.withNullability(false)),
95+
Some($"_aggregateexpression" === $"c3"))
96+
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
97+
comparePlans(optimized, correctAnswer)
98+
}
8299
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
28002800
checkAnswer(df3, Row(7))
28012801
}
28022802
}
2803+
2804+
test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
2805+
withView("v1", "v2") {
2806+
Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
2807+
.toDF("c1", "c2", "c3")
2808+
.createOrReplaceTempView("v1")
2809+
Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
2810+
.toDF("col1", "col2", "col3")
2811+
.createOrReplaceTempView("v2")
2812+
2813+
val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1")
2814+
checkAnswer(df1,
2815+
Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)
2816+
2817+
val df2 = sql("""SELECT
2818+
| col1,
2819+
| SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x
2820+
|FROM v2 GROUP BY col1
2821+
|ORDER BY col1""".stripMargin)
2822+
checkAnswer(df2,
2823+
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
2824+
2825+
val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x
2826+
|FROM v2
2827+
|GROUP BY col1
2828+
|ORDER BY col1""".stripMargin)
2829+
checkAnswer(df3,
2830+
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
2831+
}
2832+
}
28032833
}

0 commit comments

Comments
 (0)