Skip to content

Commit 6940010

Browse files
committed
Reservoir sampling implementation.
1 parent 72e9021 commit 6940010

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,48 @@
1717

1818
package org.apache.spark.util.random
1919

20+
import scala.reflect.ClassTag
21+
2022
private[spark] object SamplingUtils {
2123

24+
/**
25+
* Reservoir Sampling implementation.
26+
*
27+
* @param input input size
28+
* @param k reservoir size
29+
* @return (samples, input size)
30+
*/
31+
def reservoirSample[T: ClassTag](input: Iterator[T], k: Int): (Array[T], Int) = {
32+
val reservoir = new Array[T](k)
33+
// Put the first k elements in the reservoir.
34+
var i = 0
35+
while (i < k && input.hasNext) {
36+
val item = input.next()
37+
reservoir(i) = item
38+
i += 1
39+
}
40+
41+
// If we have consumed all the elements, return them. Otherwise do the replacement.
42+
if (i < k) {
43+
// If input size < k, trim the array to return only an array of input size.
44+
val trimReservoir = new Array[T](i)
45+
System.arraycopy(reservoir, 0, trimReservoir, 0, i)
46+
(trimReservoir, i)
47+
} else {
48+
// If input size > k, continue the sampling process.
49+
val rand = new XORShiftRandom
50+
while (input.hasNext) {
51+
val item = input.next()
52+
val replacementIndex = rand.nextInt(i)
53+
if (replacementIndex < k) {
54+
reservoir(replacementIndex) = item
55+
}
56+
i += 1
57+
}
58+
(reservoir, i)
59+
}
60+
}
61+
2262
/**
2363
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
2464
* the time.

core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,32 @@
1717

1818
package org.apache.spark.util.random
1919

20+
import scala.util.Random
21+
2022
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
2123
import org.scalatest.FunSuite
2224

2325
class SamplingUtilsSuite extends FunSuite {
2426

27+
test("reservoirSample") {
28+
val input = Seq.fill(100)(Random.nextInt())
29+
30+
// input size < k
31+
val (sample1, count1) = SamplingUtils.reservoirSample(input.iterator, 150)
32+
assert(count1 === 100)
33+
assert(input === sample1.toSeq)
34+
35+
// input size == k
36+
val (sample2, count2) = SamplingUtils.reservoirSample(input.iterator, 100)
37+
assert(count2 === 100)
38+
assert(input === sample2.toSeq)
39+
40+
// input size > k
41+
val (sample3, count3) = SamplingUtils.reservoirSample(input.iterator, 10)
42+
assert(count3 === 100)
43+
assert(sample3.length === 10)
44+
}
45+
2546
test("computeFraction") {
2647
// test that the computed fraction guarantees enough data points
2748
// in the sample with a failure rate <= 0.0001

0 commit comments

Comments
 (0)