Skip to content

Commit 16a3a3e

Browse files
mengxrconviva-zz
authored andcommitted
[SPARK-2568] RangePartitioner should run only one job if data is balanced
As of Spark 1.0, RangePartitioner goes through data twice: once to compute the count and once to do sampling. As a result, to do sortByKey, Spark goes through data 3 times (once to count, once to sample, and once to sort). `RangePartitioner` should go through data only once, collecting samples from input partitions as well as counting. If the data is balanced, this should give us a good sketch. If we see big partitions, we re-sample from them in order to collect enough items. The downside is that we need to collect more from each partition in the first pass. An alternative solution is caching the intermediate result and decide whether to fetch the data after. Author: Xiangrui Meng <[email protected]> Author: Reynold Xin <[email protected]> Closes apache#1562 from mengxr/range-partitioner and squashes the following commits: 6cc2551 [Xiangrui Meng] change foreach to for eb39b08 [Xiangrui Meng] Merge branch 'master' into range-partitioner eb95dd8 [Xiangrui Meng] separate sketching and determining bounds impl c436d30 [Xiangrui Meng] fix binary metrics unit tests db58a55 [Xiangrui Meng] add unit tests a6e35d6 [Xiangrui Meng] minor update 60be09e [Xiangrui Meng] remove importance sampler 9ee9992 [Xiangrui Meng] update range partitioner to run only one job on roughly balanced data cc12f47 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into range-part 06ac2ec [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into range-part 17bcbf3 [Reynold Xin] Added seed. badf20d [Reynold Xin] Renamed the method. 6940010 [Reynold Xin] Reservoir sampling implementation.
1 parent f07c1e3 commit 16a3a3e

File tree

3 files changed

+171
-19
lines changed

3 files changed

+171
-19
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@ package org.apache.spark
1919

2020
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
2121

22-
import scala.reflect.ClassTag
22+
import scala.collection.mutable
23+
import scala.collection.mutable.ArrayBuffer
24+
import scala.reflect.{ClassTag, classTag}
25+
import scala.util.hashing.byteswap32
2326

24-
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
2528
import org.apache.spark.serializer.JavaSerializer
2629
import org.apache.spark.util.{CollectionsUtils, Utils}
30+
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}
2731

2832
/**
2933
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V](
103107
private var ascending: Boolean = true)
104108
extends Partitioner {
105109

110+
// We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
111+
require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
112+
106113
private var ordering = implicitly[Ordering[K]]
107114

108115
// An array of upper bounds for the first (partitions - 1) partitions
109116
private var rangeBounds: Array[K] = {
110-
if (partitions == 1) {
111-
Array()
117+
if (partitions <= 1) {
118+
Array.empty
112119
} else {
113-
val rddSize = rdd.count()
114-
val maxSampleSize = partitions * 20.0
115-
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
116-
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
117-
if (rddSample.length == 0) {
118-
Array()
120+
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
121+
val sampleSize = math.min(20.0 * partitions, 1e6)
122+
// Assume the input partitions are roughly balanced and over-sample a little bit.
123+
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
124+
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
125+
if (numItems == 0L) {
126+
Array.empty
119127
} else {
120-
val bounds = new Array[K](partitions - 1)
121-
for (i <- 0 until partitions - 1) {
122-
val index = (rddSample.length - 1) * (i + 1) / partitions
123-
bounds(i) = rddSample(index)
128+
// If a partition contains much more than the average number of items, we re-sample from it
129+
// to ensure that enough items are collected from that partition.
130+
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
131+
val candidates = ArrayBuffer.empty[(K, Float)]
132+
val imbalancedPartitions = mutable.Set.empty[Int]
133+
sketched.foreach { case (idx, n, sample) =>
134+
if (fraction * n > sampleSizePerPartition) {
135+
imbalancedPartitions += idx
136+
} else {
137+
// The weight is 1 over the sampling probability.
138+
val weight = (n.toDouble / sample.size).toFloat
139+
for (key <- sample) {
140+
candidates += ((key, weight))
141+
}
142+
}
143+
}
144+
if (imbalancedPartitions.nonEmpty) {
145+
// Re-sample imbalanced partitions with the desired sampling probability.
146+
val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
147+
val seed = byteswap32(-rdd.id - 1)
148+
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
149+
val weight = (1.0 / fraction).toFloat
150+
candidates ++= reSampled.map(x => (x, weight))
124151
}
125-
bounds
152+
RangePartitioner.determineBounds(candidates, partitions)
126153
}
127154
}
128155
}
@@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V](
212239
}
213240
}
214241
}
242+
243+
private[spark] object RangePartitioner {
244+
245+
/**
246+
* Sketches the input RDD via reservoir sampling on each partition.
247+
*
248+
* @param rdd the input RDD to sketch
249+
* @param sampleSizePerPartition max sample size per partition
250+
* @return (total number of items, an array of (partitionId, number of items, sample))
251+
*/
252+
def sketch[K:ClassTag](
253+
rdd: RDD[K],
254+
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
255+
val shift = rdd.id
256+
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
257+
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
258+
val seed = byteswap32(idx ^ (shift << 16))
259+
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
260+
iter, sampleSizePerPartition, seed)
261+
Iterator((idx, n, sample))
262+
}.collect()
263+
val numItems = sketched.map(_._2.toLong).sum
264+
(numItems, sketched)
265+
}
266+
267+
/**
268+
* Determines the bounds for range partitioning from candidates with weights indicating how many
269+
* items each represents. Usually this is 1 over the probability used to sample this candidate.
270+
*
271+
* @param candidates unordered candidates with weights
272+
* @param partitions number of partitions
273+
* @return selected bounds
274+
*/
275+
def determineBounds[K:Ordering:ClassTag](
276+
candidates: ArrayBuffer[(K, Float)],
277+
partitions: Int): Array[K] = {
278+
val ordering = implicitly[Ordering[K]]
279+
val ordered = candidates.sortBy(_._1)
280+
val numCandidates = ordered.size
281+
val sumWeights = ordered.map(_._2.toDouble).sum
282+
val step = sumWeights / partitions
283+
var cumWeight = 0.0
284+
var target = step
285+
val bounds = ArrayBuffer.empty[K]
286+
var i = 0
287+
var j = 0
288+
var previousBound = Option.empty[K]
289+
while ((i < numCandidates) && (j < partitions - 1)) {
290+
val (key, weight) = ordered(i)
291+
cumWeight += weight
292+
if (cumWeight > target) {
293+
// Skip duplicate values.
294+
if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
295+
bounds += key
296+
target += step
297+
j += 1
298+
previousBound = Some(key)
299+
}
300+
}
301+
i += 1
302+
}
303+
bounds.toArray
304+
}
305+
}

core/src/test/scala/org/apache/spark/PartitioningSuite.scala

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark
1919

20+
import scala.collection.mutable.ArrayBuffer
2021
import scala.math.abs
2122

2223
import org.scalatest.{FunSuite, PrivateMethodTester}
@@ -52,14 +53,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
5253

5354
assert(p2 === p2)
5455
assert(p4 === p4)
55-
assert(p2 != p4)
56-
assert(p4 != p2)
56+
assert(p2 === p4)
5757
assert(p4 === anotherP4)
5858
assert(anotherP4 === p4)
5959
assert(descendingP2 === descendingP2)
6060
assert(descendingP4 === descendingP4)
61-
assert(descendingP2 != descendingP4)
62-
assert(descendingP4 != descendingP2)
61+
assert(descendingP2 === descendingP4)
6362
assert(p2 != descendingP2)
6463
assert(p4 != descendingP4)
6564
assert(descendingP2 != p2)
@@ -102,6 +101,63 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
102101
partitioner.getPartition(Row(100))
103102
}
104103

104+
test("RangPartitioner.sketch") {
105+
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
106+
val random = new java.util.Random(i)
107+
Iterator.fill(i)(random.nextDouble())
108+
}.cache()
109+
val sampleSizePerPartition = 10
110+
val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition)
111+
assert(count === rdd.count())
112+
sketched.foreach { case (idx, n, sample) =>
113+
assert(n === idx)
114+
assert(sample.size === math.min(n, sampleSizePerPartition))
115+
}
116+
}
117+
118+
test("RangePartitioner.determineBounds") {
119+
assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty,
120+
"Bounds on an empty candidates set should be empty.")
121+
val candidates = ArrayBuffer(
122+
(0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f))
123+
assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7))
124+
}
125+
126+
test("RangePartitioner should run only one job if data is roughly balanced") {
127+
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
128+
val random = new java.util.Random(i)
129+
Iterator.fill(5000 * i)((random.nextDouble() + i, i))
130+
}.cache()
131+
for (numPartitions <- Seq(10, 20, 40)) {
132+
val partitioner = new RangePartitioner(numPartitions, rdd)
133+
assert(partitioner.numPartitions === numPartitions)
134+
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
135+
assert(counts.max < 3.0 * counts.min)
136+
}
137+
}
138+
139+
test("RangePartitioner should work well on unbalanced data") {
140+
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
141+
val random = new java.util.Random(i)
142+
Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i))
143+
}.cache()
144+
for (numPartitions <- Seq(2, 4, 8)) {
145+
val partitioner = new RangePartitioner(numPartitions, rdd)
146+
assert(partitioner.numPartitions === numPartitions)
147+
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
148+
assert(counts.max < 3.0 * counts.min)
149+
}
150+
}
151+
152+
test("RangePartitioner should return a single partition for empty RDDs") {
153+
val empty1 = sc.emptyRDD[(Int, Double)]
154+
val partitioner1 = new RangePartitioner(0, empty1)
155+
assert(partitioner1.numPartitions === 1)
156+
val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)])
157+
val partitioner2 = new RangePartitioner(2, empty2)
158+
assert(partitioner2.numPartitions === 1)
159+
}
160+
105161
test("HashPartitioner not equal to RangePartitioner") {
106162
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
107163
val rangeP2 = new RangePartitioner(2, rdd)

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
613613
}
614614
}
615615

616+
test("sort an empty RDD") {
617+
val data = sc.emptyRDD[Int]
618+
assert(data.sortBy(x => x).collect() === Array.empty)
619+
}
620+
616621
test("sortByKey") {
617622
val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))
618623

0 commit comments

Comments
 (0)