Skip to content

Commit 1a0791d

Browse files
bersprocketsdongjoon-hyun
authored andcommitted
[SPARK-49261][SQL] Don't replace literals in aggregate expressions with group-by expressions
### What changes were proposed in this pull request? Before this PR, `RewriteDistinctAggregates` could potentially replace literals in the aggregate expressions with output attributes from the `Expand` operator. This can occur when a group-by expression is a literal that happens by chance to match a literal used in an aggregate expression. E.g.: ``` create or replace temp view v1(a, b, c) as values (1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4); cache table v1; select round(sum(b), 6) as sum1, count(distinct a) as count1, count(distinct c) as count2 from ( select 6 as gb, * from v1 ) group by a, gb; ``` In the optimized plan, you can see that the literal 6 in the `round` function invocation has been patched with an output attribute (6#163) from the `Expand` operator: ``` == Optimized Logical Plan == 'Aggregate [a#123, 6#163], [round(first(sum(__auto_generated_subquery_name.b)#167, true) FILTER (WHERE (gid#162 = 0)), 6#163) AS sum1#114, count(__auto_generated_subquery_name.a#164) FILTER (WHERE (gid#162 = 1)) AS count1#115L, count(__auto_generated_subquery_name.c#165) FILTER (WHERE (gid#162 = 2)) AS count2#116L] +- Aggregate [a#123, 6#163, __auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, gid#162], [a#123, 6#163, __auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, gid#162, sum(__auto_generated_subquery_name.b#166) AS sum(__auto_generated_subquery_name.b)#167] +- Expand [[a#123, 6, null, null, 0, b#124], [a#123, 6, a#123, null, 1, null], [a#123, 6, null, c#125, 2, null]], [a#123, 6#163, __auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, gid#162, __auto_generated_subquery_name.b#166] +- InMemoryRelation [a#123, b#124, c#125], StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [a#6, b#7, c#8] ``` This is because the literal 6 was used in the group-by expressions (referred to as gb in the query, and renamed 6#163 in the `Expand` operator's output attributes). After this PR, foldable expressions in the aggregate expressions are kept as-is. ### Why are the changes needed? Some expressions require a foldable argument. In the above example, the `round` function requires a foldable expression as the scale argument. Because the scale argument is patched with an attribute, `RoundBase#checkInputDataTypes` returns an error, which leaves the `Aggregate` operator unresolved: ``` [INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000 org.apache.spark.sql.catalyst.analysis.UnresolvedException: [INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000 at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:255) at org.apache.spark.sql.catalyst.types.DataTypeUtils$.$anonfun$fromAttributes$1(DataTypeUtils.scala:241) at scala.collection.immutable.List.map(List.scala:247) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.sql.catalyst.types.DataTypeUtils$.fromAttributes(DataTypeUtils.scala:241) at org.apache.spark.sql.catalyst.plans.QueryPlan.schema$lzycompute(QueryPlan.scala:428) at org.apache.spark.sql.catalyst.plans.QueryPlan.schema(QueryPlan.scala:428) at org.apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:474) ... ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47876 from bersprockets/group_by_lit_issue. Authored-by: Bruce Robbins <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent bc54eac commit 1a0791d

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
400400
(distinctAggOperatorMap.flatMap(_._2) ++
401401
regularAggOperatorMap.map(e => (e._1, e._3))).toMap
402402

403+
val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable)
403404
val patchedAggExpressions = a.aggregateExpressions.map { e =>
404405
e.transformDown {
405406
case e: Expression =>
406407
// The same GROUP BY clauses can have different forms (different names for instance) in
407408
// the groupBy and aggregate expressions of an aggregate. This makes a map lookup
408409
// tricky. So we do a linear search for a semantically equal group by expression.
409-
groupByMap
410+
groupByMapNonFoldable
410411
.find(ge => e.semanticEquals(ge._1))
411412
.map(_._2)
412413
.getOrElse(transformations.getOrElse(e, e))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer
1818

1919
import org.apache.spark.sql.catalyst.dsl.expressions._
2020
import org.apache.spark.sql.catalyst.dsl.plans._
21-
import org.apache.spark.sql.catalyst.expressions.Literal
21+
import org.apache.spark.sql.catalyst.expressions.{Literal, Round}
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
2323
import org.apache.spark.sql.catalyst.plans.PlanTest
2424
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
@@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
109109
case _ => fail(s"Plan is not rewritten:\n$rewrite")
110110
}
111111
}
112+
113+
test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") {
114+
val relation = testRelation2
115+
.select(Literal(6).as("gb"), $"a", $"b", $"c", $"d")
116+
val input = relation
117+
.groupBy($"a", $"gb")(
118+
countDistinct($"b").as("agg1"),
119+
countDistinct($"d").as("agg2"),
120+
Round(sum($"c").as("sum1"), 6)).analyze
121+
val rewriteFold = FoldablePropagation(input)
122+
// without the fix, the below produces an unresolved plan
123+
val rewrite = RewriteDistinctAggregates(rewriteFold)
124+
if (!rewrite.resolved) {
125+
fail(s"Plan is not as expected:\n$rewrite")
126+
}
127+
}
112128
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,6 +2490,27 @@ class DataFrameAggregateSuite extends QueryTest
24902490
})
24912491
}
24922492
}
2493+
2494+
test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") {
2495+
val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c")
2496+
withTempView("v1") {
2497+
data.createOrReplaceTempView("v1")
2498+
val df =
2499+
sql("""SELECT
2500+
| ROUND(SUM(b), 6) AS sum1,
2501+
| COUNT(DISTINCT a) AS count1,
2502+
| COUNT(DISTINCT c) AS count2
2503+
|FROM (
2504+
| SELECT
2505+
| 6 AS gb,
2506+
| *
2507+
| FROM v1
2508+
|)
2509+
|GROUP BY a, gb
2510+
|""".stripMargin)
2511+
checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil)
2512+
}
2513+
}
24932514
}
24942515

24952516
case class B(c: Option[Double])

0 commit comments

Comments
 (0)