Skip to content

Commit db58a55

Browse files
committed
add unit tests
1 parent a6e35d6 commit db58a55

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ class RangePartitioner[K : Ordering : ClassTag, V](
108108

109109
private var ordering = implicitly[Ordering[K]]
110110

111+
@transient private[spark] var singlePass = true // for unit tests
112+
111113
// An array of upper bounds for the first (partitions - 1) partitions
112114
private var rangeBounds: Array[K] = {
113115
if (partitions == 1) {
@@ -116,7 +118,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
116118
// This is the sample size we need to have roughly balanced output partitions.
117119
val sampleSize = 20.0 * partitions
118120
// Assume the input partitions are roughly balanced and over-sample a little bit.
119-
val sampleSizePerPartition = math.ceil(5.0 * sampleSize / rdd.partitions.size).toInt
121+
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
120122
val shift = rdd.id
121123
val classTagK = classTag[K]
122124
val sketch = rdd.mapPartitionsWithIndex { (idx, iter) =>
@@ -149,9 +151,10 @@ class RangePartitioner[K : Ordering : ClassTag, V](
149151
}
150152
}
151153
if (imbalancedPartitions.nonEmpty) {
154+
singlePass = false
152155
val sampleFunc: (TaskContext, Iterator[Product2[K, V]]) => Array[K] = { (context, iter) =>
153156
val random = new XORShiftRandom(byteswap32(context.partitionId - shift))
154-
iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray
157+
iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray(classTagK)
155158
}
156159
val weight = (1.0 / fraction).toFloat
157160
val resultHandler: (Int, Array[K]) => Unit = { (_, sample) =>

core/src/test/scala/org/apache/spark/PartitioningSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,34 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
100100
partitioner.getPartition(Row(100))
101101
}
102102

103+
test("RangePartitioner should run only one job if data is roughly balanced") {
104+
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
105+
val random = new java.util.Random(i)
106+
Iterator.fill(5000 * i)((random.nextDouble() + i, i))
107+
}.cache()
108+
for (numPartitions <- Seq(10, 20, 40)) {
109+
val partitioner = new RangePartitioner(numPartitions, rdd)
110+
assert(partitioner.numPartitions === numPartitions)
111+
assert(partitioner.singlePass === true)
112+
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
113+
assert(counts.max < 2.0 * counts.min)
114+
}
115+
}
116+
117+
test("RangePartitioner should work well on unbalanced data") {
118+
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
119+
val random = new java.util.Random(i)
120+
Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i))
121+
}.cache()
122+
for (numPartitions <- Seq(2, 4, 8)) {
123+
val partitioner = new RangePartitioner(numPartitions, rdd)
124+
assert(partitioner.numPartitions === numPartitions)
125+
assert(partitioner.singlePass === false)
126+
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
127+
assert(counts.max < 2.0 * counts.min)
128+
}
129+
}
130+
103131
test("HashPartitioner not equal to RangePartitioner") {
104132
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
105133
val rangeP2 = new RangePartitioner(2, rdd)

0 commit comments

Comments
 (0)