Skip to content

Commit e00b81e

Browse files
peter-tothattilapiros
authored andcommitted
[SPARK-34079][SQL] Merge non-correlated scalar subqueries
### What changes were proposed in this pull request? This PR adds a new optimizer rule `MergeScalarSubqueries` to merge multiple non-correlated `ScalarSubquery`s to compute multiple scalar values once. E.g. the following query: ``` SELECT (SELECT avg(a) FROM t), (SELECT sum(b) FROM t) ``` is optimized from: ``` == Optimized Logical Plan == Project [scalar-subquery#242 [] AS scalarsubquery()#253, scalar-subquery#243 [] AS scalarsubquery()#254L] : :- Aggregate [avg(a#244) AS avg(a)#247] : : +- Project [a#244] : : +- Relation default.t[a#244,b#245] parquet : +- Aggregate [sum(a#251) AS sum(a)#250L] : +- Project [a#251] : +- Relation default.t[a#251,b#252] parquet +- OneRowRelation ``` to: ``` == Optimized Logical Plan == Project [scalar-subquery#242 [].avg(a) AS scalarsubquery()#253, scalar-subquery#243 [].sum(a) AS scalarsubquery()#254L] : :- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] : : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L] : : +- Project [a#244] : : +- Relation default.t[a#244,b#245] parquet : +- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L] : +- Project [a#244] : +- Relation default.t[a#244,b#245] parquet +- OneRowRelation ``` and in the physical plan subqueries are reused: ``` == Physical Plan == AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == *(1) Project [Subquery subquery#242, [id=#113].avg(a) AS scalarsubquery()#253, ReusedSubquery Subquery subquery#242, [id=#113].sum(a) AS scalarsubquery()#254L] : :- Subquery subquery#242, [id=#113] : : +- AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == *(2) Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] +- *(2) HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)], output=[avg(a)#247, sum(a)#250L]) +- ShuffleQueryStage 0 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#158] +- *(1) HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)], output=[sum#262, count#263L, sum#264L]) +- *(1) ColumnarToRow +- FileScan parquet default.t[a#244] Batched: true, DataFilters: [], Format: Parquet, Location: ..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:int> +- == Initial Plan == Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260] +- HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)], output=[avg(a)#247, sum(a)#250L]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#110] +- HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)], output=[sum#262, count#263L, sum#264L]) +- FileScan parquet default.t[a#244] Batched: true, DataFilters: [], Format: Parquet, Location: ..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:int> : +- ReusedSubquery Subquery subquery#242, [id=#113] +- *(1) Scan OneRowRelation[] +- == Initial Plan == ... ``` Please note that the above simple example could be easily optimized into a common select expression without reuse node, but this PR can handle more complex queries as well. ### Why are the changes needed? Performance improvement. ``` [info] TPCDS Snappy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] q9 - MergeScalarSubqueries off 50798 52521 1423 0.0 Infinity 1.0X [info] q9 - MergeScalarSubqueries on 19484 19675 226 0.0 Infinity 2.6X [info] TPCDS Snappy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] q9b - MergeScalarSubqueries off 15430 17803 NaN 0.0 Infinity 1.0X [info] q9b - MergeScalarSubqueries on 3862 4002 196 0.0 Infinity 4.0X ``` Please find `q9b` in the description of SPARK-34079. It is a variant of [q9.sql](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q9.sql) using CTE. The performance improvement in case of `q9` comes from merging 15 subqueries into 5 and in case of `q9b` it comes from merging 5 subqueries into 1. ### Does this PR introduce _any_ user-facing change? No. But this optimization can be disabled with `spark.sql.optimizer.excludedRules` config. ### How was this patch tested? Existing and new UTs. Closes #32298 from peter-toth/SPARK-34079-multi-column-scalar-subquery. Lead-authored-by: Peter Toth <[email protected]> Co-authored-by: attilapiros <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 21e48b7 commit e00b81e

File tree

19 files changed

+1706
-1600
lines changed

19 files changed

+1706
-1600
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ case class BloomFilterMightContain(
5656
case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
5757
case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
5858
TypeCheckResult.TypeCheckSuccess
59+
case GetStructField(subquery: PlanExpression[_], _, _)
60+
if !subquery.containsPattern(OUTER_REFERENCE) =>
61+
TypeCheckResult.TypeCheckSuccess
5962
case _ =>
6063
TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " +
6164
"should be either a constant value or a scalar subquery expression")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala

Lines changed: 389 additions & 0 deletions
Large diffs are not rendered by default.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
121121
private def pushdownPredicatesAndAttributes(
122122
plan: LogicalPlan,
123123
cteMap: CTEMap): LogicalPlan = plan.transformWithSubqueries {
124-
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates) =>
124+
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates, _) =>
125125
val (_, _, newPreds, newAttrSet) = cteMap(id)
126126
val originalPlan = originalPlanWithPredicates.map(_._1).getOrElse(child)
127127
val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty)
@@ -169,7 +169,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
169169
object CleanUpTempCTEInfo extends Rule[LogicalPlan] {
170170
override def apply(plan: LogicalPlan): LogicalPlan =
171171
plan.transformWithPruning(_.containsPattern(CTE)) {
172-
case cteDef @ CTERelationDef(_, _, Some(_)) =>
172+
case cteDef @ CTERelationDef(_, _, Some(_), _) =>
173173
cteDef.copy(originalPlanWithPredicates = None)
174174
}
175175
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ object ReplaceCTERefWithRepartition extends Rule[LogicalPlan] {
4747
case WithCTE(child, cteDefs) =>
4848
cteDefs.foreach { cteDef =>
4949
val inlined = replaceWithRepartition(cteDef.child, cteMap)
50-
val withRepartition = if (inlined.isInstanceOf[RepartitionOperation]) {
51-
// If the CTE definition plan itself is a repartition operation, we do not need to add an
52-
// extra repartition shuffle.
53-
inlined
54-
} else {
55-
Repartition(conf.numShufflePartitions, shuffle = true, inlined)
56-
}
50+
val withRepartition =
51+
if (inlined.isInstanceOf[RepartitionOperation] || cteDef.underSubquery) {
52+
// If the CTE definition plan itself is a repartition operation or if it hosts a merged
53+
// scalar subquery, we do not need to add an extra repartition shuffle.
54+
inlined
55+
} else {
56+
Repartition(conf.numShufflePartitions, shuffle = true, inlined)
57+
}
5758
cteMap.put(cteDef.id, withRepartition)
5859
}
5960
replaceWithRepartition(child, cteMap)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRe
2222
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
2323
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
2424
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
2828
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -663,11 +663,14 @@ case class UnresolvedWith(
663663
* predicates that have been pushed down into `child`. This is
664664
* a temporary field used by optimization rules for CTE predicate
665665
* pushdown to help ensure rule idempotency.
666+
* @param underSubquery If true, it means we don't need to add a shuffle for this CTE relation as
667+
* subquery reuse will be applied to reuse CTE relation output.
666668
*/
667669
case class CTERelationDef(
668670
child: LogicalPlan,
669671
id: Long = CTERelationDef.newId,
670-
originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None) extends UnaryNode {
672+
originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None,
673+
underSubquery: Boolean = false) extends UnaryNode {
671674

672675
final override val nodePatterns: Seq[TreePattern] = Seq(CTE)
673676

@@ -678,17 +681,19 @@ case class CTERelationDef(
678681
}
679682

680683
object CTERelationDef {
681-
private val curId = new java.util.concurrent.atomic.AtomicLong()
684+
private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
682685
def newId: Long = curId.getAndIncrement()
683686
}
684687

685688
/**
686689
* Represents the relation of a CTE reference.
687-
* @param cteId The ID of the corresponding CTE definition.
688-
* @param _resolved Whether this reference is resolved.
689-
* @param output The output attributes of this CTE reference, which can be different from
690-
* the output of its corresponding CTE definition after attribute de-duplication.
691-
* @param statsOpt The optional statistics inferred from the corresponding CTE definition.
690+
* @param cteId The ID of the corresponding CTE definition.
691+
* @param _resolved Whether this reference is resolved.
692+
* @param output The output attributes of this CTE reference, which can be different
693+
* from the output of its corresponding CTE definition after attribute
694+
* de-duplication.
695+
* @param statsOpt The optional statistics inferred from the corresponding CTE
696+
* definition.
692697
*/
693698
case class CTERelationRef(
694699
cteId: Long,
@@ -1014,6 +1019,24 @@ case class Aggregate(
10141019
}
10151020
}
10161021

1022+
object Aggregate {
1023+
def isAggregateBufferMutable(schema: StructType): Boolean = {
1024+
schema.forall(f => UnsafeRow.isMutable(f.dataType))
1025+
}
1026+
1027+
def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
1028+
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
1029+
isAggregateBufferMutable(aggregationBufferSchema)
1030+
}
1031+
1032+
def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
1033+
aggregateExpressions.map(_.aggregateFunction).exists {
1034+
case _: TypedImperativeAggregate[_] => true
1035+
case _ => false
1036+
}
1037+
}
1038+
}
1039+
10171040
case class Window(
10181041
windowExpressions: Seq[NamedExpression],
10191042
partitionSpec: Seq[Expression],

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ object TreePattern extends Enumeration {
7777
val REGEXP_REPLACE: Value = Value
7878
val RUNTIME_REPLACEABLE: Value = Value
7979
val SCALAR_SUBQUERY: Value = Value
80+
val SCALAR_SUBQUERY_REFERENCE: Value = Value
8081
val SCALA_UDF: Value = Value
8182
val SORT: Value = Value
8283
val SUBQUERY_ALIAS: Value = Value

0 commit comments

Comments
 (0)