Skip to content

[SPARK-47430][SQL] Rework group by map type to fix bind reference exception #47545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,73 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps

/**
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
* Adds [[MapSort]] to group expressions containing map columns, as the key/value paris need to be
* in the correct order before grouping:
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
*
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
* ) GROUP BY _groupingmapsort
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case a @ Aggregate(groupingExpr, _, _) =>
val newGrouping = groupingExpr.map { expr =>
if (!expr.exists(_.isInstanceOf[MapSort])
&& expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
insertMapSortRecursively(expr)
} else {
expr
private def shouldAddMapSort(expr: Expression): Boolean = {
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
return plan
}
val shouldRewrite = plan.exists {
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
case _ => false
}
if (!shouldRewrite) {
return plan
}

plan transformUpWithNewOutput {
case agg @ Aggregate(groupingExprs, aggregateExpressions, child)
if agg.groupingExpressions.exists(shouldAddMapSort) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
if (expr.ne(inserted)) {
exprToMapSort.getOrElseUpdate(
expr.canonicalized,
Alias(inserted, "_groupingmapsort")()
).toAttribute
} else {
expr
}
}
}
a.copy(groupingExpressions = newGrouping)
val newAggregateExprs = aggregateExpressions.map {
case named if exprToMapSort.contains(named.canonicalized) =>
// If we replace the top-level named expr, then should add back the original name
exprToMapSort(named.canonicalized).toAttribute.withName(named.name)
case other =>
other.transformUp {
case e => exprToMapSort.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
}.asInstanceOf[NamedExpression]
}
val newChild = Project(child.output ++ exprToMapSort.values, child)
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
newAgg -> agg.output.zip(newAgg.output)
}
}

/*
Inserts MapSort recursively taking into account when
it is nested inside a struct or array.
/**
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
*/
private def insertMapSortRecursively(e: Expression): Expression = {
e.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
}

val batches = (
Batch("Finish Analysis", Once, FinishAnalysis) ::
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
Expand Down Expand Up @@ -246,8 +246,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
Batch("InsertMapSortInGroupingExpressions", Once,
InsertMapSortInGroupingExpressions) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Expand Down Expand Up @@ -297,6 +295,10 @@ abstract class Optimizer(catalogManager: CatalogManager)
ReplaceExpressions,
RewriteNonCorrelatedExists,
PullOutGroupingExpressions,
// Put `InsertMapSortInGroupingExpressions` after `PullOutGroupingExpressions`,
// so the grouping keys can only be attribute and literal which makes
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
InsertMapSortInGroupingExpressions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add some comments to explain the rule order reasoning.

ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2162,8 +2162,9 @@ class DataFrameAggregateSuite extends QueryTest
)
}

private def assertAggregateOnDataframe(df: DataFrame,
expected: Int, aggregateColumn: String): Unit = {
private def assertAggregateOnDataframe(
df: => DataFrame,
expected: Int): Unit = {
val configurations = Seq(
Seq.empty[(String, String)], // hash aggregate is used by default
Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN",
Expand All @@ -2175,32 +2176,64 @@ class DataFrameAggregateSuite extends QueryTest
Seq("spark.sql.test.forceApplySortAggregate" -> "true")
)

for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(createAggregate(df).count() == expected)
// Make tests faster
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(df.count() == expected, df.queryExecution.simpleString)
}
}
}

def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*"))
}

test("SPARK-47430 Support GROUP BY MapType") {
val numRows = 50

val dfSameInt = (0 until numRows)
.map(_ => Tuple1(Map(1 -> 1)))
.toDF("m0")
assertAggregateOnDataframe(dfSameInt, 1, "m0")

val dfSameFloat = (0 until numRows)
.map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 )))
.toDF("m0")
assertAggregateOnDataframe(dfSameFloat, 1, "m0")

val dfDifferent = (0 until numRows)
.map(i => Tuple1(Map(i -> i)))
.toDF("m0")
assertAggregateOnDataframe(dfDifferent, numRows, "m0")
def genMapData(dataType: String): String = {
s"""
|case when id % 4 == 0 then map()
|when id % 4 == 1 then map(cast(0 as $dataType), cast(0 as $dataType))
|when id % 4 == 2 then map(cast(0 as $dataType), cast(0 as $dataType),
| cast(1 as $dataType), cast(1 as $dataType))
|else map(cast(1 as $dataType), cast(1 as $dataType),
| cast(0 as $dataType), cast(0 as $dataType))
|end
|""".stripMargin
}
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
withTempView("v") {
spark.range(20)
.selectExpr(
s"cast(1 as $dt) as c1",
s"${genMapData(dt)} as c2",
"map(c1, null) as c3",
s"cast(null as map<$dt, $dt>) as c4")
.createOrReplaceTempView("v")

assertAggregateOnDataframe(
spark.sql("SELECT count(*) FROM v GROUP BY c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT c2, count(*) FROM v GROUP BY c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT c1, c2, count(*) FROM v GROUP BY c1, c2"),
3)
assertAggregateOnDataframe(
spark.sql("SELECT map(c1, c1) FROM v GROUP BY map(c1, c1)"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT map(c1, c1), count(*) FROM v GROUP BY map(c1, c1)"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c3, count(*) FROM v GROUP BY c3"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c4, count(*) FROM v GROUP BY c4"),
1)
assertAggregateOnDataframe(
spark.sql("SELECT c1, c2, c3, c4, count(*) FROM v GROUP BY c1, c2, c3, c4"),
3)
}
}
}

test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
Expand All @@ -2209,12 +2242,16 @@ class DataFrameAggregateSuite extends QueryTest
val dfSame = (0 until numRows)
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
.toDF("c0")
assertAggregateOnDataframe(dfSame, 1, "c0")
.groupBy($"c0")
.count()
assertAggregateOnDataframe(dfSame, 1)

val dfDifferent = (0 until numRows)
.map(i => Tuple1(new CalendarInterval(i, i, i)))
.toDF("c0")
assertAggregateOnDataframe(dfDifferent, numRows, "c0")
.groupBy($"c0")
.count()
assertAggregateOnDataframe(dfDifferent, numRows)
}

test("SPARK-46779: Group by subquery with a cached relation") {
Expand Down