Skip to content

Commit 89406d3

Browse files
Yucai YuJkSelf
authored andcommitted
Fix exception: Child of ShuffleQueryStage must be a ShuffleExchange (apache#52)
* Fix exception: Child of ShuffleQueryStage must be a ShuffleExchange * top ShuffleExchange of QueryStage should not be removed anyway * remove unecessary parentheses * check top shuffle exchange for ShuffleQueryStage only * minor comments * improve topShuffleCheck * simplfy codes
1 parent 36a935e commit 89406d3

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
140140
removeSort(right))
141141

142142
val newChild = queryStage.child.transformDown {
143-
case s: SortMergeJoinExec if (s.fastEquals(smj)) => broadcastJoin
143+
case s: SortMergeJoinExec if s.fastEquals(smj) => broadcastJoin
144144
}
145145

146146
val broadcastSidePlan = buildSide match {
147-
case BuildLeft => (removeSort(left))
148-
case BuildRight => (removeSort(right))
147+
case BuildLeft => removeSort(left)
148+
case BuildRight => removeSort(right)
149149
}
150150
// Local shuffle read less partitions based on broadcastSide's row statistics
151151
joinType match {
@@ -163,8 +163,14 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
163163
case e: ShuffleExchangeExec => e
164164
}.length
165165

166-
if (conf.adaptiveAllowAdditionShuffle || numExchanges == 0 ||
167-
(queryStage.isInstanceOf[ShuffleQueryStage] && numExchanges <= 1)) {
166+
val topShuffleCheck = queryStage match {
167+
case _: ShuffleQueryStage => afterEnsureRequirements.isInstanceOf[ShuffleExchangeExec]
168+
case _ => true
169+
}
170+
val allowAdditionalShuffle = conf.adaptiveAllowAdditionShuffle
171+
val noAdditionalShuffle = (numExchanges == 0) ||
172+
(queryStage.isInstanceOf[ShuffleQueryStage] && numExchanges <= 1)
173+
if (topShuffleCheck && (allowAdditionalShuffle || noAdditionalShuffle)) {
168174
// Update the plan in queryStage
169175
queryStage.child = newChild
170176
broadcastJoin

0 commit comments

Comments
 (0)