Skip to content

Commit 1b9001f

Browse files
committed
[SPARK-3409][SQL] Avoid pulling in Exchange operator itself in Exchange's closures.
This is a tiny teeny optimization to move the if check of sortBasedShuffledOn to outside the closures so the closures don't need to pull in the entire Exchange operator object. Author: Reynold Xin <[email protected]> Closes #2282 from rxin/SPARK-3409 and squashes the following commits: 1de3f88 [Reynold Xin] [SPARK-3409][SQL] Avoid pulling in Exchange operator itself in Exchange's closures.
1 parent 9422c4e commit 1b9001f

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
3636

3737
override def outputPartitioning = newPartitioning
3838

39-
def output = child.output
39+
override def output = child.output
4040

4141
/** We must copy rows when sort based shuffle is on */
4242
protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
4343

44-
def execute() = attachTree(this , "execute") {
44+
override def execute() = attachTree(this , "execute") {
4545
newPartitioning match {
4646
case HashPartitioning(expressions, numPartitions) =>
4747
// TODO: Eliminate redundant expressions in grouping key and value.
48-
val rdd = child.execute().mapPartitions { iter =>
49-
if (sortBasedShuffleOn) {
50-
@transient val hashExpressions =
51-
newProjection(expressions, child.output)
52-
48+
val rdd = if (sortBasedShuffleOn) {
49+
child.execute().mapPartitions { iter =>
50+
val hashExpressions = newProjection(expressions, child.output)
5351
iter.map(r => (hashExpressions(r), r.copy()))
54-
} else {
55-
@transient val hashExpressions =
56-
newMutableProjection(expressions, child.output)()
57-
52+
}
53+
} else {
54+
child.execute().mapPartitions { iter =>
55+
val hashExpressions = newMutableProjection(expressions, child.output)()
5856
val mutablePair = new MutablePair[Row, Row]()
5957
iter.map(r => mutablePair.update(hashExpressions(r), r))
6058
}
@@ -65,28 +63,29 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
6563
shuffled.map(_._2)
6664

6765
case RangePartitioning(sortingExpressions, numPartitions) =>
68-
// TODO: RangePartitioner should take an Ordering.
69-
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
70-
71-
val rdd = child.execute().mapPartitions { iter =>
72-
if (sortBasedShuffleOn) {
73-
iter.map(row => (row.copy(), null))
74-
} else {
66+
val rdd = if (sortBasedShuffleOn) {
67+
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
68+
} else {
69+
child.execute().mapPartitions { iter =>
7570
val mutablePair = new MutablePair[Row, Null](null, null)
7671
iter.map(row => mutablePair.update(row, null))
7772
}
7873
}
74+
75+
// TODO: RangePartitioner should take an Ordering.
76+
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
77+
7978
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
8079
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
8180
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
8281

8382
shuffled.map(_._1)
8483

8584
case SinglePartition =>
86-
val rdd = child.execute().mapPartitions { iter =>
87-
if (sortBasedShuffleOn) {
88-
iter.map(r => (null, r.copy()))
89-
} else {
85+
val rdd = if (sortBasedShuffleOn) {
86+
child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
87+
} else {
88+
child.execute().mapPartitions { iter =>
9089
val mutablePair = new MutablePair[Null, Row]()
9190
iter.map(r => mutablePair.update(null, r))
9291
}

0 commit comments

Comments
 (0)