Skip to content

Commit 8ee7536

Browse files
committed
handle broadcast side is empty and add a test to cover it
1 parent 08e47d2 commit 8ee7536

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ case class BroadcastNestedLoopJoinExec(
461461
}
462462

463463
private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
464-
val (_, buildRowArrayTerm) = prepareBroadcast(ctx)
464+
val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx)
465465
val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast)
466466
val buildVars = genBuildSideVars(ctx, buildRow, broadcast)
467467

@@ -474,25 +474,33 @@ case class BroadcastNestedLoopJoinExec(
474474
val foundMatch = ctx.freshName("foundMatch")
475475
val numOutput = metricTerm(ctx, "numOutputRows")
476476

477-
s"""
478-
|boolean $foundMatch = false;
479-
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
480-
| UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
481-
| boolean $shouldOutputRow = false;
482-
| $checkCondition {
483-
| $shouldOutputRow = true;
484-
| $foundMatch = true;
485-
| }
486-
| if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) {
487-
| $buildRow = null;
488-
| $shouldOutputRow = true;
489-
| }
490-
| if ($shouldOutputRow) {
491-
| $numOutput.add(1);
492-
| ${consume(ctx, resultVars)}
493-
| }
494-
|}
495-
""".stripMargin
477+
if (buildRowArray.isEmpty) {
478+
s"""
479+
|UnsafeRow $buildRow = null;
480+
|$numOutput.add(1);
481+
|${consume(ctx, resultVars)}
482+
""".stripMargin
483+
} else {
484+
s"""
485+
|boolean $foundMatch = false;
486+
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
487+
| UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
488+
| boolean $shouldOutputRow = false;
489+
| $checkCondition {
490+
| $shouldOutputRow = true;
491+
| $foundMatch = true;
492+
| }
493+
| if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) {
494+
| $buildRow = null;
495+
| $shouldOutputRow = true;
496+
| }
497+
| if ($shouldOutputRow) {
498+
| $numOutput.add(1);
499+
| ${consume(ctx, resultVars)}
500+
| }
501+
|}
502+
""".stripMargin
503+
}
496504
}
497505

498506
private def codegenLeftExistence(

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
215215
val df1 = spark.range(4).select($"id".as("k1"))
216216
val df2 = spark.range(3).select($"id".as("k2"))
217217
val df3 = spark.range(2).select($"id".as("k3"))
218+
val df4 = spark.range(0).select($"id".as("k4"))
218219

219220
Seq(true, false).foreach { codegenEnabled =>
220221
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
@@ -240,11 +241,19 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
240241
.join(df3, $"k1" <= $"k3", "left_outer")
241242
hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect {
242243
case WholeStageCodegenExec(BroadcastNestedLoopJoinExec(
243-
_: BroadcastNestedLoopJoinExec, _, _, _, _)) => true
244+
_: BroadcastNestedLoopJoinExec, _, _, _, _)) => true
244245
}.size === 1
245246
assert(hasJoinInCodegen == codegenEnabled)
246247
checkAnswer(twoJoinsDF,
247248
Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null)))
249+
250+
// test build side is empty
251+
val buildSideIsEmptyDF = df3.join(df4, $"k3" > $"k4", "left_outer")
252+
hasJoinInCodegen = buildSideIsEmptyDF.queryExecution.executedPlan.collect {
253+
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
254+
}.size === 1
255+
assert(hasJoinInCodegen == codegenEnabled)
256+
checkAnswer(buildSideIsEmptyDF, Seq(Row(0, null), Row(1, null)))
248257
}
249258
}
250259
}

0 commit comments

Comments
 (0)