-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-35351][SQL] Add code-gen for left anti sort merge join #32547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -364,7 +364,7 @@ case class SortMergeJoinExec( | |
} | ||
|
||
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match { | ||
case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys)) | ||
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => ((left, leftKeys), (right, rightKeys)) | ||
case RightOuter => ((right, rightKeys), (left, leftKeys)) | ||
case x => | ||
throw new IllegalArgumentException( | ||
|
@@ -375,7 +375,7 @@ case class SortMergeJoinExec( | |
private lazy val bufferedOutput = bufferedPlan.output | ||
|
||
override def supportCodegen: Boolean = joinType match { | ||
case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true | ||
case _: InnerLike | LeftOuter | RightOuter | LeftSemi | LeftAnti => true | ||
case _ => false | ||
} | ||
|
||
|
@@ -453,7 +453,7 @@ case class SortMergeJoinExec( | |
|$streamedRow = null; | ||
|continue; | ||
""".stripMargin | ||
case LeftOuter | RightOuter => | ||
case LeftOuter | RightOuter | LeftAnti => | ||
// Eagerly return streamed row. Only call `matches.clear()` when `matches.isEmpty()` is | ||
// false, to reduce unnecessary computation. | ||
s""" | ||
|
@@ -472,7 +472,7 @@ case class SortMergeJoinExec( | |
case _: InnerLike | LeftSemi => | ||
// Skip streamed row. | ||
s"$streamedRow = null;" | ||
case LeftOuter | RightOuter => | ||
case LeftOuter | RightOuter | LeftAnti => | ||
// Eagerly return with streamed row. | ||
"return false;" | ||
case x => | ||
|
@@ -509,17 +509,17 @@ case class SortMergeJoinExec( | |
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when | ||
// hitting the next `streamedRow` with non-null join | ||
// keys. | ||
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row, | ||
// and return false. | ||
// 2. Left/Right Outer and Left Anti join: clear the previous `matches` if needed, | ||
// keep the row, and return false. | ||
// | ||
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`. | ||
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches. | ||
// Use `bufferedRow` to iterate buffered side to put all matched rows into | ||
// `matches` (`addRowToBuffer`). Return true when getting all matched rows. | ||
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`): | ||
// 1. Inner and Left Semi join: skip the row. | ||
// 2. Left/Right Outer join: keep the row and return false (with `matches` being | ||
// empty). | ||
// 2. Left/Right Outer and Left Anti join: keep the row and return false (with | ||
// `matches` being empty). | ||
val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows") | ||
ctx.addNewFunction(findNextJoinRowsFuncName, | ||
s""" | ||
|
@@ -664,14 +664,14 @@ case class SortMergeJoinExec( | |
streamedVars ++ bufferedVars | ||
case RightOuter => | ||
bufferedVars ++ streamedVars | ||
case LeftSemi => | ||
case LeftSemi | LeftAnti => | ||
streamedVars | ||
case x => | ||
throw new IllegalArgumentException( | ||
s"SortMergeJoin.doProduce should not take $x as the JoinType") | ||
} | ||
|
||
val (streamedBeforeLoop, condCheck) = if (condition.isDefined) { | ||
val (streamedBeforeLoop, condCheck, loadStreamed) = if (condition.isDefined) { | ||
// Split the code of creating variables based on whether it's used by condition or not. | ||
val loaded = ctx.freshName("loaded") | ||
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars) | ||
|
@@ -680,13 +680,36 @@ case class SortMergeJoinExec( | |
ctx.currentVars = streamedVars ++ bufferedVars | ||
val cond = BindReferences.bindReference( | ||
condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx) | ||
// evaluate the columns those used by condition before loop | ||
// Evaluate the columns those used by condition before loop | ||
val before = | ||
s""" | ||
|boolean $loaded = false; | ||
|$streamedBefore | ||
""".stripMargin | ||
|
||
val loadStreamed = | ||
s""" | ||
|if (!$loaded) { | ||
| $loaded = true; | ||
| $streamedAfter | ||
|} | ||
""".stripMargin | ||
|
||
val loadStreamedAfterCondition = joinType match { | ||
case LeftAnti => | ||
// No need to evaluate columns not used by condition from streamed side, as for Left Anti | ||
// join, streamed row with match is not outputted. | ||
"" | ||
case _ => loadStreamed | ||
} | ||
|
||
val loadBufferedAfterCondition = joinType match { | ||
case LeftSemi | LeftAnti => | ||
// No need to evaluate columns not used by condition from buffered side | ||
"" | ||
case _ => bufferedAfter | ||
} | ||
|
||
val checking = | ||
s""" | ||
|$bufferedBefore | ||
|
@@ -696,15 +719,12 @@ case class SortMergeJoinExec( | |
| continue; | ||
| } | ||
|} | ||
|if (!$loaded) { | ||
| $loaded = true; | ||
| $streamedAfter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For left anti,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - good call for code size. Actually I just figured we don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - updated. Also avoid unnecessary code for |
||
|} | ||
|$bufferedAfter | ||
|$loadStreamedAfterCondition | ||
|$loadBufferedAfterCondition | ||
""".stripMargin | ||
(before, checking.trim) | ||
(before, checking.trim, loadStreamed) | ||
} else { | ||
(evaluateVariables(streamedVars), "") | ||
(evaluateVariables(streamedVars), "", "") | ||
} | ||
|
||
val beforeLoop = | ||
|
@@ -732,6 +752,9 @@ case class SortMergeJoinExec( | |
case LeftSemi => | ||
codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, | ||
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup) | ||
case LeftAnti => | ||
codegenAnti(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, | ||
loadStreamed, ctx.freshName("hasMatchedRow"), outputRow, eagerCleanup) | ||
case x => | ||
throw new IllegalArgumentException( | ||
s"SortMergeJoin.doProduce should not take $x as the JoinType") | ||
|
@@ -825,6 +848,44 @@ case class SortMergeJoinExec( | |
""".stripMargin | ||
} | ||
|
||
/** | ||
* Generates the code for Left Anti join. | ||
*/ | ||
private def codegenAnti( | ||
streamedInput: String, | ||
findNextJoinRows: String, | ||
beforeLoop: String, | ||
matchIterator: String, | ||
bufferedRow: String, | ||
conditionCheck: String, | ||
loadStreamed: String, | ||
hasMatchedRow: String, | ||
outputRow: String, | ||
eagerCleanup: String): String = { | ||
s""" | ||
|while ($streamedInput.hasNext()) { | ||
| $findNextJoinRows; | ||
| $beforeLoop | ||
| boolean $hasMatchedRow = false; | ||
| | ||
| while (!$hasMatchedRow && $matchIterator.hasNext()) { | ||
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next(); | ||
| $conditionCheck | ||
| $hasMatchedRow = true; | ||
| } | ||
| | ||
| if (!$hasMatchedRow) { | ||
| // load all values of streamed row, because the values not in join condition are not | ||
| // loaded yet. | ||
| $loadStreamed | ||
| $outputRow | ||
| } | ||
| if (shouldStop()) return; | ||
|} | ||
|$eagerCleanup | ||
""".stripMargin | ||
} | ||
|
||
override protected def withNewChildrenInternal( | ||
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec = | ||
copy(left = newLeft, right = newRight) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: seems
loaded
is not needed forLeftAnti
case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loadStreamed
is not used byLeftAnti
.I think you are referring to
boolean $loaded = false;
inbefore
should not be needed, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, looks like for
LeftAnti
, it doesn't rely onloaded
to dostreamedAfter
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created #32681 as followup, thanks.