@@ -105,8 +105,18 @@ case class SortMergeJoinExec(
105
105
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
106
106
}
107
107
108
+ // Flag to only buffer first matched row, to avoid buffering unnecessary rows.
109
+ private val onlyBufferFirstMatchedRow = (joinType, condition) match {
110
+ case (LeftExistence (_), None ) => true
111
+ case _ => false
112
+ }
113
+
108
114
private def getInMemoryThreshold : Int = {
109
- sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
115
+ if (onlyBufferFirstMatchedRow) {
116
+ 1
117
+ } else {
118
+ sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
119
+ }
110
120
}
111
121
112
122
protected override def doExecute (): RDD [InternalRow ] = {
@@ -236,7 +246,7 @@ case class SortMergeJoinExec(
236
246
inMemoryThreshold,
237
247
spillThreshold,
238
248
cleanupResources,
239
- condition.isEmpty
249
+ onlyBufferFirstMatchedRow
240
250
)
241
251
private [this ] val joinRow = new JoinedRow
242
252
@@ -273,7 +283,7 @@ case class SortMergeJoinExec(
273
283
inMemoryThreshold,
274
284
spillThreshold,
275
285
cleanupResources,
276
- condition.isEmpty
286
+ onlyBufferFirstMatchedRow
277
287
)
278
288
private [this ] val joinRow = new JoinedRow
279
289
@@ -317,7 +327,7 @@ case class SortMergeJoinExec(
317
327
inMemoryThreshold,
318
328
spillThreshold,
319
329
cleanupResources,
320
- condition.isEmpty
330
+ onlyBufferFirstMatchedRow
321
331
)
322
332
private [this ] val joinRow = new JoinedRow
323
333
@@ -354,7 +364,7 @@ case class SortMergeJoinExec(
354
364
}
355
365
356
366
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
357
- case _ : InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
367
+ case _ : InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys))
358
368
case RightOuter => ((right, rightKeys), (left, leftKeys))
359
369
case x =>
360
370
throw new IllegalArgumentException (
@@ -365,7 +375,7 @@ case class SortMergeJoinExec(
365
375
private lazy val bufferedOutput = bufferedPlan.output
366
376
367
377
override def supportCodegen : Boolean = joinType match {
368
- case _ : InnerLike | LeftOuter | RightOuter => true
378
+ case _ : InnerLike | LeftOuter | RightOuter | LeftSemi => true
369
379
case _ => false
370
380
}
371
381
@@ -435,7 +445,7 @@ case class SortMergeJoinExec(
435
445
436
446
// Handle the case when streamed rows has any NULL keys.
437
447
val handleStreamedAnyNull = joinType match {
438
- case _ : InnerLike =>
448
+ case _ : InnerLike | LeftSemi =>
439
449
// Skip streamed row.
440
450
s """
441
451
| $streamedRow = null;
@@ -457,7 +467,7 @@ case class SortMergeJoinExec(
457
467
458
468
// Handle the case when streamed keys has no match with buffered side.
459
469
val handleStreamedWithoutMatch = joinType match {
460
- case _ : InnerLike =>
470
+ case _ : InnerLike | LeftSemi =>
461
471
// Skip streamed row.
462
472
s " $streamedRow = null; "
463
473
case LeftOuter | RightOuter =>
@@ -468,6 +478,17 @@ case class SortMergeJoinExec(
468
478
s " SortMergeJoin.genScanner should not take $x as the JoinType " )
469
479
}
470
480
481
+ val addRowToBuffer =
482
+ if (onlyBufferFirstMatchedRow) {
483
+ s """
484
+ |if ( $matches.isEmpty()) {
485
+ | $matches.add((UnsafeRow) $bufferedRow);
486
+ |}
487
+ """ .stripMargin
488
+ } else {
489
+ s " $matches.add((UnsafeRow) $bufferedRow); "
490
+ }
491
+
471
492
// Generate a function to scan both streamed and buffered sides to find a match.
472
493
// Return whether a match is found.
473
494
//
@@ -483,17 +504,18 @@ case class SortMergeJoinExec(
483
504
// The function has the following step:
484
505
// - Step 1: Find the next `streamedRow` with non-null join keys.
485
506
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
486
- // 1. Inner join: skip the row. `matches` will be cleared later when hitting the
487
- // next `streamedRow` with non-null join keys.
507
+ // 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
508
+ // hitting the next `streamedRow` with non-null join
509
+ // keys.
488
510
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
489
511
// and return false.
490
512
//
491
513
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
492
514
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
493
515
// Use `bufferedRow` to iterate buffered side to put all matched rows into
494
- // `matches`. Return true when getting all matched rows.
516
+ // `matches` (`addRowToBuffer`) . Return true when getting all matched rows.
495
517
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
496
- // 1. Inner join: skip the row.
518
+ // 1. Inner and Left Semi join: skip the row.
497
519
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
498
520
// empty).
499
521
ctx.addNewFunction(" findNextJoinRows" ,
@@ -543,7 +565,7 @@ case class SortMergeJoinExec(
543
565
| $handleStreamedWithoutMatch
544
566
| }
545
567
| } else {
546
- | $matches .add((UnsafeRow) $bufferedRow );
568
+ | $addRowToBuffer
547
569
| $bufferedRow = null;
548
570
| }
549
571
| } while ( $streamedRow != null);
@@ -639,19 +661,22 @@ case class SortMergeJoinExec(
639
661
streamedVars ++ bufferedVars
640
662
case RightOuter =>
641
663
bufferedVars ++ streamedVars
664
+ case LeftSemi =>
665
+ streamedVars
642
666
case x =>
643
667
throw new IllegalArgumentException (
644
668
s " SortMergeJoin.doProduce should not take $x as the JoinType " )
645
669
}
646
670
647
- val (beforeLoop , condCheck) = if (condition.isDefined) {
671
+ val (streamedBeforeLoop , condCheck) = if (condition.isDefined) {
648
672
// Split the code of creating variables based on whether it's used by condition or not.
649
673
val loaded = ctx.freshName(" loaded" )
650
674
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
651
675
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
652
676
// Generate code for condition
653
- ctx.currentVars = resultVars
654
- val cond = BindReferences .bindReference(condition.get, output).genCode(ctx)
677
+ ctx.currentVars = streamedVars ++ bufferedVars
678
+ val cond = BindReferences .bindReference(
679
+ condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
655
680
// evaluate the columns those used by condition before loop
656
681
val before =
657
682
s """
@@ -674,65 +699,129 @@ case class SortMergeJoinExec(
674
699
|}
675
700
| $bufferedAfter
676
701
""" .stripMargin
677
- (before, checking)
702
+ (before, checking.trim )
678
703
} else {
679
704
(evaluateVariables(streamedVars), " " )
680
705
}
681
706
682
- val thisPlan = ctx.addReferenceObj(" plan" , this )
683
- val eagerCleanup = s " $thisPlan.cleanupResources(); "
684
-
685
- lazy val innerJoin =
707
+ val beforeLoop =
686
708
s """
687
- |while (findNextJoinRows( $streamedInput, $bufferedInput)) {
688
- | ${streamedVarDecl.mkString(" \n " )}
689
- | ${beforeLoop.trim}
690
- | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
691
- | while ( $iterator.hasNext()) {
692
- | InternalRow $bufferedRow = (InternalRow) $iterator.next();
693
- | ${condCheck.trim}
694
- | $numOutput.add(1);
695
- | ${consume(ctx, resultVars)}
696
- | }
697
- | if (shouldStop()) return;
698
- |}
699
- | $eagerCleanup
700
- """ .stripMargin
701
-
702
- lazy val outerJoin = {
703
- val hasOutputRow = ctx.freshName(" hasOutputRow" )
709
+ | ${streamedVarDecl.mkString(" \n " )}
710
+ | ${streamedBeforeLoop.trim}
711
+ |scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
712
+ """ .stripMargin
713
+ val outputRow =
704
714
s """
705
- |while ( $streamedInput.hasNext()) {
706
- | findNextJoinRows( $streamedInput, $bufferedInput);
707
- | ${streamedVarDecl.mkString(" \n " )}
708
- | ${beforeLoop.trim}
709
- | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
710
- | boolean $hasOutputRow = false;
711
- |
712
- | // the last iteration of this loop is to emit an empty row if there is no matched rows.
713
- | while ( $iterator.hasNext() || ! $hasOutputRow) {
714
- | InternalRow $bufferedRow = $iterator.hasNext() ?
715
- | (InternalRow) $iterator.next() : null;
716
- | ${condCheck.trim}
717
- | $hasOutputRow = true;
718
- | $numOutput.add(1);
719
- | ${consume(ctx, resultVars)}
720
- | }
721
- | if (shouldStop()) return;
722
- |}
723
- | $eagerCleanup
715
+ | $numOutput.add(1);
716
+ | ${consume(ctx, resultVars)}
724
717
""" .stripMargin
725
- }
718
+ val findNextJoinRows = s " findNextJoinRows( $streamedInput, $bufferedInput) "
719
+ val thisPlan = ctx.addReferenceObj(" plan" , this )
720
+ val eagerCleanup = s " $thisPlan.cleanupResources(); "
726
721
727
722
joinType match {
728
- case _ : InnerLike => innerJoin
729
- case LeftOuter | RightOuter => outerJoin
723
+ case _ : InnerLike =>
724
+ codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
725
+ eagerCleanup)
726
+ case LeftOuter | RightOuter =>
727
+ codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
728
+ ctx.freshName(" hasOutputRow" ), outputRow, eagerCleanup)
729
+ case LeftSemi =>
730
+ codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
731
+ ctx.freshName(" hasOutputRow" ), outputRow, eagerCleanup)
730
732
case x =>
731
733
throw new IllegalArgumentException (
732
734
s " SortMergeJoin.doProduce should not take $x as the JoinType " )
733
735
}
734
736
}
735
737
738
+ /**
739
+ * Generates the code for Inner join.
740
+ */
741
+ private def codegenInner (
742
+ findNextJoinRows : String ,
743
+ beforeLoop : String ,
744
+ matchIterator : String ,
745
+ bufferedRow : String ,
746
+ conditionCheck : String ,
747
+ outputRow : String ,
748
+ eagerCleanup : String ): String = {
749
+ s """
750
+ |while ( $findNextJoinRows) {
751
+ | $beforeLoop
752
+ | while ( $matchIterator.hasNext()) {
753
+ | InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
754
+ | $conditionCheck
755
+ | $outputRow
756
+ | }
757
+ | if (shouldStop()) return;
758
+ |}
759
+ | $eagerCleanup
760
+ """ .stripMargin
761
+ }
762
+
763
+ /**
764
+ * Generates the code for Left or Right Outer join.
765
+ */
766
+ private def codegenOuter (
767
+ streamedInput : String ,
768
+ findNextJoinRows : String ,
769
+ beforeLoop : String ,
770
+ matchIterator : String ,
771
+ bufferedRow : String ,
772
+ conditionCheck : String ,
773
+ hasOutputRow : String ,
774
+ outputRow : String ,
775
+ eagerCleanup : String ): String = {
776
+ s """
777
+ |while ( $streamedInput.hasNext()) {
778
+ | $findNextJoinRows;
779
+ | $beforeLoop
780
+ | boolean $hasOutputRow = false;
781
+ |
782
+ | // the last iteration of this loop is to emit an empty row if there is no matched rows.
783
+ | while ( $matchIterator.hasNext() || ! $hasOutputRow) {
784
+ | InternalRow $bufferedRow = $matchIterator.hasNext() ?
785
+ | (InternalRow) $matchIterator.next() : null;
786
+ | $conditionCheck
787
+ | $hasOutputRow = true;
788
+ | $outputRow
789
+ | }
790
+ | if (shouldStop()) return;
791
+ |}
792
+ | $eagerCleanup
793
+ """ .stripMargin
794
+ }
795
+
796
+ /**
797
+ * Generates the code for Left Semi join.
798
+ */
799
+ private def codegenSemi (
800
+ findNextJoinRows : String ,
801
+ beforeLoop : String ,
802
+ matchIterator : String ,
803
+ bufferedRow : String ,
804
+ conditionCheck : String ,
805
+ hasOutputRow : String ,
806
+ outputRow : String ,
807
+ eagerCleanup : String ): String = {
808
+ s """
809
+ |while ( $findNextJoinRows) {
810
+ | $beforeLoop
811
+ | boolean $hasOutputRow = false;
812
+ |
813
+ | while (! $hasOutputRow && $matchIterator.hasNext()) {
814
+ | InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
815
+ | $conditionCheck
816
+ | $hasOutputRow = true;
817
+ | $outputRow
818
+ | }
819
+ | if (shouldStop()) return;
820
+ |}
821
+ | $eagerCleanup
822
+ """ .stripMargin
823
+ }
824
+
736
825
override protected def withNewChildrenInternal (
737
826
newLeft : SparkPlan , newRight : SparkPlan ): SortMergeJoinExec =
738
827
copy(left = newLeft, right = newRight)
@@ -783,8 +872,7 @@ private[joins] class SortMergeJoinScanner(
783
872
private [this ] var matchJoinKey : InternalRow = _
784
873
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
785
874
private [this ] val bufferedMatches : ExternalAppendOnlyUnsafeRowArray =
786
- new ExternalAppendOnlyUnsafeRowArray (if (onlyBufferFirstMatch) 1 else inMemoryThreshold,
787
- spillThreshold)
875
+ new ExternalAppendOnlyUnsafeRowArray (inMemoryThreshold, spillThreshold)
788
876
789
877
// Initialization (note: do _not_ want to advance streamed here).
790
878
advancedBufferedToRowWithNullFreeJoinKey()
0 commit comments