Skip to content

Commit 15f7ff0

Browse files
benhurdelheymzhang
authored andcommitted
[SPARK-53311][SQL][PYTHON][CORE] Make PullOutNonDeterministic use canonicalized expressions
### What changes were proposed in this pull request? Make PullOutNonDeterministic use canonicalized expressions to dedup group and aggregate expressions. This affects pyspark udfs in particular. Example: ``` from pyspark.sql.functions import col, avg, udf pythonUDF = udf(lambda x: x).asNondeterministic() spark.range(10)\ .selectExpr("id", "id % 3 as value")\ .groupBy(pythonUDF(col("value")))\ .agg(avg("id"), pythonUDF(col("value")))\ .explain(extended=True) ``` Currently results in a plan like this: ``` Aggregate [_nondeterministic#15](apache#15), [_nondeterministic#15 AS dummyNondeterministicUDF(value)apache#12, avg(id#0L) AS avg(id)apache#13, dummyNondeterministicUDF(value#6L)apache#8 AS dummyNondeterministicUDF(value)apache#14](apache#15%20AS%20dummyNondeterministicUDF(value)apache#12,%20avg(id#0L)%20AS%20avg(id)apache#13,%20dummyNondeterministicUDF(value#6L)apache#8%20AS%20dummyNondeterministicUDF(value)apache#14) +- Project [id#0L, value#6L, dummyNondeterministicUDF(value#6L)apache#7 AS _nondeterministic#15](#0L,%20value#6L,%20dummyNondeterministicUDF(value#6L)apache#7%20AS%20_nondeterministic#15) +- Project [id#0L, (id#0L % cast(3 as bigint)) AS value#6L](#0L,%20(id#0L%20%%20cast(3%20as%20bigint))%20AS%20value#6L) +- Range (0, 10, step=1, splits=Some(2)) ``` and then it throws: ``` [[MISSING_AGGREGATION] The non-aggregating expression "value" is based on columns which are not participating in the GROUP BY clause. Add the columns or the expression to the GROUP BY, aggregate the expression, or use "any_value(value)" if you do not care which of the values within a group is returned. SQLSTATE: 42803 ``` - how canonicalized fixes this: - nondeterministic PythonUDF expressions always have distinct resultIds per udf - The fix is to canonicalize the expressions when matching. Canonicalized means that we're setting the resultIds to -1, allowing us to dedup the PythonUDF expressions. - for deterministic UDFs, this rule does not apply and "Post Analysis" batch extracts and deduplicates the expressions, as expected ### Why are the changes needed? - the output of the query with the fix applied still makes sense - the nondeterministic UDF is invoked only once, in the project. ### Does this PR introduce _any_ user-facing change? Yes, it's additive, it enables queries to run that previously threw errors. ### How was this patch tested? - added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#52061 from benrobby/adhoc-fix-pull-out-nondeterministic. Authored-by: Ben Hurdelhey <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent db97311 commit 15f7ff0

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ object NondeterministicExpressionCollection {
3838
case namedExpression: NamedExpression => namedExpression
3939
case _ => Alias(nondeterministicExpr, "_nondeterministic")()
4040
}
41-
nonDeterministicToAttributes.put(nondeterministicExpr, namedExpression)
41+
nonDeterministicToAttributes.put(nondeterministicExpr.canonicalized, namedExpression)
4242
}
4343
}
4444
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object PullOutNondeterministic extends Rule[LogicalPlan] {
4242
NondeterministicExpressionCollection.getNondeterministicToAttributes(a.groupingExpressions)
4343
val newChild = Project(a.child.output ++ nondeterToAttr.values.asScala.toSeq, a.child)
4444
val deterministicAggregate = a.transformExpressions { case e =>
45-
Option(nondeterToAttr.get(e)).map(_.toAttribute).getOrElse(e)
45+
Option(nondeterToAttr.get(e.canonicalized)).map(_.toAttribute).getOrElse(e)
4646
}.copy(child = newChild)
4747

4848
deterministicAggregate.groupingExpressions.foreach(expr => if (!expr.deterministic) {
@@ -69,7 +69,7 @@ object PullOutNondeterministic extends Rule[LogicalPlan] {
6969
val nondeterToAttr =
7070
NondeterministicExpressionCollection.getNondeterministicToAttributes(p.expressions)
7171
val newPlan = p.transformExpressions { case e =>
72-
Option(nondeterToAttr.get(e)).map(_.toAttribute).getOrElse(e)
72+
Option(nondeterToAttr.get(e.canonicalized)).map(_.toAttribute).getOrElse(e)
7373
}
7474
val newChild = Project(p.child.output ++ nondeterToAttr.values.asScala.toSeq, p.child)
7575
Project(p.output, newPlan.withNewChildren(newChild :: Nil))

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ object IntegratedUDFTestUtils extends SQLHelper {
474474
* casted_col.cast(df.schema["col"].dataType)
475475
* }}}
476476
*/
477-
case class TestPythonUDF(name: String, returnType: Option[DataType] = None) extends TestUDF {
477+
case class TestPythonUDF(name: String, returnType: Option[DataType] = None,
478+
deterministic: Boolean = true) extends TestUDF {
478479
private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
479480
name = name,
480481
func = SimplePythonFunction(
@@ -487,7 +488,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
487488
accumulator = null),
488489
dataType = StringType,
489490
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
490-
udfDeterministic = true) {
491+
udfDeterministic = deterministic) {
491492

492493
override def builder(e: Seq[Expression]): Expression = {
493494
assert(e.length == 1, "Defined UDF only has one column")

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
21-
import org.apache.spark.sql.functions.{array, col, count, transform}
21+
import org.apache.spark.sql.functions.{array, avg, col, count, transform}
2222
import org.apache.spark.sql.test.SharedSparkSession
2323
import org.apache.spark.sql.types.LongType
2424

@@ -139,4 +139,21 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
139139
checkAnswer(df, Row(0, 1, 1, 0, 1, 1))
140140
}
141141
}
142+
143+
test("SPARK-53311: Nondeterministic Python UDF pull out in aggregate with grouping") {
144+
assume(shouldTestPythonUDFs)
145+
146+
// nondeterministic UDF
147+
val pythonUDF = TestPythonUDF(name = "foo", Some(LongType), deterministic = false)
148+
149+
// This query should work without throwing an analysis exception
150+
// The UDF foo(value) appears in both grouping expressions and aggregate expressions
151+
// The fix ensures that both instances are properly mapped to the same attribute
152+
val df = spark.range(1)
153+
.selectExpr("id", "id % 3 as value")
154+
.groupBy(pythonUDF(col("value")))
155+
.agg(avg("id"), pythonUDF(col("value")))
156+
157+
checkAnswer(df, Row(0, 0.0, 0))
158+
}
142159
}

0 commit comments

Comments
 (0)