Skip to content

Commit c1e995a

Browse files
c21cloud-fan
authored andcommitted
[SPARK-35350][SQL] Add code-gen for left semi sort merge join
### What changes were proposed in this pull request? As title. This PR is to add code-gen support for LEFT SEMI sort merge join. The main change is to add `semiJoin` code path in `SortMergeJoinExec.doProduce()` and introduce `onlyBufferFirstMatchedRow` in `SortMergeJoinExec.genScanner()`. The latter is for left semi sort merge join without condition. For this kind of query, we don't need to buffer all matched rows, but only the first one (this is same as non-code-gen code path). Example query: ``` val df1 = spark.range(10).select($"id".as("k1")) val df2 = spark.range(4).select($"id".as("k2")) val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi") ``` Example of generated code for the query: ``` == Subtree 5 / 5 (maxMethodCodeSize:302; maxConstantPoolSize:156(0.24% used); numInnerClasses:0) == *(5) Project [id#0L AS k1#2L] +- *(5) SortMergeJoin [id#0L], [k2#6L], LeftSemi :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#0L, 5), ENSURE_REQUIREMENTS, [id=#27] : +- *(1) Range (0, 10, step=1, splits=2) +- *(4) Sort [k2#6L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(k2#6L, 5), ENSURE_REQUIREMENTS, [id=#33] +- *(3) Project [id#4L AS k2#6L] +- *(3) Range (0, 4, step=1, splits=2) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage5(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=5 /* 006 */ final class GeneratedIteratorForCodegenStage5 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator smj_streamedInput_0; /* 010 */ private scala.collection.Iterator smj_bufferedInput_0; /* 011 */ private InternalRow smj_streamedRow_0; /* 012 */ private InternalRow smj_bufferedRow_0; /* 013 */ private long smj_value_2; /* 014 */ private org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray smj_matches_0; /* 015 */ private long smj_value_3; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] smj_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage5(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ smj_streamedInput_0 = inputs[0]; /* 026 */ smj_bufferedInput_0 = inputs[1]; /* 027 */ /* 028 */ smj_matches_0 = new org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray(1, 2147483647); /* 029 */ smj_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 030 */ smj_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 031 */ /* 032 */ } /* 033 */ /* 034 */ private boolean findNextJoinRows( /* 035 */ scala.collection.Iterator streamedIter, /* 036 */ scala.collection.Iterator bufferedIter) { /* 037 */ smj_streamedRow_0 = null; /* 038 */ int comp = 0; /* 039 */ while (smj_streamedRow_0 == null) { /* 040 */ if (!streamedIter.hasNext()) return false; /* 041 */ smj_streamedRow_0 = (InternalRow) streamedIter.next(); /* 042 */ long smj_value_0 = smj_streamedRow_0.getLong(0); /* 043 */ if (false) { /* 044 */ smj_streamedRow_0 = null; /* 045 */ continue; /* 046 */ /* 047 */ } /* 048 */ if (!smj_matches_0.isEmpty()) { /* 049 */ comp = 0; /* 050 */ if (comp == 0) { /* 051 */ comp = (smj_value_0 > smj_value_3 ? 1 : smj_value_0 < smj_value_3 ? -1 : 0); /* 052 */ } /* 053 */ /* 054 */ if (comp == 0) { /* 055 */ return true; /* 056 */ } /* 057 */ smj_matches_0.clear(); /* 058 */ } /* 059 */ /* 060 */ do { /* 061 */ if (smj_bufferedRow_0 == null) { /* 062 */ if (!bufferedIter.hasNext()) { /* 063 */ smj_value_3 = smj_value_0; /* 064 */ return !smj_matches_0.isEmpty(); /* 065 */ } /* 066 */ smj_bufferedRow_0 = (InternalRow) bufferedIter.next(); /* 067 */ long smj_value_1 = smj_bufferedRow_0.getLong(0); /* 068 */ if (false) { /* 069 */ smj_bufferedRow_0 = null; /* 070 */ continue; /* 071 */ } /* 072 */ smj_value_2 = smj_value_1; /* 073 */ } /* 074 */ /* 075 */ comp = 0; /* 076 */ if (comp == 0) { /* 077 */ comp = (smj_value_0 > smj_value_2 ? 1 : smj_value_0 < smj_value_2 ? -1 : 0); /* 078 */ } /* 079 */ /* 080 */ if (comp > 0) { /* 081 */ smj_bufferedRow_0 = null; /* 082 */ } else if (comp < 0) { /* 083 */ if (!smj_matches_0.isEmpty()) { /* 084 */ smj_value_3 = smj_value_0; /* 085 */ return true; /* 086 */ } else { /* 087 */ smj_streamedRow_0 = null; /* 088 */ } /* 089 */ } else { /* 090 */ if (smj_matches_0.isEmpty()) { /* 091 */ smj_matches_0.add((UnsafeRow) smj_bufferedRow_0); /* 092 */ } /* 093 */ /* 094 */ smj_bufferedRow_0 = null; /* 095 */ } /* 096 */ } while (smj_streamedRow_0 != null); /* 097 */ } /* 098 */ return false; // unreachable /* 099 */ } /* 100 */ /* 101 */ protected void processNext() throws java.io.IOException { /* 102 */ while (findNextJoinRows(smj_streamedInput_0, smj_bufferedInput_0)) { /* 103 */ long smj_value_4 = -1L; /* 104 */ smj_value_4 = smj_streamedRow_0.getLong(0); /* 105 */ scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator(); /* 106 */ boolean smj_hasOutputRow_0 = false; /* 107 */ /* 108 */ while (!smj_hasOutputRow_0 && smj_iterator_0.hasNext()) { /* 109 */ InternalRow smj_bufferedRow_1 = (InternalRow) smj_iterator_0.next(); /* 110 */ /* 111 */ smj_hasOutputRow_0 = true; /* 112 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 113 */ /* 114 */ // common sub-expressions /* 115 */ /* 116 */ smj_mutableStateArray_0[1].reset(); /* 117 */ /* 118 */ smj_mutableStateArray_0[1].write(0, smj_value_4); /* 119 */ append((smj_mutableStateArray_0[1].getRow()).copy()); /* 120 */ /* 121 */ } /* 122 */ if (shouldStop()) return; /* 123 */ } /* 124 */ ((org.apache.spark.sql.execution.joins.SortMergeJoinExec) references[1] /* plan */).cleanupResources(); /* 125 */ } /* 126 */ /* 127 */ } ``` ### Why are the changes needed? Improve query CPU performance. Test with one query: ``` def sortMergeJoin(): Unit = { val N = 2 << 20 codegenBenchmark("left semi sort merge join", N) { val df1 = spark.range(N).selectExpr(s"id * 2 as k1") val df2 = spark.range(N).selectExpr(s"id * 3 as k2") val df = df1.join(df2, col("k1") === col("k2"), "left_semi") assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) df.noop() } } ``` Seeing 30% of run-time improvement: ``` Running benchmark: left semi sort merge join Running case: left semi sort merge join code-gen off Stopped after 2 iterations, 1369 ms Running case: left semi sort merge join code-gen on Stopped after 5 iterations, 2743 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz left semi sort merge join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ left semi sort merge join code-gen off 676 685 13 3.1 322.2 1.0X left semi sort merge join code-gen on 524 549 32 4.0 249.7 1.3X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `WholeStageCodegenSuite.scala` and `ExistenceJoinSuite.scala`. Closes #32528 from c21/smj-left-semi. Authored-by: Cheng Su <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5181543 commit c1e995a

File tree

47 files changed

+1797
-1587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1797
-1587
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 150 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,18 @@ case class SortMergeJoinExec(
105105
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
106106
}
107107

108+
// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
109+
private val onlyBufferFirstMatchedRow = (joinType, condition) match {
110+
case (LeftExistence(_), None) => true
111+
case _ => false
112+
}
113+
108114
private def getInMemoryThreshold: Int = {
109-
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
115+
if (onlyBufferFirstMatchedRow) {
116+
1
117+
} else {
118+
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
119+
}
110120
}
111121

112122
protected override def doExecute(): RDD[InternalRow] = {
@@ -236,7 +246,7 @@ case class SortMergeJoinExec(
236246
inMemoryThreshold,
237247
spillThreshold,
238248
cleanupResources,
239-
condition.isEmpty
249+
onlyBufferFirstMatchedRow
240250
)
241251
private[this] val joinRow = new JoinedRow
242252

@@ -273,7 +283,7 @@ case class SortMergeJoinExec(
273283
inMemoryThreshold,
274284
spillThreshold,
275285
cleanupResources,
276-
condition.isEmpty
286+
onlyBufferFirstMatchedRow
277287
)
278288
private[this] val joinRow = new JoinedRow
279289

@@ -317,7 +327,7 @@ case class SortMergeJoinExec(
317327
inMemoryThreshold,
318328
spillThreshold,
319329
cleanupResources,
320-
condition.isEmpty
330+
onlyBufferFirstMatchedRow
321331
)
322332
private[this] val joinRow = new JoinedRow
323333

@@ -354,7 +364,7 @@ case class SortMergeJoinExec(
354364
}
355365

356366
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
357-
case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
367+
case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys))
358368
case RightOuter => ((right, rightKeys), (left, leftKeys))
359369
case x =>
360370
throw new IllegalArgumentException(
@@ -365,7 +375,7 @@ case class SortMergeJoinExec(
365375
private lazy val bufferedOutput = bufferedPlan.output
366376

367377
override def supportCodegen: Boolean = joinType match {
368-
case _: InnerLike | LeftOuter | RightOuter => true
378+
case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true
369379
case _ => false
370380
}
371381

@@ -435,7 +445,7 @@ case class SortMergeJoinExec(
435445

436446
// Handle the case when streamed rows has any NULL keys.
437447
val handleStreamedAnyNull = joinType match {
438-
case _: InnerLike =>
448+
case _: InnerLike | LeftSemi =>
439449
// Skip streamed row.
440450
s"""
441451
|$streamedRow = null;
@@ -457,7 +467,7 @@ case class SortMergeJoinExec(
457467

458468
// Handle the case when streamed keys has no match with buffered side.
459469
val handleStreamedWithoutMatch = joinType match {
460-
case _: InnerLike =>
470+
case _: InnerLike | LeftSemi =>
461471
// Skip streamed row.
462472
s"$streamedRow = null;"
463473
case LeftOuter | RightOuter =>
@@ -468,6 +478,17 @@ case class SortMergeJoinExec(
468478
s"SortMergeJoin.genScanner should not take $x as the JoinType")
469479
}
470480

481+
val addRowToBuffer =
482+
if (onlyBufferFirstMatchedRow) {
483+
s"""
484+
|if ($matches.isEmpty()) {
485+
| $matches.add((UnsafeRow) $bufferedRow);
486+
|}
487+
""".stripMargin
488+
} else {
489+
s"$matches.add((UnsafeRow) $bufferedRow);"
490+
}
491+
471492
// Generate a function to scan both streamed and buffered sides to find a match.
472493
// Return whether a match is found.
473494
//
@@ -483,17 +504,18 @@ case class SortMergeJoinExec(
483504
// The function has the following step:
484505
// - Step 1: Find the next `streamedRow` with non-null join keys.
485506
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
486-
// 1. Inner join: skip the row. `matches` will be cleared later when hitting the
487-
// next `streamedRow` with non-null join keys.
507+
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
508+
// hitting the next `streamedRow` with non-null join
509+
// keys.
488510
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
489511
// and return false.
490512
//
491513
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
492514
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
493515
// Use `bufferedRow` to iterate buffered side to put all matched rows into
494-
// `matches`. Return true when getting all matched rows.
516+
// `matches` (`addRowToBuffer`). Return true when getting all matched rows.
495517
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
496-
// 1. Inner join: skip the row.
518+
// 1. Inner and Left Semi join: skip the row.
497519
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
498520
// empty).
499521
ctx.addNewFunction("findNextJoinRows",
@@ -543,7 +565,7 @@ case class SortMergeJoinExec(
543565
| $handleStreamedWithoutMatch
544566
| }
545567
| } else {
546-
| $matches.add((UnsafeRow) $bufferedRow);
568+
| $addRowToBuffer
547569
| $bufferedRow = null;
548570
| }
549571
| } while ($streamedRow != null);
@@ -639,19 +661,22 @@ case class SortMergeJoinExec(
639661
streamedVars ++ bufferedVars
640662
case RightOuter =>
641663
bufferedVars ++ streamedVars
664+
case LeftSemi =>
665+
streamedVars
642666
case x =>
643667
throw new IllegalArgumentException(
644668
s"SortMergeJoin.doProduce should not take $x as the JoinType")
645669
}
646670

647-
val (beforeLoop, condCheck) = if (condition.isDefined) {
671+
val (streamedBeforeLoop, condCheck) = if (condition.isDefined) {
648672
// Split the code of creating variables based on whether it's used by condition or not.
649673
val loaded = ctx.freshName("loaded")
650674
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
651675
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
652676
// Generate code for condition
653-
ctx.currentVars = resultVars
654-
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
677+
ctx.currentVars = streamedVars ++ bufferedVars
678+
val cond = BindReferences.bindReference(
679+
condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
655680
// evaluate the columns those used by condition before loop
656681
val before =
657682
s"""
@@ -674,65 +699,129 @@ case class SortMergeJoinExec(
674699
|}
675700
|$bufferedAfter
676701
""".stripMargin
677-
(before, checking)
702+
(before, checking.trim)
678703
} else {
679704
(evaluateVariables(streamedVars), "")
680705
}
681706

682-
val thisPlan = ctx.addReferenceObj("plan", this)
683-
val eagerCleanup = s"$thisPlan.cleanupResources();"
684-
685-
lazy val innerJoin =
707+
val beforeLoop =
686708
s"""
687-
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
688-
| ${streamedVarDecl.mkString("\n")}
689-
| ${beforeLoop.trim}
690-
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
691-
| while ($iterator.hasNext()) {
692-
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
693-
| ${condCheck.trim}
694-
| $numOutput.add(1);
695-
| ${consume(ctx, resultVars)}
696-
| }
697-
| if (shouldStop()) return;
698-
|}
699-
|$eagerCleanup
700-
""".stripMargin
701-
702-
lazy val outerJoin = {
703-
val hasOutputRow = ctx.freshName("hasOutputRow")
709+
|${streamedVarDecl.mkString("\n")}
710+
|${streamedBeforeLoop.trim}
711+
|scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
712+
""".stripMargin
713+
val outputRow =
704714
s"""
705-
|while ($streamedInput.hasNext()) {
706-
| findNextJoinRows($streamedInput, $bufferedInput);
707-
| ${streamedVarDecl.mkString("\n")}
708-
| ${beforeLoop.trim}
709-
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
710-
| boolean $hasOutputRow = false;
711-
|
712-
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
713-
| while ($iterator.hasNext() || !$hasOutputRow) {
714-
| InternalRow $bufferedRow = $iterator.hasNext() ?
715-
| (InternalRow) $iterator.next() : null;
716-
| ${condCheck.trim}
717-
| $hasOutputRow = true;
718-
| $numOutput.add(1);
719-
| ${consume(ctx, resultVars)}
720-
| }
721-
| if (shouldStop()) return;
722-
|}
723-
|$eagerCleanup
715+
|$numOutput.add(1);
716+
|${consume(ctx, resultVars)}
724717
""".stripMargin
725-
}
718+
val findNextJoinRows = s"findNextJoinRows($streamedInput, $bufferedInput)"
719+
val thisPlan = ctx.addReferenceObj("plan", this)
720+
val eagerCleanup = s"$thisPlan.cleanupResources();"
726721

727722
joinType match {
728-
case _: InnerLike => innerJoin
729-
case LeftOuter | RightOuter => outerJoin
723+
case _: InnerLike =>
724+
codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
725+
eagerCleanup)
726+
case LeftOuter | RightOuter =>
727+
codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
728+
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
729+
case LeftSemi =>
730+
codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
731+
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
730732
case x =>
731733
throw new IllegalArgumentException(
732734
s"SortMergeJoin.doProduce should not take $x as the JoinType")
733735
}
734736
}
735737

738+
/**
739+
* Generates the code for Inner join.
740+
*/
741+
private def codegenInner(
742+
findNextJoinRows: String,
743+
beforeLoop: String,
744+
matchIterator: String,
745+
bufferedRow: String,
746+
conditionCheck: String,
747+
outputRow: String,
748+
eagerCleanup: String): String = {
749+
s"""
750+
|while ($findNextJoinRows) {
751+
| $beforeLoop
752+
| while ($matchIterator.hasNext()) {
753+
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
754+
| $conditionCheck
755+
| $outputRow
756+
| }
757+
| if (shouldStop()) return;
758+
|}
759+
|$eagerCleanup
760+
""".stripMargin
761+
}
762+
763+
/**
764+
* Generates the code for Left or Right Outer join.
765+
*/
766+
private def codegenOuter(
767+
streamedInput: String,
768+
findNextJoinRows: String,
769+
beforeLoop: String,
770+
matchIterator: String,
771+
bufferedRow: String,
772+
conditionCheck: String,
773+
hasOutputRow: String,
774+
outputRow: String,
775+
eagerCleanup: String): String = {
776+
s"""
777+
|while ($streamedInput.hasNext()) {
778+
| $findNextJoinRows;
779+
| $beforeLoop
780+
| boolean $hasOutputRow = false;
781+
|
782+
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
783+
| while ($matchIterator.hasNext() || !$hasOutputRow) {
784+
| InternalRow $bufferedRow = $matchIterator.hasNext() ?
785+
| (InternalRow) $matchIterator.next() : null;
786+
| $conditionCheck
787+
| $hasOutputRow = true;
788+
| $outputRow
789+
| }
790+
| if (shouldStop()) return;
791+
|}
792+
|$eagerCleanup
793+
""".stripMargin
794+
}
795+
796+
/**
797+
* Generates the code for Left Semi join.
798+
*/
799+
private def codegenSemi(
800+
findNextJoinRows: String,
801+
beforeLoop: String,
802+
matchIterator: String,
803+
bufferedRow: String,
804+
conditionCheck: String,
805+
hasOutputRow: String,
806+
outputRow: String,
807+
eagerCleanup: String): String = {
808+
s"""
809+
|while ($findNextJoinRows) {
810+
| $beforeLoop
811+
| boolean $hasOutputRow = false;
812+
|
813+
| while (!$hasOutputRow && $matchIterator.hasNext()) {
814+
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
815+
| $conditionCheck
816+
| $hasOutputRow = true;
817+
| $outputRow
818+
| }
819+
| if (shouldStop()) return;
820+
|}
821+
|$eagerCleanup
822+
""".stripMargin
823+
}
824+
736825
override protected def withNewChildrenInternal(
737826
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
738827
copy(left = newLeft, right = newRight)
@@ -783,8 +872,7 @@ private[joins] class SortMergeJoinScanner(
783872
private[this] var matchJoinKey: InternalRow = _
784873
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
785874
private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
786-
new ExternalAppendOnlyUnsafeRowArray(if (onlyBufferFirstMatch) 1 else inMemoryThreshold,
787-
spillThreshold)
875+
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
788876

789877
// Initialization (note: do _not_ want to advance streamed here).
790878
advancedBufferedToRowWithNullFreeJoinKey()

0 commit comments

Comments
 (0)