Skip to content

Commit c148cfe

Browse files
update test cases
1 parent 73afab1 commit c148cfe

File tree

4 files changed

+174
-52
lines changed

4 files changed

+174
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,13 @@ package object dsl {
390390
condition: Option[Expression] = None): LogicalPlan =
391391
Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)
392392

393+
def lateralJoin(
394+
otherPlan: LogicalPlan,
395+
joinType: JoinType = Inner,
396+
condition: Option[Expression] = None): LogicalPlan = {
397+
LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition)
398+
}
399+
393400
def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder](
394401
otherPlan: LogicalPlan,
395402
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,18 +727,22 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {
727727
}
728728
}
729729

730+
private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = {
731+
plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined
732+
}
733+
730734
/**
731735
* Rewrite a subquery expression into one or more expressions. The rewrite can only be done
732736
* if there is no nested subqueries in the subquery plan.
733737
*/
734738
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
735739
case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
736-
if right.plan.subqueries.isEmpty && right.joinCond.isEmpty =>
740+
if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty =>
737741
Project(left.output ++ projectList, left)
738742
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
739743
_.containsPattern(SCALAR_SUBQUERY)) {
740744
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
741-
if s.plan.subqueries.isEmpty && s.joinCond.isEmpty =>
745+
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
742746
assert(projectList.size == 1)
743747
projectList.head
744748
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery}
23+
import org.apache.spark.sql.catalyst.plans._
24+
import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation, Project}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
import org.apache.spark.sql.internal.SQLConf
27+
28+
class OptimizeOneRowRelationSubquerySuite extends PlanTest {
29+
30+
private var optimizeOneRowRelationSubqueryEnabled: Boolean = _
31+
32+
protected override def beforeAll(): Unit = {
33+
super.beforeAll()
34+
optimizeOneRowRelationSubqueryEnabled =
35+
SQLConf.get.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)
36+
SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY, true)
37+
}
38+
39+
protected override def afterAll(): Unit = {
40+
SQLConf.get.setConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY,
41+
optimizeOneRowRelationSubqueryEnabled)
42+
super.afterAll()
43+
}
44+
45+
object Optimize extends RuleExecutor[LogicalPlan] {
46+
val batches =
47+
Batch("Subquery", Once,
48+
OptimizeOneRowRelationSubquery,
49+
PullupCorrelatedPredicates) :: Nil
50+
}
51+
52+
private def assertHasDomainJoin(plan: LogicalPlan): Unit = {
53+
assert(plan.collectWithSubqueries { case d: DomainJoin => d }.nonEmpty,
54+
s"Plan does not contain DomainJoin:\n$plan")
55+
}
56+
57+
val t0 = OneRowRelation()
58+
val a = 'a.int
59+
val b = 'b.int
60+
val t1 = LocalRelation(a, b)
61+
val t2 = LocalRelation('c.int, 'd.int)
62+
63+
test("Optimize scalar subquery with a single project") {
64+
// SELECT (SELECT a) FROM t1
65+
val query = t1.select(ScalarSubquery(t0.select('a)).as("sub"))
66+
val optimized = Optimize.execute(query.analyze)
67+
val correctAnswer = t1.select('a.as("sub"))
68+
comparePlans(optimized, correctAnswer.analyze)
69+
}
70+
71+
test("Optimize lateral subquery with a single project") {
72+
Seq(Inner, LeftOuter, Cross).foreach { joinType =>
73+
// SELECT * FROM t1 JOIN LATERAL (SELECT a, b)
74+
val query = t1.lateralJoin(t0.select('a, 'b), joinType, None)
75+
val optimized = Optimize.execute(query.analyze)
76+
val correctAnswer = t1.select('a, 'b, 'a.as("a"), 'b.as("b"))
77+
comparePlans(optimized, correctAnswer.analyze)
78+
}
79+
}
80+
81+
test("Optimize subquery with subquery alias") {
82+
val inner = t0.select('a).as("t2")
83+
val query = t1.select(ScalarSubquery(inner).as("sub"))
84+
val optimized = Optimize.execute(query.analyze)
85+
val correctAnswer = t1.select('a.as("sub"))
86+
comparePlans(optimized, correctAnswer.analyze)
87+
}
88+
89+
test("Optimize scalar subquery with multiple projects") {
90+
// SELECT (SELECT a1 + b1 FROM (SELECT a AS a1, b AS b1)) FROM t1
91+
val inner = t0.select('a.as("a1"), 'b.as("b1")).select(('a1 + 'b1).as("c"))
92+
val query = t1.select(ScalarSubquery(inner).as("sub"))
93+
val optimized = Optimize.execute(query.analyze)
94+
val correctAnswer = Project(Alias(Alias(a + b, "c")(), "sub")() :: Nil, t1)
95+
comparePlans(optimized, correctAnswer)
96+
}
97+
98+
test("Optimize lateral subquery with multiple projects") {
99+
Seq(Inner, LeftOuter, Cross).foreach { joinType =>
100+
val inner = t0.select('a.as("a1"), 'b.as("b1"))
101+
.select(('a1 + 'b1).as("c1"), ('a1 - 'b1).as("c2"))
102+
val query = t1.lateralJoin(inner, joinType, None)
103+
val optimized = Optimize.execute(query.analyze)
104+
val correctAnswer = t1.select('a, 'b, ('a + 'b).as("c1"), ('a - 'b).as("c2"))
105+
comparePlans(optimized, correctAnswer.analyze)
106+
}
107+
}
108+
109+
test("Optimize subquery with nested correlated subqueries") {
110+
// SELECT (SELECT (SELECT b) FROM (SELECT a AS b)) FROM t1
111+
val inner = t0.select('a.as("b")).select(ScalarSubquery(t0.select('b)).as("s"))
112+
val query = t1.select(ScalarSubquery(inner).as("sub"))
113+
val optimized = Optimize.execute(query.analyze)
114+
val correctAnswer = Project(Alias(Alias(a, "s")(), "sub")() :: Nil, t1)
115+
comparePlans(optimized, correctAnswer)
116+
}
117+
118+
test("Batch should be idempotent") {
119+
// SELECT (SELECT 1 WHERE a = a + 1) FROM t1
120+
val inner = t0.select(1).where('a === 'a + 1)
121+
val query = t1.select(ScalarSubquery(inner).as("sub"))
122+
val optimized = Optimize.execute(query.analyze)
123+
val doubleOptimized = Optimize.execute(optimized)
124+
comparePlans(optimized, doubleOptimized, checkAnalysis = false)
125+
}
126+
127+
test("Should not optimize scalar subquery with operators other than project") {
128+
// SELECT (SELECT a AS a1 WHERE a = 1) FROM t1
129+
val inner = t0.where('a === 1).select('a.as("a1"))
130+
val query = t1.select(ScalarSubquery(inner).as("sub"))
131+
val optimized = Optimize.execute(query.analyze)
132+
assertHasDomainJoin(optimized)
133+
}
134+
135+
test("Should not optimize subquery with non-deterministic expressions") {
136+
// SELECT (SELECT r FROM (SELECT a + rand() AS r)) FROM t1
137+
val inner = t0.select(('a + rand(0)).as("r")).select('r)
138+
val query = t1.select(ScalarSubquery(inner).as("sub"))
139+
val optimized = Optimize.execute(query.analyze)
140+
assertHasDomainJoin(optimized)
141+
}
142+
143+
test("Should not optimize lateral join with non-empty join conditions") {
144+
Seq(Inner, LeftOuter).foreach { joinType =>
145+
// SELECT * FROM t1 JOIN LATERAL (SELECT a AS a1, b AS b1) ON a = b1
146+
val query = t1.lateralJoin(t0.select('a.as("a1"), 'b.as("b1")), joinType, Some('a === 'b1))
147+
val optimized = Optimize.execute(query.analyze)
148+
assertHasDomainJoin(optimized)
149+
}
150+
}
151+
152+
test("Should not optimize subquery with nested subqueries") {
153+
// SELECT (SELECT (SELECT a WHERE a = 1) FROM (SELECT a AS a)) FROM t1
154+
val inner = t0.select('a).where('a === 1)
155+
val subquery = t0.select('a.as("a"))
156+
.select(ScalarSubquery(inner).as("s")).select('s + 1)
157+
val query = t1.select(ScalarSubquery(subquery).as("sub"))
158+
val optimized = Optimize.execute(query.analyze)
159+
assertHasDomainJoin(optimized)
160+
}
161+
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,54 +1877,4 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
18771877
"ReusedSubqueryExec should reuse an existing subquery")
18781878
}
18791879
}
1880-
1881-
test("SPARK-36063: optimize one row relation subqueries") {
1882-
withTempView("t") {
1883-
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
1884-
Seq(
1885-
"select (select c1) from t",
1886-
"select (select a) from t as t(a, b)",
1887-
"select (select c from (select c from (select c1 as c))) from t",
1888-
"select (select (select a) from (select c1, c2) t(a, b)) from t",
1889-
"select s.c1 from t, lateral (select c1, c2) s"
1890-
).foreach { query =>
1891-
Seq(true, false).foreach { enabled =>
1892-
withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> enabled.toString) {
1893-
val df = sql(query)
1894-
val plan = df.queryExecution.optimizedPlan
1895-
val joins = plan.collectWithSubqueries { case j: Join => j }
1896-
assert(joins.isEmpty == enabled)
1897-
checkAnswer(df, Row(0) :: Row(1) :: Nil)
1898-
}
1899-
}
1900-
}
1901-
}
1902-
}
1903-
1904-
test("SPARK-36063: optimize one row relation subqueries (negative case)") {
1905-
withTempView("t") {
1906-
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
1907-
withSQLConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY.key -> "true") {
1908-
Seq(
1909-
// With additional operators
1910-
("select (select c1 where c2 = 1) from t", Row(0) :: Row(null) :: Nil),
1911-
// With non-deterministic expressions
1912-
("select (select floor(r) from (select c1 + rand() as r)) from t",
1913-
Row(0) :: Row(1) :: Nil),
1914-
// With non-empty lateral join condition
1915-
("select * from t join lateral (select c1, c2) s on t.c1 = s.c2", Nil),
1916-
// With nested subqueries that cannot be optimized
1917-
("select (select (select a where a = 1) from (select c1 as a)) from t",
1918-
Row(null) :: Row(1) :: Nil),
1919-
("select * from t, lateral (select (select a where a = 1) from (select c1 as a))",
1920-
Row(0, 1, null) :: Row(1, 2, 1) :: Nil)
1921-
).foreach { case (query, expected) =>
1922-
val df = sql(query)
1923-
val joins = df.queryExecution.optimizedPlan.collect { case j: Join => j }
1924-
assert(joins.nonEmpty)
1925-
checkAnswer(df, expected)
1926-
}
1927-
}
1928-
}
1929-
}
19301880
}

0 commit comments

Comments
 (0)