Skip to content

Commit f1c5b2a

Browse files
ulysses-youattilapiros
authored andcommitted
[SPARK-47430][SQL] Rework group by map type to fix bind reference exception
### What changes were proposed in this pull request? This pr reworks the group by map type to fix issues: - Can not bind reference excpetion at runtume since the attribute was wrapped by `MapSort` and we didi not transform the plan with new output - The add `MapSort` rule should be put before `PullOutGroupingExpressions` to avoid complex expr existing in grouping keys ### Why are the changes needed? To fix issues. for example: ``` select map(1, id) from range(10) group by map(1, id); [INTERNAL_ERROR] Couldn't find _groupingexpression#18 in [mapsort(_groupingexpression#18)apache#19] SQLSTATE: XX000 org.apache.spark.SparkException: [INTERNAL_ERROR] Couldn't find _groupingexpression#18 in [mapsort(_groupingexpression#18)apache#19] SQLSTATE: XX000 at org.apache.spark.SparkException$.internalError(SparkException.scala:92) at org.apache.spark.SparkException$.internalError(SparkException.scala:96) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:81) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:74) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:470) ``` ### Does this PR introduce _any_ user-facing change? no, not released ### How was this patch tested? improve the tests to add more cases ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#47545 from ulysses-you/maptype. Authored-by: ulysses-you <[email protected]> Signed-off-by: youxiduo <[email protected]>
1 parent 0bb556b commit f1c5b2a

File tree

4 files changed

+123
-49
lines changed

4 files changed

+123
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,73 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedLambdaVariable}
21-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2224
import org.apache.spark.sql.catalyst.rules.Rule
23-
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
25+
import org.apache.spark.sql.catalyst.trees.TreePattern
2426
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
2527
import org.apache.spark.util.ArrayImplicits.SparkArrayOps
2628

2729
/**
28-
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
30+
* Adds [[MapSort]] to group expressions containing map columns, as the key/value paris need to be
2931
* in the correct order before grouping:
30-
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
31-
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
32+
*
33+
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
34+
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
35+
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
36+
* ) GROUP BY _groupingmapsort
3237
*/
3338
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
34-
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
35-
_.containsPattern(AGGREGATE), ruleId) {
36-
case a @ Aggregate(groupingExpr, _, _) =>
37-
val newGrouping = groupingExpr.map { expr =>
38-
if (!expr.exists(_.isInstanceOf[MapSort])
39-
&& expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
40-
insertMapSortRecursively(expr)
41-
} else {
42-
expr
39+
private def shouldAddMapSort(expr: Expression): Boolean = {
40+
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
41+
}
42+
43+
override def apply(plan: LogicalPlan): LogicalPlan = {
44+
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
45+
return plan
46+
}
47+
val shouldRewrite = plan.exists {
48+
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
49+
case _ => false
50+
}
51+
if (!shouldRewrite) {
52+
return plan
53+
}
54+
55+
plan transformUpWithNewOutput {
56+
case agg @ Aggregate(groupingExprs, aggregateExpressions, child)
57+
if agg.groupingExpressions.exists(shouldAddMapSort) =>
58+
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
59+
val newGroupingKeys = groupingExprs.map { expr =>
60+
val inserted = insertMapSortRecursively(expr)
61+
if (expr.ne(inserted)) {
62+
exprToMapSort.getOrElseUpdate(
63+
expr.canonicalized,
64+
Alias(inserted, "_groupingmapsort")()
65+
).toAttribute
66+
} else {
67+
expr
68+
}
4369
}
44-
}
45-
a.copy(groupingExpressions = newGrouping)
70+
val newAggregateExprs = aggregateExpressions.map {
71+
case named if exprToMapSort.contains(named.canonicalized) =>
72+
// If we replace the top-level named expr, then should add back the original name
73+
exprToMapSort(named.canonicalized).toAttribute.withName(named.name)
74+
case other =>
75+
other.transformUp {
76+
case e => exprToMapSort.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
77+
}.asInstanceOf[NamedExpression]
78+
}
79+
val newChild = Project(child.output ++ exprToMapSort.values, child)
80+
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
81+
newAgg -> agg.output.zip(newAgg.output)
82+
}
4683
}
4784

48-
/*
49-
Inserts MapSort recursively taking into account when
50-
it is nested inside a struct or array.
85+
/**
86+
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
5187
*/
5288
private def insertMapSortRecursively(e: Expression): Expression = {
5389
e.dataType match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
150150
}
151151

152152
val batches = (
153-
Batch("Finish Analysis", Once, FinishAnalysis) ::
153+
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
154154
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
155155
// may produce `With` expressions that need to be rewritten.
156156
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
@@ -246,8 +246,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
246246
CollapseProject,
247247
RemoveRedundantAliases,
248248
RemoveNoopOperators) :+
249-
Batch("InsertMapSortInGroupingExpressions", Once,
250-
InsertMapSortInGroupingExpressions) :+
251249
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
252250
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
253251
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
@@ -297,6 +295,10 @@ abstract class Optimizer(catalogManager: CatalogManager)
297295
ReplaceExpressions,
298296
RewriteNonCorrelatedExists,
299297
PullOutGroupingExpressions,
298+
// Put `InsertMapSortInGroupingExpressions` after `PullOutGroupingExpressions`,
299+
// so the grouping keys can only be attribute and literal which makes
300+
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
301+
InsertMapSortInGroupingExpressions,
300302
ComputeCurrentTime,
301303
ReplaceCurrentLike(catalogManager),
302304
SpecialDatetimeValues,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ object RuleIdCollection {
127127
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
128128
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
129129
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
130-
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
131130
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
132131
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
133132
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,8 +2162,9 @@ class DataFrameAggregateSuite extends QueryTest
21622162
)
21632163
}
21642164

2165-
private def assertAggregateOnDataframe(df: DataFrame,
2166-
expected: Int, aggregateColumn: String): Unit = {
2165+
private def assertAggregateOnDataframe(
2166+
df: => DataFrame,
2167+
expected: Int): Unit = {
21672168
val configurations = Seq(
21682169
Seq.empty[(String, String)], // hash aggregate is used by default
21692170
Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN",
@@ -2175,32 +2176,64 @@ class DataFrameAggregateSuite extends QueryTest
21752176
Seq("spark.sql.test.forceApplySortAggregate" -> "true")
21762177
)
21772178

2178-
for (conf <- configurations) {
2179-
withSQLConf(conf: _*) {
2180-
assert(createAggregate(df).count() == expected)
2179+
// Make tests faster
2180+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
2181+
for (conf <- configurations) {
2182+
withSQLConf(conf: _*) {
2183+
assert(df.count() == expected, df.queryExecution.simpleString)
2184+
}
21812185
}
21822186
}
2183-
2184-
def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*"))
21852187
}
21862188

21872189
test("SPARK-47430 Support GROUP BY MapType") {
2188-
val numRows = 50
2189-
2190-
val dfSameInt = (0 until numRows)
2191-
.map(_ => Tuple1(Map(1 -> 1)))
2192-
.toDF("m0")
2193-
assertAggregateOnDataframe(dfSameInt, 1, "m0")
2194-
2195-
val dfSameFloat = (0 until numRows)
2196-
.map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 )))
2197-
.toDF("m0")
2198-
assertAggregateOnDataframe(dfSameFloat, 1, "m0")
2199-
2200-
val dfDifferent = (0 until numRows)
2201-
.map(i => Tuple1(Map(i -> i)))
2202-
.toDF("m0")
2203-
assertAggregateOnDataframe(dfDifferent, numRows, "m0")
2190+
def genMapData(dataType: String): String = {
2191+
s"""
2192+
|case when id % 4 == 0 then map()
2193+
|when id % 4 == 1 then map(cast(0 as $dataType), cast(0 as $dataType))
2194+
|when id % 4 == 2 then map(cast(0 as $dataType), cast(0 as $dataType),
2195+
| cast(1 as $dataType), cast(1 as $dataType))
2196+
|else map(cast(1 as $dataType), cast(1 as $dataType),
2197+
| cast(0 as $dataType), cast(0 as $dataType))
2198+
|end
2199+
|""".stripMargin
2200+
}
2201+
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
2202+
withTempView("v") {
2203+
spark.range(20)
2204+
.selectExpr(
2205+
s"cast(1 as $dt) as c1",
2206+
s"${genMapData(dt)} as c2",
2207+
"map(c1, null) as c3",
2208+
s"cast(null as map<$dt, $dt>) as c4")
2209+
.createOrReplaceTempView("v")
2210+
2211+
assertAggregateOnDataframe(
2212+
spark.sql("SELECT count(*) FROM v GROUP BY c2"),
2213+
3)
2214+
assertAggregateOnDataframe(
2215+
spark.sql("SELECT c2, count(*) FROM v GROUP BY c2"),
2216+
3)
2217+
assertAggregateOnDataframe(
2218+
spark.sql("SELECT c1, c2, count(*) FROM v GROUP BY c1, c2"),
2219+
3)
2220+
assertAggregateOnDataframe(
2221+
spark.sql("SELECT map(c1, c1) FROM v GROUP BY map(c1, c1)"),
2222+
1)
2223+
assertAggregateOnDataframe(
2224+
spark.sql("SELECT map(c1, c1), count(*) FROM v GROUP BY map(c1, c1)"),
2225+
1)
2226+
assertAggregateOnDataframe(
2227+
spark.sql("SELECT c3, count(*) FROM v GROUP BY c3"),
2228+
1)
2229+
assertAggregateOnDataframe(
2230+
spark.sql("SELECT c4, count(*) FROM v GROUP BY c4"),
2231+
1)
2232+
assertAggregateOnDataframe(
2233+
spark.sql("SELECT c1, c2, c3, c4, count(*) FROM v GROUP BY c1, c2, c3, c4"),
2234+
3)
2235+
}
2236+
}
22042237
}
22052238

22062239
test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
@@ -2209,12 +2242,16 @@ class DataFrameAggregateSuite extends QueryTest
22092242
val dfSame = (0 until numRows)
22102243
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
22112244
.toDF("c0")
2212-
assertAggregateOnDataframe(dfSame, 1, "c0")
2245+
.groupBy($"c0")
2246+
.count()
2247+
assertAggregateOnDataframe(dfSame, 1)
22132248

22142249
val dfDifferent = (0 until numRows)
22152250
.map(i => Tuple1(new CalendarInterval(i, i, i)))
22162251
.toDF("c0")
2217-
assertAggregateOnDataframe(dfDifferent, numRows, "c0")
2252+
.groupBy($"c0")
2253+
.count()
2254+
assertAggregateOnDataframe(dfDifferent, numRows)
22182255
}
22192256

22202257
test("SPARK-46779: Group by subquery with a cached relation") {

0 commit comments

Comments
 (0)