From 08e47d2fc538838892bfabad3a1a93d85ec5228b Mon Sep 17 00:00:00 2001 From: linzebing Date: Mon, 22 Mar 2021 11:08:12 -0700 Subject: [PATCH 1/2] Left/Right outer broadcast nested loop join codegen --- .../joins/BroadcastNestedLoopJoinExec.scala | 39 ++++++++++++++++++- .../execution/WholeStageCodegenSuite.scala | 38 ++++++++++++++++++ .../sql/execution/joins/OuterJoinSuite.scala | 4 +- .../execution/metric/SQLMetricsSuite.scala | 6 +-- 4 files changed, 81 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 482c3a3091f86..a431662b71909 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -396,7 +396,8 @@ case class BroadcastNestedLoopJoinExec( } override def supportCodegen: Boolean = (joinType, buildSide) match { - case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true + case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | + (LeftSemi | LeftAnti, BuildRight) => true case _ => false } @@ -413,6 +414,7 @@ case class BroadcastNestedLoopJoinExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { (joinType, buildSide) match { case (_: InnerLike, _) => codegenInner(ctx, input) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input) case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true) case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false) case _ => @@ -458,6 +460,41 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (_, buildRowArrayTerm) = prepareBroadcast(ctx) + val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) + val buildVars = genBuildSideVars(ctx, buildRow, broadcast) + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + val arrayIndex = ctx.freshName("arrayIndex") + val shouldOutputRow = ctx.freshName("shouldOutputRow") + val foundMatch = ctx.freshName("foundMatch") + val numOutput = metricTerm(ctx, "numOutputRows") + + s""" + |boolean $foundMatch = false; + |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { + | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | boolean $shouldOutputRow = false; + | $checkCondition { + | $shouldOutputRow = true; + | $foundMatch = true; + | } + | if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) { + | $buildRow = null; + | $shouldOutputRow = true; + | } + | if ($shouldOutputRow) { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } + private def codegenLeftExistence( ctx: CodegenContext, input: Seq[ExprCode], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 8246bca1893a9..d23838351cd23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -211,6 +211,44 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } + test("Left/Right outer BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") { + val df1 = spark.range(4).select($"id".as("k1")) + val df2 = spark.range(3).select($"id".as("k2")) + val df3 = spark.range(2).select($"id".as("k3")) + + Seq(true, false).foreach { codegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) { + // test left outer join + val leftOuterJoinDF = df1.join(df2, $"k1" > $"k2", "left_outer") + var hasJoinInCodegen = leftOuterJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(leftOuterJoinDF, + Seq(Row(0, null), Row(1, 0), Row(2, 0), Row(2, 1), Row(3, 0), Row(3, 1), Row(3, 2))) + + // test right outer join + val rightOuterJoinDF = df1.join(df2, $"k1" < $"k2", "right_outer") + hasJoinInCodegen = rightOuterJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(rightOuterJoinDF, Seq(Row(null, 0), Row(0, 1), Row(0, 2), Row(1, 2))) + + // test a combination of left outer and right outer joins + val twoJoinsDF = df1.join(df2, $"k1" > $"k2" + 1, "right_outer") + .join(df3, $"k1" <= $"k3", "left_outer") + hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(twoJoinsDF, + Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null))) + } + } + } + test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") { val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 150d40d0301fc..810eeea5b9a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -149,7 +149,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using BroadcastNestedLoopJoin build left") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build left") { _ => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)), @@ -158,7 +158,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using BroadcastNestedLoopJoin build right") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index dd99368e3a87b..50f980643d2d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -452,11 +452,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" - Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true)) - .foreach { case (query, enableWholeStage) => + Seq((leftQuery, 0L, false), (rightQuery, 0L, false), (leftQuery, 1L, true), + (rightQuery, 1L, true)).foreach { case (query, nodeId, enableWholeStage) => val df = spark.sql(query) testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastNestedLoopJoin", Map( + nodeId -> (("BroadcastNestedLoopJoin", Map( "number of output rows" -> 12L)))), enableWholeStage ) From 8ee75369cb66abf85ecf6f7bde98cbdd3f1287b9 Mon Sep 17 00:00:00 2001 From: linzebing Date: Mon, 22 Mar 2021 18:28:58 -0700 Subject: [PATCH 2/2] handle broadcast side is empty and add a test to cover it --- .../joins/BroadcastNestedLoopJoinExec.scala | 48 +++++++++++-------- .../execution/WholeStageCodegenSuite.scala | 11 ++++- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index a431662b71909..fa1a57a8ae3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -461,7 +461,7 @@ case class BroadcastNestedLoopJoinExec( } private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (_, buildRowArrayTerm) = prepareBroadcast(ctx) + val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) val buildVars = genBuildSideVars(ctx, buildRow, broadcast) @@ -474,25 +474,33 @@ case class BroadcastNestedLoopJoinExec( val foundMatch = ctx.freshName("foundMatch") val numOutput = metricTerm(ctx, "numOutputRows") - s""" - |boolean $foundMatch = false; - |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { - | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; - | boolean $shouldOutputRow = false; - | $checkCondition { - | $shouldOutputRow = true; - | $foundMatch = true; - | } - | if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) { - | $buildRow = null; - | $shouldOutputRow = true; - | } - | if ($shouldOutputRow) { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } - |} - """.stripMargin + if (buildRowArray.isEmpty) { + s""" + |UnsafeRow $buildRow = null; + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + } else { + s""" + |boolean $foundMatch = false; + |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { + | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | boolean $shouldOutputRow = false; + | $checkCondition { + | $shouldOutputRow = true; + | $foundMatch = true; + | } + | if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) { + | $buildRow = null; + | $shouldOutputRow = true; + | } + | if ($shouldOutputRow) { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } } private def codegenLeftExistence( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index d23838351cd23..b66308c4f880f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -215,6 +215,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) val df3 = spark.range(2).select($"id".as("k3")) + val df4 = spark.range(0).select($"id".as("k4")) Seq(true, false).foreach { codegenEnabled => withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) { @@ -240,11 +241,19 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .join(df3, $"k1" <= $"k3", "left_outer") hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( - _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true }.size === 1 assert(hasJoinInCodegen == codegenEnabled) checkAnswer(twoJoinsDF, Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null))) + + // test build side is empty + val buildSideIsEmptyDF = df3.join(df4, $"k3" > $"k4", "left_outer") + hasJoinInCodegen = buildSideIsEmptyDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(buildSideIsEmptyDF, Seq(Row(0, null), Row(1, null))) } } }