@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
19
19
20
20
import org .apache .spark .annotation .DeveloperApi
21
21
import org .apache .spark .shuffle .sort .SortShuffleManager
22
- import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner }
22
+ import org .apache .spark .{HashPartitioner , Partitioner , RangePartitioner , SparkEnv }
23
23
import org .apache .spark .rdd .{RDD , ShuffledRDD }
24
24
import org .apache .spark .serializer .Serializer
25
25
import org .apache .spark .sql .{SQLContext , Row }
@@ -81,21 +81,25 @@ case class Exchange(
81
81
*
82
82
* See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
83
83
*
84
- * @param numPartitions the number of output partitions produced by the shuffle
84
+ * @param partitioner the partitioner for the shuffle
85
85
* @param serializer the serializer that will be used to write rows
86
86
* @return true if rows should be copied before being shuffled, false otherwise
87
87
*/
88
88
private def needToCopyObjectsBeforeShuffle (
89
- numPartitions : Int ,
89
+ partitioner : Partitioner ,
90
90
serializer : Serializer ): Boolean = {
91
+ // Note: even though we only use the partitioner's `numPartitions` field, we require it to be
92
+ // passed instead of directly passing the number of partitions in order to guard against
93
+ // corner-cases where a partitioner constructed with `numPartitions` partitions may output
94
+ // fewer partitions (like RangeParittioner, for example).
91
95
if (newOrdering.nonEmpty) {
92
96
// If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`,
93
97
// which requires a defensive copy.
94
98
true
95
99
} else if (sortBasedShuffleOn) {
96
100
// Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
97
101
// However, there are two special cases where we can avoid the copy, described below:
98
- if (numPartitions <= bypassMergeThreshold) {
102
+ if (partitioner. numPartitions <= bypassMergeThreshold) {
99
103
// If the number of output partitions is sufficiently small, then Spark will fall back to
100
104
// the old hash-based shuffle write path which doesn't buffer deserialized records.
101
105
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
@@ -177,8 +181,9 @@ case class Exchange(
177
181
val keySchema = expressions.map(_.dataType).toArray
178
182
val valueSchema = child.output.map(_.dataType).toArray
179
183
val serializer = getSerializer(keySchema, valueSchema, numPartitions)
184
+ val part = new HashPartitioner (numPartitions)
180
185
181
- val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions , serializer)) {
186
+ val rdd = if (needToCopyObjectsBeforeShuffle(part , serializer)) {
182
187
child.execute().mapPartitions { iter =>
183
188
val hashExpressions = newMutableProjection(expressions, child.output)()
184
189
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -190,55 +195,59 @@ case class Exchange(
190
195
iter.map(r => mutablePair.update(hashExpressions(r), r))
191
196
}
192
197
}
193
- val part = new HashPartitioner (numPartitions)
194
- val shuffled =
195
- if (newOrdering.nonEmpty) {
196
- new ShuffledRDD [Row , Row , Row ](rdd, part).setKeyOrdering(keyOrdering)
197
- } else {
198
- new ShuffledRDD [Row , Row , Row ](rdd, part)
199
- }
198
+ val shuffled = new ShuffledRDD [Row , Row , Row ](rdd, part)
199
+ if (newOrdering.nonEmpty) {
200
+ shuffled.setKeyOrdering(keyOrdering)
201
+ }
200
202
shuffled.setSerializer(serializer)
201
203
shuffled.map(_._2)
202
204
203
205
case RangePartitioning (sortingExpressions, numPartitions) =>
204
206
val keySchema = child.output.map(_.dataType).toArray
205
207
val serializer = getSerializer(keySchema, null , numPartitions)
206
208
207
- val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
208
- child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null ))}
209
+ val childRdd = child.execute()
210
+ val part : Partitioner = {
211
+ // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
212
+ // partition bounds. To get accurate samples, we need to copy the mutable keys.
213
+ val rddForSampling = childRdd.mapPartitions { iter =>
214
+ val mutablePair = new MutablePair [Row , Null ]()
215
+ iter.map(row => mutablePair.update(row.copy(), null ))
216
+ }
217
+ // TODO: RangePartitioner should take an Ordering.
218
+ implicit val ordering = new RowOrdering (sortingExpressions, child.output)
219
+ new RangePartitioner (numPartitions, rddForSampling, ascending = true )
220
+ }
221
+
222
+ val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
223
+ childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null ))}
209
224
} else {
210
- child.execute() .mapPartitions { iter =>
211
- val mutablePair = new MutablePair [Row , Null ](null , null )
225
+ childRdd .mapPartitions { iter =>
226
+ val mutablePair = new MutablePair [Row , Null ]()
212
227
iter.map(row => mutablePair.update(row, null ))
213
228
}
214
229
}
215
230
216
- // TODO: RangePartitioner should take an Ordering.
217
- implicit val ordering = new RowOrdering (sortingExpressions, child.output)
218
-
219
- val part = new RangePartitioner (numPartitions, rdd, ascending = true )
220
- val shuffled =
221
- if (newOrdering.nonEmpty) {
222
- new ShuffledRDD [Row , Null , Null ](rdd, part).setKeyOrdering(keyOrdering)
223
- } else {
224
- new ShuffledRDD [Row , Null , Null ](rdd, part)
225
- }
231
+ val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
232
+ if (newOrdering.nonEmpty) {
233
+ shuffled.setKeyOrdering(keyOrdering)
234
+ }
226
235
shuffled.setSerializer(serializer)
227
236
shuffled.map(_._1)
228
237
229
238
case SinglePartition =>
230
239
val valueSchema = child.output.map(_.dataType).toArray
231
240
val serializer = getSerializer(null , valueSchema, 1 )
241
+ val partitioner = new HashPartitioner (1 )
232
242
233
- val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions = 1 , serializer)) {
243
+ val rdd = if (needToCopyObjectsBeforeShuffle(partitioner , serializer)) {
234
244
child.execute().mapPartitions { iter => iter.map(r => (null , r.copy())) }
235
245
} else {
236
246
child.execute().mapPartitions { iter =>
237
247
val mutablePair = new MutablePair [Null , Row ]()
238
248
iter.map(r => mutablePair.update(null , r))
239
249
}
240
250
}
241
- val partitioner = new HashPartitioner (1 )
242
251
val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
243
252
shuffled.setSerializer(serializer)
244
253
shuffled.map(_._2)
0 commit comments