Skip to content

Commit 08e47d2

Browse files
committed
Left/Right outer broadcast nested loop join codegen
1 parent 31da907 commit 08e47d2

File tree

4 files changed

+81
-6
lines changed

4 files changed

+81
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ case class BroadcastNestedLoopJoinExec(
396396
}
397397

398398
override def supportCodegen: Boolean = (joinType, buildSide) match {
399-
case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true
399+
case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) |
400+
(LeftSemi | LeftAnti, BuildRight) => true
400401
case _ => false
401402
}
402403

@@ -413,6 +414,7 @@ case class BroadcastNestedLoopJoinExec(
413414
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
414415
(joinType, buildSide) match {
415416
case (_: InnerLike, _) => codegenInner(ctx, input)
417+
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input)
416418
case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true)
417419
case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false)
418420
case _ =>
@@ -458,6 +460,41 @@ case class BroadcastNestedLoopJoinExec(
458460
""".stripMargin
459461
}
460462

463+
private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
464+
val (_, buildRowArrayTerm) = prepareBroadcast(ctx)
465+
val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast)
466+
val buildVars = genBuildSideVars(ctx, buildRow, broadcast)
467+
468+
val resultVars = buildSide match {
469+
case BuildLeft => buildVars ++ input
470+
case BuildRight => input ++ buildVars
471+
}
472+
val arrayIndex = ctx.freshName("arrayIndex")
473+
val shouldOutputRow = ctx.freshName("shouldOutputRow")
474+
val foundMatch = ctx.freshName("foundMatch")
475+
val numOutput = metricTerm(ctx, "numOutputRows")
476+
477+
s"""
478+
|boolean $foundMatch = false;
479+
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
480+
| UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
481+
| boolean $shouldOutputRow = false;
482+
| $checkCondition {
483+
| $shouldOutputRow = true;
484+
| $foundMatch = true;
485+
| }
486+
| if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) {
487+
| $buildRow = null;
488+
| $shouldOutputRow = true;
489+
| }
490+
| if ($shouldOutputRow) {
491+
| $numOutput.add(1);
492+
| ${consume(ctx, resultVars)}
493+
| }
494+
|}
495+
""".stripMargin
496+
}
497+
461498
private def codegenLeftExistence(
462499
ctx: CodegenContext,
463500
input: Seq[ExprCode],

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,44 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
211211
}
212212
}
213213

214+
test("Left/Right outer BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
215+
val df1 = spark.range(4).select($"id".as("k1"))
216+
val df2 = spark.range(3).select($"id".as("k2"))
217+
val df3 = spark.range(2).select($"id".as("k3"))
218+
219+
Seq(true, false).foreach { codegenEnabled =>
220+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
221+
// test left outer join
222+
val leftOuterJoinDF = df1.join(df2, $"k1" > $"k2", "left_outer")
223+
var hasJoinInCodegen = leftOuterJoinDF.queryExecution.executedPlan.collect {
224+
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
225+
}.size === 1
226+
assert(hasJoinInCodegen == codegenEnabled)
227+
checkAnswer(leftOuterJoinDF,
228+
Seq(Row(0, null), Row(1, 0), Row(2, 0), Row(2, 1), Row(3, 0), Row(3, 1), Row(3, 2)))
229+
230+
// test right outer join
231+
val rightOuterJoinDF = df1.join(df2, $"k1" < $"k2", "right_outer")
232+
hasJoinInCodegen = rightOuterJoinDF.queryExecution.executedPlan.collect {
233+
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
234+
}.size === 1
235+
assert(hasJoinInCodegen == codegenEnabled)
236+
checkAnswer(rightOuterJoinDF, Seq(Row(null, 0), Row(0, 1), Row(0, 2), Row(1, 2)))
237+
238+
// test a combination of left outer and right outer joins
239+
val twoJoinsDF = df1.join(df2, $"k1" > $"k2" + 1, "right_outer")
240+
.join(df3, $"k1" <= $"k3", "left_outer")
241+
hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect {
242+
case WholeStageCodegenExec(BroadcastNestedLoopJoinExec(
243+
_: BroadcastNestedLoopJoinExec, _, _, _, _)) => true
244+
}.size === 1
245+
assert(hasJoinInCodegen == codegenEnabled)
246+
checkAnswer(twoJoinsDF,
247+
Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null)))
248+
}
249+
}
250+
}
251+
214252
test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
215253
val df1 = spark.range(4).select($"id".as("k1"))
216254
val df2 = spark.range(3).select($"id".as("k2"))

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
149149
}
150150
}
151151

152-
test(s"$testName using BroadcastNestedLoopJoin build left") {
152+
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build left") { _ =>
153153
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
154154
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
155155
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)),
@@ -158,7 +158,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
158158
}
159159
}
160160

161-
test(s"$testName using BroadcastNestedLoopJoin build right") {
161+
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ =>
162162
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
163163
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
164164
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)),

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
452452
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a"
453453
val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " +
454454
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a"
455-
Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true))
456-
.foreach { case (query, enableWholeStage) =>
455+
Seq((leftQuery, 0L, false), (rightQuery, 0L, false), (leftQuery, 1L, true),
456+
(rightQuery, 1L, true)).foreach { case (query, nodeId, enableWholeStage) =>
457457
val df = spark.sql(query)
458458
testSparkPlanMetrics(df, 2, Map(
459-
0L -> (("BroadcastNestedLoopJoin", Map(
459+
nodeId -> (("BroadcastNestedLoopJoin", Map(
460460
"number of output rows" -> 12L)))),
461461
enableWholeStage
462462
)

0 commit comments

Comments
 (0)