Skip to content

Commit ad13ae3

Browse files
Guo Chenzhaoluzhonghao
authored andcommitted
Support Left Anti Join in data skew feature (apache#62)
1 parent 933d6ca commit ad13ae3

File tree

2 files changed

+151
-3
lines changed

2 files changed

+151
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import org.apache.spark.sql.internal.SQLConf
2929

3030
case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
3131

32-
private val supportedJoinTypes = Inner :: Cross :: LeftSemi :: LeftOuter:: RightOuter :: Nil
32+
private val supportedJoinTypes =
33+
Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil
3334

3435
private def isSizeSkewed(size: Long, medianSize: Long): Boolean = {
3536
size > medianSize * conf.adaptiveSkewedFactor &&
@@ -116,7 +117,7 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
116117
private def supportSplitOnLeftPartition(joinType: JoinType) = joinType != RightOuter
117118

118119
private def supportSplitOnRightPartition(joinType: JoinType) = {
119-
joinType != LeftOuter && joinType != LeftSemi
120+
joinType != LeftOuter && joinType != LeftSemi && joinType != LeftAnti
120121
}
121122

122123
private def handleSkewedJoin(

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
544544
val expectedAnswerForRightOuter =
545545
spark
546546
.range(0, 100)
547-
.flatMap(i => Seq.fill(100)(i))
547+
.flatMap(i => Seq.fill(100)(i))
548548
.selectExpr("0 as key", "value")
549549
checkAnswer(
550550
rightOuterJoin,
@@ -578,6 +578,153 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
578578
}
579579
}
580580

581+
test("adaptive skewed join: left semi/anti join and skewed on right side") {
582+
val spark = defaultSparkSession
583+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false")
584+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true")
585+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10)
586+
withSparkSession(spark) { spark: SparkSession =>
587+
val df1 =
588+
spark
589+
.range(0, 10, 1, 2)
590+
.selectExpr("id % 5 as key1", "id as value1")
591+
val df2 =
592+
spark
593+
.range(0, 1000, 1, numInputPartitions)
594+
.selectExpr("id % 1 as key2", "id as value2")
595+
596+
val leftSemiJoin =
597+
df1.join(df2, col("key1") === col("key2"), "left_semi").select(col("key1"), col("value1"))
598+
val leftAntiJoin =
599+
df1.join(df2, col("key1") === col("key2"), "left_anti").select(col("key1"), col("value1"))
600+
601+
// Before Execution, there is one SortMergeJoin
602+
val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
603+
case smj: SortMergeJoinExec => smj
604+
}
605+
assert(smjBeforeExecutionForLeftSemi.length === 1)
606+
607+
val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect {
608+
case smj: SortMergeJoinExec => smj
609+
}
610+
assert(smjBeforeExecutionForLeftAnti.length === 1)
611+
612+
// Check the answer.
613+
val expectedAnswerForLeftSemi =
614+
spark
615+
.range(0, 10)
616+
.filter(_ % 5 == 0)
617+
.selectExpr("id % 5 as key", "id as value")
618+
checkAnswer(
619+
leftSemiJoin,
620+
expectedAnswerForLeftSemi.collect())
621+
622+
val expectedAnswerForLeftAnti =
623+
spark
624+
.range(0, 10)
625+
.filter(_ % 5 != 0)
626+
.selectExpr("id % 5 as key", "id as value")
627+
checkAnswer(
628+
leftAntiJoin,
629+
expectedAnswerForLeftAnti.collect())
630+
631+
// For the left outer join case: during execution, the SMJ can not be translated to any sub
632+
// joins due to the skewed side is on the right but the join type is left semi
633+
// (not correspond with each other)
634+
val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
635+
case smj: SortMergeJoinExec => smj
636+
}
637+
assert(smjAfterExecutionForLeftSemi.length === 1)
638+
639+
// For the right outer join case: during execution, the SMJ can not be translated to any sub
640+
// joins due to the skewed side is on the right but the join type is left anti
641+
// (not correspond with each other)
642+
val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect {
643+
case smj: SortMergeJoinExec => smj
644+
}
645+
assert(smjAfterExecutionForLeftAnti.length === 1)
646+
647+
}
648+
}
649+
650+
test("adaptive skewed join: left semi/anti join and skewed on left side") {
651+
val spark = defaultSparkSession
652+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false")
653+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true")
654+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10)
655+
val MAX_SPLIT = 5
656+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS.key, MAX_SPLIT)
657+
withSparkSession(spark) { spark: SparkSession =>
658+
val df1 =
659+
spark
660+
.range(0, 1000, 1, numInputPartitions)
661+
.selectExpr("id % 1 as key1", "id as value1")
662+
val df2 =
663+
spark
664+
.range(0, 10, 1, 2)
665+
.selectExpr("id % 5 as key2", "id as value2")
666+
667+
val leftSemiJoin =
668+
df1.join(df2, col("key1") === col("key2"), "left_semi").select(col("key1"), col("value1"))
669+
val leftAntiJoin =
670+
df1.join(df2, col("key1") === col("key2"), "left_anti").select(col("key1"), col("value1"))
671+
672+
// Before Execution, there is one SortMergeJoin
673+
val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
674+
case smj: SortMergeJoinExec => smj
675+
}
676+
assert(smjBeforeExecutionForLeftSemi.length === 1)
677+
678+
val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect {
679+
case smj: SortMergeJoinExec => smj
680+
}
681+
assert(smjBeforeExecutionForLeftAnti.length === 1)
682+
683+
// Check the answer.
684+
val expectedAnswerForLeftSemi =
685+
spark
686+
.range(0, 1000)
687+
.selectExpr("id % 1 as key", "id as value")
688+
checkAnswer(
689+
leftSemiJoin,
690+
expectedAnswerForLeftSemi.collect())
691+
692+
val expectedAnswerForLeftAnti = Seq.empty
693+
checkAnswer(
694+
leftAntiJoin,
695+
expectedAnswerForLeftAnti)
696+
697+
// For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
698+
// joins due to the skewed side is on the left and the join type is left semi
699+
// (correspond with each other)
700+
val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect {
701+
case smj: SortMergeJoinExec => smj
702+
}
703+
assert(smjAfterExecutionForLeftSemi.length === MAX_SPLIT + 1)
704+
705+
// For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
706+
// joins due to the skewed side is on the left and the join type is left anti
707+
// (correspond with each other)
708+
val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect {
709+
case smj: SortMergeJoinExec => smj
710+
}
711+
assert(smjAfterExecutionForLeftAnti.length === MAX_SPLIT + 1)
712+
713+
val queryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect {
714+
case q: ShuffleQueryStageInput => q
715+
}
716+
assert(queryStageInputs.length === 2)
717+
assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions)
718+
assert(queryStageInputs(0).skewedPartitions === Some(Set(0)))
719+
720+
val skewedQueryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect {
721+
case q: SkewedShuffleQueryStageInput => q
722+
}
723+
assert(skewedQueryStageInputs.length === MAX_SPLIT * 2)
724+
725+
}
726+
}
727+
581728
test("row count statistics, compressed") {
582729
val spark = defaultSparkSession
583730
withSparkSession(spark) { spark: SparkSession =>

0 commit comments

Comments
 (0)