-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-36063][SQL] Optimize OneRowRelation subqueries #33284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer | |
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ | ||
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ | ||
|
@@ -711,3 +712,47 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] { | |
Join(left, newRight, joinType, newCond, JoinHint.NONE) | ||
} | ||
} | ||
|
||
/** | ||
* This rule optimizes subqueries with OneRowRelation as leaf nodes. | ||
*/ | ||
object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { | ||
|
||
object OneRowSubquery { | ||
def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = { | ||
CollapseProject(EliminateSubqueryAliases(plan)) match { | ||
case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList)) | ||
case _ => None | ||
} | ||
} | ||
} | ||
|
||
private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = { | ||
plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined | ||
} | ||
|
||
/** | ||
* Rewrite a subquery expression into one or more expressions. The rewrite can only be done | ||
* if there is no nested subqueries in the subquery plan. | ||
*/ | ||
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries { | ||
case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None) | ||
if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty => | ||
Project(left.output ++ projectList, left) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the lateral join has a condition, can we just add a filter above project? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be fine for inner join but for left outer join, it's trickier. This also applies to subqueries after pulling out correlated filters as join conditions. Maybe this can be a separate optimization before RewriteCorrelatedScalarSubqueries / RewriteLateralSubqueries. |
||
case p: LogicalPlan => p.transformExpressionsUpWithPruning( | ||
_.containsPattern(SCALAR_SUBQUERY)) { | ||
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _) | ||
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => | ||
assert(projectList.size == 1) | ||
projectList.head | ||
} | ||
} | ||
|
||
def apply(plan: LogicalPlan): LogicalPlan = { | ||
if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) { | ||
plan | ||
} else { | ||
rewrite(plan) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest { | |
val x = AttributeReference("x", IntegerType)() | ||
val y = AttributeReference("y", IntegerType)() | ||
val z = AttributeReference("z", IntegerType)() | ||
val t0 = OneRowRelation() | ||
val testRelation = LocalRelation(a, b, c) | ||
val testRelation2 = LocalRelation(x, y, z) | ||
|
||
|
@@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest { | |
|
||
test("correlated values in project") { | ||
val outerPlan = testRelation2 | ||
val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation()) | ||
val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation())) | ||
val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0) | ||
val correctAnswer = Project( | ||
Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will we optimize away the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now once the domain join is added, it will always be rewritten as an inner join because the join condition in the subquery might not be null: |
||
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) | ||
} | ||
|
||
test("correlated values in project with alias") { | ||
val outerPlan = testRelation2 | ||
val innerPlan = | ||
Project(Seq(OuterReference(x), 'y1, 'sum), | ||
Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum), | ||
Project(Seq( | ||
OuterReference(x), | ||
OuterReference(y).as("y1"), | ||
Add(OuterReference(x), OuterReference(y)).as("sum")), | ||
testRelation)).analyze | ||
val correctAnswer = | ||
Project(Seq(x, 'y1, 'sum, y), | ||
Project(Seq(x, y.as("y1"), (x + y).as("sum"), y), | ||
Project(Seq(x.as("x1"), 'y1, 'sum, x, y), | ||
Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y), | ||
DomainJoin(Seq(x, y), testRelation))).analyze | ||
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) | ||
} | ||
|
@@ -228,28 +230,28 @@ class DecorrelateInnerQuerySuite extends PlanTest { | |
val outerPlan = testRelation2 | ||
val innerPlan = | ||
Project( | ||
Seq(OuterReference(x)), | ||
Seq(OuterReference(x).as("x1")), | ||
Filter( | ||
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), | ||
testRelation | ||
) | ||
) | ||
val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation)) | ||
val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation)) | ||
check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c)) | ||
} | ||
|
||
test("correlated values in project without correlated equality conditions in filter") { | ||
val outerPlan = testRelation2 | ||
val innerPlan = | ||
Project( | ||
Seq(OuterReference(y)), | ||
Seq(OuterReference(y).as("y1")), | ||
Filter( | ||
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), | ||
testRelation | ||
) | ||
) | ||
val correctAnswer = | ||
Project(Seq(y, a, c), | ||
Project(Seq(y.as("y1"), y, a, c), | ||
Filter(b === 1, | ||
DomainJoin(Seq(y), testRelation) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to handle nested subqueries here? I think the rule
OptimizeSubqueries
will run this rule again to optimize nested subqueries.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why we need to check subqueries is to deal with nested subqueries:
A subquery's plan should only be rewritten if it doesn't contain another correlated subquery. If we do not transform the nested subqueries first, we will miss out cases like the one above.