Skip to content

Commit 6a6bfce

Browse files
committed
Fix issue related to RangePartitioning:
- We now defensively copy before computing the partition bounds, which is necessary in order to get accurate sampling. - We now pass the actual partitioner into needToCopyObjectsBeforeShuffle(), which guards against the fact that RangePartitioner may produce a shuffle with fewer than `numPartitions` partitions.
1 parent ad006a4 commit 6a6bfce

File tree

1 file changed

+37
-28
lines changed

1 file changed

+37
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.shuffle.sort.SortShuffleManager
22-
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner}
22+
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
2323
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2424
import org.apache.spark.serializer.Serializer
2525
import org.apache.spark.sql.{SQLContext, Row}
@@ -81,21 +81,25 @@ case class Exchange(
8181
*
8282
* See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
8383
*
84-
* @param numPartitions the number of output partitions produced by the shuffle
84+
* @param partitioner the partitioner for the shuffle
8585
* @param serializer the serializer that will be used to write rows
8686
* @return true if rows should be copied before being shuffled, false otherwise
8787
*/
8888
private def needToCopyObjectsBeforeShuffle(
89-
numPartitions: Int,
89+
partitioner: Partitioner,
9090
serializer: Serializer): Boolean = {
91+
// Note: even though we only use the partitioner's `numPartitions` field, we require it to be
92+
// passed instead of directly passing the number of partitions in order to guard against
93+
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
94+
// fewer partitions (like RangeParittioner, for example).
9195
if (newOrdering.nonEmpty) {
9296
// If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`,
9397
// which requires a defensive copy.
9498
true
9599
} else if (sortBasedShuffleOn) {
96100
// Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
97101
// However, there are two special cases where we can avoid the copy, described below:
98-
if (numPartitions <= bypassMergeThreshold) {
102+
if (partitioner.numPartitions <= bypassMergeThreshold) {
99103
// If the number of output partitions is sufficiently small, then Spark will fall back to
100104
// the old hash-based shuffle write path which doesn't buffer deserialized records.
101105
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
@@ -177,8 +181,9 @@ case class Exchange(
177181
val keySchema = expressions.map(_.dataType).toArray
178182
val valueSchema = child.output.map(_.dataType).toArray
179183
val serializer = getSerializer(keySchema, valueSchema, numPartitions)
184+
val part = new HashPartitioner(numPartitions)
180185

181-
val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
186+
val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
182187
child.execute().mapPartitions { iter =>
183188
val hashExpressions = newMutableProjection(expressions, child.output)()
184189
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -190,55 +195,59 @@ case class Exchange(
190195
iter.map(r => mutablePair.update(hashExpressions(r), r))
191196
}
192197
}
193-
val part = new HashPartitioner(numPartitions)
194-
val shuffled =
195-
if (newOrdering.nonEmpty) {
196-
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
197-
} else {
198-
new ShuffledRDD[Row, Row, Row](rdd, part)
199-
}
198+
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
199+
if (newOrdering.nonEmpty) {
200+
shuffled.setKeyOrdering(keyOrdering)
201+
}
200202
shuffled.setSerializer(serializer)
201203
shuffled.map(_._2)
202204

203205
case RangePartitioning(sortingExpressions, numPartitions) =>
204206
val keySchema = child.output.map(_.dataType).toArray
205207
val serializer = getSerializer(keySchema, null, numPartitions)
206208

207-
val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
208-
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
209+
val childRdd = child.execute()
210+
val part: Partitioner = {
211+
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
212+
// partition bounds. To get accurate samples, we need to copy the mutable keys.
213+
val rddForSampling = childRdd.mapPartitions { iter =>
214+
val mutablePair = new MutablePair[Row, Null]()
215+
iter.map(row => mutablePair.update(row.copy(), null))
216+
}
217+
// TODO: RangePartitioner should take an Ordering.
218+
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
219+
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
220+
}
221+
222+
val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
223+
childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
209224
} else {
210-
child.execute().mapPartitions { iter =>
211-
val mutablePair = new MutablePair[Row, Null](null, null)
225+
childRdd.mapPartitions { iter =>
226+
val mutablePair = new MutablePair[Row, Null]()
212227
iter.map(row => mutablePair.update(row, null))
213228
}
214229
}
215230

216-
// TODO: RangePartitioner should take an Ordering.
217-
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
218-
219-
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
220-
val shuffled =
221-
if (newOrdering.nonEmpty) {
222-
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
223-
} else {
224-
new ShuffledRDD[Row, Null, Null](rdd, part)
225-
}
231+
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
232+
if (newOrdering.nonEmpty) {
233+
shuffled.setKeyOrdering(keyOrdering)
234+
}
226235
shuffled.setSerializer(serializer)
227236
shuffled.map(_._1)
228237

229238
case SinglePartition =>
230239
val valueSchema = child.output.map(_.dataType).toArray
231240
val serializer = getSerializer(null, valueSchema, 1)
241+
val partitioner = new HashPartitioner(1)
232242

233-
val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions = 1, serializer)) {
243+
val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
234244
child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
235245
} else {
236246
child.execute().mapPartitions { iter =>
237247
val mutablePair = new MutablePair[Null, Row]()
238248
iter.map(r => mutablePair.update(null, r))
239249
}
240250
}
241-
val partitioner = new HashPartitioner(1)
242251
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
243252
shuffled.setSerializer(serializer)
244253
shuffled.map(_._2)

0 commit comments

Comments
 (0)