Skip to content

[SPARK-36063][SQL] Optimize OneRowRelation subqueries #33284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ package object dsl {
condition: Option[Expression] = None): LogicalPlan =
Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)

def lateralJoin(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
condition: Option[Expression] = None): LogicalPlan = {
LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition)
}

def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder](
otherPlan: LogicalPlan,
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ object SubExprUtils extends PredicateHelper {
/**
* Returns an expression after removing the OuterReference shell.
*/
def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r }
def stripOuterReference[E <: Expression](e: E): E = {
e.transform { case OuterReference(r) => r }.asInstanceOf[E]
}

/**
* Returns the list of expressions after removing the OuterReference shell from each of
* the expression.
*/
def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference)
def stripOuterReferences[E <: Expression](e: Seq[E]): Seq[E] = e.map(stripOuterReference)

/**
* Returns the logical plan after removing the OuterReference shell from all the expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ object DecorrelateInnerQuery extends PredicateHelper {
expressions.map(replaceOuterReference(_, outerReferenceMap))
}

/**
* Replace all outer references in the given named expressions and keep the output
* attributes unchanged.
*/
private def replaceOuterInNamedExpressions(
expressions: Seq[NamedExpression],
outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = {
expressions.map { expr =>
val newExpr = replaceOuterReference(expr, outerReferenceMap)
if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) {
Alias(newExpr, expr.name)(expr.exprId)
} else {
newExpr
}
}
}

/**
* Return all references that are presented in the join conditions but not in the output
* of the given named expressions.
Expand Down Expand Up @@ -429,8 +446,9 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated)
// Replace all outer references in the original project list.
val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
// Replace all outer references in the original project list and keep the output
// attributes unchanged.
val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap)
// Preserve required domain attributes in the join condition by adding the missing
// references to the new project list.
val referencesToAdd = missingReferences(newProjectList, joinCond)
Expand All @@ -442,9 +460,10 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated = true)
// Replace all outer references in grouping and aggregate expressions.
// Replace all outer references in grouping and aggregate expressions, and keep
// the output attributes unchanged.
val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap)
val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap)
// Add all required domain attributes to both grouping and aggregate expressions.
val referencesToAdd = missingReferences(newAggExpr, joinCond)
val newAggregate = a.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// non-nullable when an empty relation child of a Union is removed
UpdateAttributeNullability) ::
Batch("Pullup Correlated Expressions", Once,
OptimizeOneRowRelationSubquery,
PullupCorrelatedPredicates) ::
// Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense
// to enforce idempotence on it and we change this batch from Once to FixedPoint(1).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
Expand Down Expand Up @@ -711,3 +712,47 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] {
Join(left, newRight, joinType, newCond, JoinHint.NONE)
}
}

/**
* This rule optimizes subqueries with OneRowRelation as leaf nodes.
*/
object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {

object OneRowSubquery {
def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = {
CollapseProject(EliminateSubqueryAliases(plan)) match {
case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList))
case _ => None
}
}
}

private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = {
plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined
}

/**
* Rewrite a subquery expression into one or more expressions. The rewrite can only be done
* if there is no nested subqueries in the subquery plan.
*/
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle nested subqueries here? I think the rule OptimizeSubqueries will run this rule again to optimize nested subqueries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why we need to check subqueries is to deal with nested subqueries:

Project [scalar-subquery [a]]
:  +- Project [scalar-subquery [b]] <-- collapsible if transform with nested subqueries first
:     :  +- Project [outer(b) + 1]
:     :     +- OneRowRelation
:     +- Project [outer(a) as b]
:         +- OneRowRelation
+- Relation [a]

A subquery's plan should only be rewritten if it doesn't contain another correlated subquery. If we do not transform the nested subqueries first, we will miss out cases like the one above.

case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty =>
Project(left.output ++ projectList, left)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the lateral join has a condition, can we just add a filter above project?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine for inner join but for left outer join, it's trickier. This also applies to subqueries after pulling out correlated filters as join conditions. Maybe this can be a separate optimization before RewriteCorrelatedScalarSubqueries / RewriteLateralSubqueries.

case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(projectList.size == 1)
projectList.head
}
}

def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) {
plan
} else {
rewrite(plan)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,23 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
subqueries ++ subqueries.flatMap(_.subqueriesAll)
}

/**
* Returns a copy of this node where the given partial function has been recursively applied
* first to the subqueries in this node's children, then this node's children, and finally
* this node itself (post-order). When the partial function does not apply to a given node,
* it is left unchanged.
*/
def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
transformUp { case plan =>
val transformed = plan transformExpressionsUp {
case planExpression: PlanExpression[PlanType] =>
val newPlan = planExpression.plan.transformUpWithSubqueries(f)
planExpression.withNewPlan(newPlan)
}
f.applyOrElse[PlanType, PlanType](transformed, identity)
}
}

/**
* A variant of `collect`. This method not only apply the given function to all elements in this
* plan, also considering all the plans in its (nested) subqueries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2593,6 +2593,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY =
buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery")
.internal()
.doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.")
.version("3.2.0")
.booleanConf
.createWithDefault(true)

val TOP_K_SORT_FALLBACK_THRESHOLD =
buildConf("spark.sql.execution.topKSortFallbackThreshold")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()
val z = AttributeReference("z", IntegerType)()
val t0 = OneRowRelation()
val testRelation = LocalRelation(a, b, c)
val testRelation2 = LocalRelation(x, y, z)

Expand Down Expand Up @@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest {

test("correlated values in project") {
val outerPlan = testRelation2
val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation())
val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation()))
val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0)
val correctAnswer = Project(
Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we optimize away the DomainJoin at the end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now once the domain join is added, it will always be rewritten as an inner join because the join condition in the subquery might not be null: select (select c1 where c1 = c2 + 1) from t.

check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}

test("correlated values in project with alias") {
val outerPlan = testRelation2
val innerPlan =
Project(Seq(OuterReference(x), 'y1, 'sum),
Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum),
Project(Seq(
OuterReference(x),
OuterReference(y).as("y1"),
Add(OuterReference(x), OuterReference(y)).as("sum")),
testRelation)).analyze
val correctAnswer =
Project(Seq(x, 'y1, 'sum, y),
Project(Seq(x, y.as("y1"), (x + y).as("sum"), y),
Project(Seq(x.as("x1"), 'y1, 'sum, x, y),
Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y),
DomainJoin(Seq(x, y), testRelation))).analyze
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}
Expand All @@ -228,28 +230,28 @@ class DecorrelateInnerQuerySuite extends PlanTest {
val outerPlan = testRelation2
val innerPlan =
Project(
Seq(OuterReference(x)),
Seq(OuterReference(x).as("x1")),
Filter(
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
testRelation
)
)
val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation))
val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation))
check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c))
}

test("correlated values in project without correlated equality conditions in filter") {
val outerPlan = testRelation2
val innerPlan =
Project(
Seq(OuterReference(y)),
Seq(OuterReference(y).as("y1")),
Filter(
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
testRelation
)
)
val correctAnswer =
Project(Seq(y, a, c),
Project(Seq(y.as("y1"), y, a, c),
Filter(b === 1,
DomainJoin(Seq(y), testRelation)
)
Expand Down
Loading