Skip to content

Commit ef0a76e

Browse files
rednaxelafxcloud-fan
authored andcommitted
[SPARK-42851][SQL] Guard EquivalentExpressions.addExpr() with supportedExpression()
### What changes were proposed in this pull request? In `EquivalentExpressions.addExpr()`, add a guard `supportedExpression()` to make it consistent with `addExprTree()` and `getExprState()`. ### Why are the changes needed? This fixes a regression caused by #39010 which added the `supportedExpression()` to `addExprTree()` and `getExprState()` but not `addExpr()`. One example of a use case affected by the inconsistency is the `PhysicalAggregation` pattern in physical planning. There, it calls `addExpr()` to deduplicate the aggregate expressions, and then calls `getExprState()` to deduplicate the result expressions. Guarding inconsistently will cause the aggregate and result expressions go out of sync, eventually resulting in query execution error (or whole-stage codegen error). ### Does this PR introduce _any_ user-facing change? This fixes a regression affecting Spark 3.3.2+, where it may manifest as an error running aggregate operators with higher-order functions. Example running the SQL command: ```sql select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2) ``` example error message before the fix: ``` java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))#3] ``` after the fix this error is gone. ### How was this patch tested? Added new test cases to `SubexpressionEliminationSuite` for the immediate issue, and to `DataFrameAggregateSuite` for an example of user-visible symptom. Closes #40473 from rednaxelafx/spark-42851. Authored-by: Kris Mok <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent c9a530e commit ef0a76e

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ class EquivalentExpressions {
4040
* Returns true if there was already a matching expression.
4141
*/
4242
def addExpr(expr: Expression): Boolean = {
43-
updateExprInMap(expr, equivalenceMap)
43+
if (supportedExpression(expr)) {
44+
updateExprInMap(expr, equivalenceMap)
45+
} else {
46+
false
47+
}
4448
}
4549

4650
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen._
2222
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2323
import org.apache.spark.sql.internal.SQLConf
24-
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
24+
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType}
2525

2626
class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper {
2727
test("Semantic equals and hash") {
@@ -449,6 +449,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
449449
assert(e2.getCommonSubexpressions.size == 1)
450450
assert(e2.getCommonSubexpressions.head == add)
451451
}
452+
453+
test("SPARK-42851: Handle supportExpression consistently across add and get") {
454+
val expr = {
455+
val function = (lambda: Expression) => Add(lambda, Literal(1))
456+
val elementType = IntegerType
457+
val colClass = classOf[Array[Int]]
458+
val inputType = ObjectType(colClass)
459+
val inputObject = BoundReference(0, inputType, nullable = true)
460+
objects.MapObjects(function, inputObject, elementType, true, Option(colClass))
461+
}
462+
val equivalence = new EquivalentExpressions
463+
equivalence.addExpr(expr)
464+
val hasMatching = equivalence.addExpr(expr)
465+
val cseState = equivalence.getExprState(expr)
466+
assert(hasMatching == cseState.isDefined)
467+
}
452468
}
453469

454470
case class CodegenFallbackExpression(child: Expression)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest
15381538
)
15391539
checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil)
15401540
}
1541+
1542+
test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") {
1543+
val res = sql(
1544+
"select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)"
1545+
)
1546+
checkAnswer(res, Row(Array(1), Array(1)))
1547+
}
15411548
}
15421549

15431550
case class B(c: Option[Double])

0 commit comments

Comments
 (0)