@@ -544,7 +544,7 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
544
544
val expectedAnswerForRightOuter =
545
545
spark
546
546
.range(0 , 100 )
547
- .flatMap(i => Seq .fill(100 )(i))
547
+ .flatMap(i => Seq .fill(100 )(i))
548
548
.selectExpr(" 0 as key" , " value" )
549
549
checkAnswer(
550
550
rightOuterJoin,
@@ -578,6 +578,153 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
578
578
}
579
579
}
580
580
581
+ test(" adaptive skewed join: left semi/anti join and skewed on right side" ) {
582
+ val spark = defaultSparkSession
583
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
584
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
585
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
586
+ withSparkSession(spark) { spark : SparkSession =>
587
+ val df1 =
588
+ spark
589
+ .range(0 , 10 , 1 , 2 )
590
+ .selectExpr(" id % 5 as key1" , " id as value1" )
591
+ val df2 =
592
+ spark
593
+ .range(0 , 1000 , 1 , numInputPartitions)
594
+ .selectExpr(" id % 1 as key2" , " id as value2" )
595
+
596
+ val leftSemiJoin =
597
+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_semi" ).select(col(" key1" ), col(" value1" ))
598
+ val leftAntiJoin =
599
+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_anti" ).select(col(" key1" ), col(" value1" ))
600
+
601
+ // Before Execution, there is one SortMergeJoin
602
+ val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
603
+ case smj : SortMergeJoinExec => smj
604
+ }
605
+ assert(smjBeforeExecutionForLeftSemi.length === 1 )
606
+
607
+ val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect {
608
+ case smj : SortMergeJoinExec => smj
609
+ }
610
+ assert(smjBeforeExecutionForLeftAnti.length === 1 )
611
+
612
+ // Check the answer.
613
+ val expectedAnswerForLeftSemi =
614
+ spark
615
+ .range(0 , 10 )
616
+ .filter(_ % 5 == 0 )
617
+ .selectExpr(" id % 5 as key" , " id as value" )
618
+ checkAnswer(
619
+ leftSemiJoin,
620
+ expectedAnswerForLeftSemi.collect())
621
+
622
+ val expectedAnswerForLeftAnti =
623
+ spark
624
+ .range(0 , 10 )
625
+ .filter(_ % 5 != 0 )
626
+ .selectExpr(" id % 5 as key" , " id as value" )
627
+ checkAnswer(
628
+ leftAntiJoin,
629
+ expectedAnswerForLeftAnti.collect())
630
+
631
+ // For the left outer join case: during execution, the SMJ can not be translated to any sub
632
+ // joins due to the skewed side is on the right but the join type is left semi
633
+ // (not correspond with each other)
634
+ val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
635
+ case smj : SortMergeJoinExec => smj
636
+ }
637
+ assert(smjAfterExecutionForLeftSemi.length === 1 )
638
+
639
+ // For the right outer join case: during execution, the SMJ can not be translated to any sub
640
+ // joins due to the skewed side is on the right but the join type is left anti
641
+ // (not correspond with each other)
642
+ val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect {
643
+ case smj : SortMergeJoinExec => smj
644
+ }
645
+ assert(smjAfterExecutionForLeftAnti.length === 1 )
646
+
647
+ }
648
+ }
649
+
650
+ test(" adaptive skewed join: left semi/anti join and skewed on left side" ) {
651
+ val spark = defaultSparkSession
652
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
653
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
654
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
655
+ val MAX_SPLIT = 5
656
+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS .key, MAX_SPLIT )
657
+ withSparkSession(spark) { spark : SparkSession =>
658
+ val df1 =
659
+ spark
660
+ .range(0 , 1000 , 1 , numInputPartitions)
661
+ .selectExpr(" id % 1 as key1" , " id as value1" )
662
+ val df2 =
663
+ spark
664
+ .range(0 , 10 , 1 , 2 )
665
+ .selectExpr(" id % 5 as key2" , " id as value2" )
666
+
667
+ val leftSemiJoin =
668
+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_semi" ).select(col(" key1" ), col(" value1" ))
669
+ val leftAntiJoin =
670
+ df1.join(df2, col(" key1" ) === col(" key2" ), " left_anti" ).select(col(" key1" ), col(" value1" ))
671
+
672
+ // Before Execution, there is one SortMergeJoin
673
+ val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
674
+ case smj : SortMergeJoinExec => smj
675
+ }
676
+ assert(smjBeforeExecutionForLeftSemi.length === 1 )
677
+
678
+ val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect {
679
+ case smj : SortMergeJoinExec => smj
680
+ }
681
+ assert(smjBeforeExecutionForLeftAnti.length === 1 )
682
+
683
+ // Check the answer.
684
+ val expectedAnswerForLeftSemi =
685
+ spark
686
+ .range(0 , 1000 )
687
+ .selectExpr(" id % 1 as key" , " id as value" )
688
+ checkAnswer(
689
+ leftSemiJoin,
690
+ expectedAnswerForLeftSemi.collect())
691
+
692
+ val expectedAnswerForLeftAnti = Seq .empty
693
+ checkAnswer(
694
+ leftAntiJoin,
695
+ expectedAnswerForLeftAnti)
696
+
697
+ // For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
698
+ // joins due to the skewed side is on the left and the join type is left semi
699
+ // (correspond with each other)
700
+ val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
701
+ case smj : SortMergeJoinExec => smj
702
+ }
703
+ assert(smjAfterExecutionForLeftSemi.length === MAX_SPLIT + 1 )
704
+
705
+ // For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
706
+ // joins due to the skewed side is on the left and the join type is left anti
707
+ // (correspond with each other)
708
+ val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect {
709
+ case smj : SortMergeJoinExec => smj
710
+ }
711
+ assert(smjAfterExecutionForLeftAnti.length === MAX_SPLIT + 1 )
712
+
713
+ val queryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect {
714
+ case q : ShuffleQueryStageInput => q
715
+ }
716
+ assert(queryStageInputs.length === 2 )
717
+ assert(queryStageInputs(0 ).skewedPartitions === queryStageInputs(1 ).skewedPartitions)
718
+ assert(queryStageInputs(0 ).skewedPartitions === Some (Set (0 )))
719
+
720
+ val skewedQueryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect {
721
+ case q : SkewedShuffleQueryStageInput => q
722
+ }
723
+ assert(skewedQueryStageInputs.length === MAX_SPLIT * 2 )
724
+
725
+ }
726
+ }
727
+
581
728
test(" row count statistics, compressed" ) {
582
729
val spark = defaultSparkSession
583
730
withSparkSession(spark) { spark : SparkSession =>
0 commit comments