Skip to content

Commit 3b114af

Browse files
nemanjapetr-dbcloud-fan
authored andcommitted
[SPARK-50739][SQL] Recursive CTE. Analyzer changes to unravel and resolve the recursion components
### What changes were proposed in this pull request? Instruction for reviewers https://docs.google.com/document/d/1qcEJxqoXcr5cSt6HgIQjWQSqhfkSaVYkoDHsg5oxXp4/edit Introduction of UnionLoop and UnionLoopRef logical plan classes. Changes in ResolveWithCTE.scala to have the analyzer grok recursive anchors. Specifically we substitute CTERelationRef with UnionLoopRef, and Union with UnionLoop. We untangle the dead loop in resolving where recursive CTE reference is referring to an unresolved CTE definition, which itself cannot be resolved as one of its descendants is an unresolved CTE reference. ### Why are the changes needed? Support for the recursive CTE. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Proposed changes in this PR are no-op. Tested ./build/sbt "test:testOnly org.apache.spark.sql.SQLQueryTestSuite" ./build/sbt "test:testOnly *PlanParserSuite" ### Was this patch authored or co-authored using generative AI tooling? No Closes #49351 from nemanjapetr-db/nemanjapetr-db/rcte3. Authored-by: Nemanja Petrovic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2721a50 commit 3b114af

File tree

11 files changed

+587
-216
lines changed

11 files changed

+587
-216
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,6 +3105,12 @@
31053105
],
31063106
"sqlState" : "42613"
31073107
},
3108+
"INVALID_RECURSIVE_CTE" : {
3109+
"message" : [
3110+
"Invalid recursive definition found. Recursive queries must contain an UNION or an UNION ALL statement with 2 children. The first child needs to be the anchor term without any recursive references."
3111+
],
3112+
"sqlState" : "42836"
3113+
},
31083114
"INVALID_REGEXP_REPLACE" : {
31093115
"message" : [
31103116
"Could not perform regexp_replace for source = \"<source>\", pattern = \"<pattern>\", replacement = \"<replacement>\" and position = <position>."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,9 @@ object CTESubstitution extends Rule[LogicalPlan] {
316316
// CTE definition can reference a previous one or itself if recursion allowed.
317317
val substituted = substituteCTE(innerCTEResolved, alwaysInline,
318318
resolvedCTERelations, recursiveCTERelation)
319-
val cteRelation = CTERelationDef(substituted)
319+
val cteRelation = recursiveCTERelation
320+
.map(_._2.copy(child = substituted))
321+
.getOrElse(CTERelationDef(substituted))
320322
if (!alwaysInline) {
321323
cteDefs += cteRelation
322324
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala

Lines changed: 127 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@ package org.apache.spark.sql.catalyst.analysis
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
23-
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, WithCTE}
23+
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
26+
import org.apache.spark.sql.errors.QueryCompilationErrors
2627

2728
/**
2829
* Updates CTE references with the resolve output attributes of corresponding CTE definitions.
2930
*/
3031
object ResolveWithCTE extends Rule[LogicalPlan] {
3132
override def apply(plan: LogicalPlan): LogicalPlan = {
3233
if (plan.containsAllPatterns(CTE)) {
34+
// A helper map definitionID->Definition. Used for non-recursive CTE definitions only, either
35+
// inherently non-recursive or that became non-recursive due to recursive CTERelationRef->
36+
// UnionLoopRef substitution. Bridges the gap between two CTE resolutions (one for WithCTE,
37+
// another for reference) by materializing the resolved definitions in first pass and
38+
// consuming them in the second.
3339
val cteDefMap = mutable.HashMap.empty[Long, CTERelationDef]
3440
resolveWithCTE(plan, cteDefMap)
3541
} else {
@@ -41,16 +47,113 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
4147
plan: LogicalPlan,
4248
cteDefMap: mutable.HashMap[Long, CTERelationDef]): LogicalPlan = {
4349
plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) {
44-
case w @ WithCTE(_, cteDefs) =>
45-
cteDefs.foreach { cteDef =>
46-
if (cteDef.resolved) {
47-
cteDefMap.put(cteDef.id, cteDef)
48-
}
50+
case withCTE @ WithCTE(_, cteDefs) =>
51+
val newCTEDefs = cteDefs.map {
52+
// `cteDef.recursive` means "presence of a recursive CTERelationRef under cteDef". The
53+
// side effect of node substitution below is that after CTERelationRef substitution
54+
// its cteDef is no more considered `recursive`. This code path is common for `cteDef`
55+
// that were non-recursive from the get go, as well as those that are no more recursive
56+
// due to node substitution.
57+
case cteDef if !cteDef.recursive =>
58+
if (cteDef.resolved) {
59+
cteDefMap.put(cteDef.id, cteDef)
60+
}
61+
cteDef
62+
case cteDef =>
63+
cteDef.child match {
64+
// If it is a supported recursive CTE query pattern (4 so far), extract the anchor and
65+
// recursive plans from the Union and rewrite Union with UnionLoop. The recursive CTE
66+
// references inside UnionLoop's recursive plan will be rewritten as UnionLoopRef,
67+
// using the output of the resolved anchor plan. The side effect of recursive
68+
// CTERelationRef->UnionLoopRef substitution is that `cteDef` that was originally
69+
// considered `recursive` is no more in the context of `cteDef.recursive` method
70+
// definition.
71+
//
72+
// Simple case of duplicating (UNION ALL) clause.
73+
case alias @ SubqueryAlias(_, Union(Seq(anchor, recursion), false, false)) =>
74+
if (!anchor.resolved) {
75+
cteDef
76+
} else {
77+
val loop = UnionLoop(
78+
cteDef.id,
79+
anchor,
80+
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None))
81+
cteDef.copy(child = alias.copy(child = loop))
82+
}
83+
84+
// The case of CTE name followed by a parenthesized list of column name(s), eg.
85+
// WITH RECURSIVE t(n).
86+
case alias @ SubqueryAlias(_,
87+
columnAlias @ UnresolvedSubqueryColumnAliases(
88+
colNames,
89+
Union(Seq(anchor, recursion), false, false)
90+
)) =>
91+
if (!anchor.resolved) {
92+
cteDef
93+
} else {
94+
val loop = UnionLoop(
95+
cteDef.id,
96+
anchor,
97+
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)))
98+
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
99+
}
100+
101+
// If the recursion is described with an UNION (deduplicating) clause then the
102+
// recursive term should not return those rows that have been calculated previously,
103+
// and we exclude those rows from the current iteration result.
104+
case alias @ SubqueryAlias(_,
105+
Distinct(Union(Seq(anchor, recursion), false, false))) =>
106+
if (!anchor.resolved) {
107+
cteDef
108+
} else {
109+
val loop = UnionLoop(
110+
cteDef.id,
111+
Distinct(anchor),
112+
Except(
113+
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None),
114+
UnionLoopRef(cteDef.id, anchor.output, true),
115+
isAll = false
116+
)
117+
)
118+
cteDef.copy(child = alias.copy(child = loop))
119+
}
120+
121+
// The case of CTE name followed by a parenthesized list of column name(s).
122+
case alias @ SubqueryAlias(_,
123+
columnAlias@UnresolvedSubqueryColumnAliases(
124+
colNames,
125+
Distinct(Union(Seq(anchor, recursion), false, false))
126+
)) =>
127+
if (!anchor.resolved) {
128+
cteDef
129+
} else {
130+
val loop = UnionLoop(
131+
cteDef.id,
132+
Distinct(anchor),
133+
Except(
134+
rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
135+
UnionLoopRef(cteDef.id, anchor.output, true),
136+
isAll = false
137+
)
138+
)
139+
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
140+
}
141+
142+
case other =>
143+
// We do not support cases of sole Union (needs a SubqueryAlias above it), nor
144+
// Project (as UnresolvedSubqueryColumnAliases have not been substituted with the
145+
// Project yet), leaving us with cases of SubqueryAlias->Union and SubqueryAlias->
146+
// UnresolvedSubqueryColumnAliases->Union. The same applies to Distinct Union.
147+
throw QueryCompilationErrors.invalidRecursiveCteError(
148+
"Unsupported recursive CTE UNION placement.")
149+
}
49150
}
50-
w
151+
withCTE.copy(cteDefs = newCTEDefs)
51152

153+
// This is a non-recursive reference to a definition.
52154
case ref: CTERelationRef if !ref.resolved =>
53155
cteDefMap.get(ref.cteId).map { cteDef =>
156+
// cteDef is certainly resolved, otherwise it would not have been in the map.
54157
CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming)
55158
}.getOrElse {
56159
ref
@@ -62,4 +165,21 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
62165
}
63166
}
64167
}
168+
169+
// Substitutes a recursive CTERelationRef with UnionLoopRef if CTERelationRef refers to cteDefId.
170+
// Assumes that `anchor` is already resolved. If `columnNames` is set (which happens if an
171+
// UnresolvedSubqueryColumnAliases is atop a Union) also places an UnresolvedSubqueryColumnAliases
172+
// node above UnionLoopRef. At some later stage UnresolvedSubqueryColumnAliases gets resolved to
173+
// a Project node.
174+
private def rewriteRecursiveCTERefs(
175+
recursion: LogicalPlan,
176+
anchor: LogicalPlan,
177+
cteDefId: Long,
178+
columnNames: Option[Seq[String]]) = {
179+
recursion.transformWithPruning(_.containsPattern(CTE)) {
180+
case r: CTERelationRef if r.recursive && r.cteId == cteDefId =>
181+
val ref = UnionLoopRef(r.cteId, anchor.output, false)
182+
columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref)
183+
}
184+
}
65185
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
122122
private def pushdownPredicatesAndAttributes(
123123
plan: LogicalPlan,
124124
cteMap: CTEMap): LogicalPlan = plan.transformWithSubqueries {
125-
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates, _, _) =>
125+
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates, _) =>
126126
val (_, _, newPreds, newAttrSet) = cteMap(id)
127127
val originalPlan = originalPlanWithPredicates.map(_._1).getOrElse(child)
128128
val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty)
@@ -170,7 +170,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
170170
object CleanUpTempCTEInfo extends Rule[LogicalPlan] {
171171
override def apply(plan: LogicalPlan): LogicalPlan =
172172
plan.transformWithPruning(_.containsPattern(CTE)) {
173-
case cteDef @ CTERelationDef(_, _, Some(_), _, _) =>
173+
case cteDef @ CTERelationDef(_, _, Some(_), _) =>
174174
cteDef.copy(originalPlanWithPredicates = None)
175175
}
176176
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,4 +774,23 @@ object QueryPlan extends PredicateHelper {
774774
case e: AnalysisException => append(e.toString)
775775
}
776776
}
777+
778+
/**
779+
* Generate detailed field string with different format based on type of input value. Supported
780+
* input values are sequences and strings. An empty sequences converts to []. A non-empty
781+
* sequences converts to square brackets-enclosed, comma-separated values, prefixed with a
782+
* sequence length. An empty string converts to None, while a non-empty string is verbatim
783+
* outputted. In all four cases, user-provided fieldName prefixes the output. Examples:
784+
* List("Hello", "World") -> <fieldName>: [2]: [Hello, World]
785+
* List() -> <fieldName>: []
786+
* "hello_world" -> <fieldName>: hello_world
787+
* "" -> <fieldName>: None
788+
*/
789+
def generateFieldString(fieldName: String, values: Any): String = values match {
790+
case iter: Iterable[_] if (iter.size == 0) => s"${fieldName}: []"
791+
case iter: Iterable[_] => s"${fieldName} [${iter.size}]: ${iter.mkString("[", ", ", "]")}"
792+
case str: String if (str == null || str.isEmpty) => s"${fieldName}: None"
793+
case str: String => s"${fieldName}: ${str}"
794+
case _ => throw new IllegalArgumentException(s"Unsupported type for argument values: $values")
795+
}
777796
}

0 commit comments

Comments
 (0)