Skip to content

Commit 9bdd36e

Browse files
committed
Check sample size and move computeFraction
Check that the sample size is within supported range. Moved computeFraction int a private util class in util.random
1 parent e3fd6a6 commit 9bdd36e

File tree

4 files changed

+108
-55
lines changed

4 files changed

+108
-55
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import org.apache.spark.partial.PartialResult
4343
import org.apache.spark.storage.StorageLevel
4444
import org.apache.spark.util.{BoundedPriorityQueue, SerializableHyperLogLog, Utils}
4545
import org.apache.spark.util.collection.OpenHashMap
46-
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
46+
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
4747

4848
/**
4949
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -400,12 +400,21 @@ abstract class RDD[T: ClassTag](
400400
throw new IllegalArgumentException("Negative number of elements requested")
401401
}
402402

403+
if (!withReplacement && num > initialCount) {
404+
throw new IllegalArgumentException("Cannot create sample larger than the original when " +
405+
"sampling without replacement")
406+
}
407+
403408
if (initialCount == 0) {
404409
return new Array[T](0)
405410
}
406411

407412
if (initialCount > Integer.MAX_VALUE - 1) {
408-
maxSelected = Integer.MAX_VALUE - 1
413+
maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt
414+
if (num > maxSelected) {
415+
throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " +
416+
"5.0 * math.sqrt(Integer.MAX_VALUE)")
417+
}
409418
} else {
410419
maxSelected = initialCount.toInt
411420
}
@@ -415,7 +424,7 @@ abstract class RDD[T: ClassTag](
415424
total = maxSelected
416425
fraction = multiplier * (maxSelected + 1) / initialCount
417426
} else {
418-
fraction = computeFraction(num, initialCount, withReplacement)
427+
fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement)
419428
total = num
420429
}
421430

@@ -431,35 +440,6 @@ abstract class RDD[T: ClassTag](
431440
Utils.randomizeInPlace(samples, rand).take(total)
432441
}
433442

434-
/**
435-
* Let p = num / total, where num is the sample size and total is the total number of
436-
* datapoints in the RDD. We're trying to compute q > p such that
437-
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
438-
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
439-
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
440-
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
441-
* num > 12, but we need a slightly larger q (9 empirically determined).
442-
* - when sampling without replacement, we're drawing each datapoint with prob_i
443-
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
444-
* rate, where success rate is defined the same as in sampling with replacement.
445-
*
446-
* @param num sample size
447-
* @param total size of RDD
448-
* @param withReplacement whether sampling with replacement
449-
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
450-
*/
451-
private[rdd] def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
452-
val fraction = num.toDouble / total
453-
if (withReplacement) {
454-
val numStDev = if (num < 12) 9 else 5
455-
fraction + numStDev * math.sqrt(fraction / total)
456-
} else {
457-
val delta = 1e-4
458-
val gamma = - math.log(delta) / total
459-
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
460-
}
461-
}
462-
463443
/**
464444
* Return the union of this RDD and another one. Any identical elements will appear multiple
465445
* times (use `.distinct()` to eliminate them).
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.random
19+
20+
private[spark] object SamplingUtils {
21+
22+
/**
23+
* Let p = num / total, where num is the sample size and total is the total number of
24+
* datapoints in the RDD. We're trying to compute q > p such that
25+
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
26+
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
27+
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
28+
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
29+
* num > 12, but we need a slightly larger q (9 empirically determined).
30+
* - when sampling without replacement, we're drawing each datapoint with prob_i
31+
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
32+
* rate, where success rate is defined the same as in sampling with replacement.
33+
*
34+
* @param num sample size
35+
* @param total size of RDD
36+
* @param withReplacement whether sampling with replacement
37+
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
38+
*/
39+
def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
40+
val fraction = num.toDouble / total
41+
if (withReplacement) {
42+
val numStDev = if (num < 12) 9 else 5
43+
fraction + numStDev * math.sqrt(fraction / total)
44+
} else {
45+
val delta = 1e-4
46+
val gamma = - math.log(delta) / total
47+
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
48+
}
49+
}
50+
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -497,29 +497,6 @@ class RDDSuite extends FunSuite with SharedSparkContext {
497497
assert(sortedTopK === nums.sorted(ord).take(5))
498498
}
499499

500-
test("computeFraction") {
501-
// test that the computed fraction guarantees enough datapoints
502-
// in the sample with a failure rate <= 0.0001
503-
val data = new EmptyRDD[Int](sc)
504-
val n = 100000
505-
506-
for (s <- 1 to 15) {
507-
val frac = data.computeFraction(s, n, true)
508-
val poisson = new PoissonDistribution(frac * n)
509-
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
510-
}
511-
for (s <- List(20, 100, 1000)) {
512-
val frac = data.computeFraction(s, n, true)
513-
val poisson = new PoissonDistribution(frac * n)
514-
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
515-
}
516-
for (s <- List(1, 10, 100, 1000)) {
517-
val frac = data.computeFraction(s, n, false)
518-
val binomial = new BinomialDistribution(n, frac)
519-
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
520-
}
521-
}
522-
523500
test("takeSample") {
524501
val n = 1000000
525502
val data = sc.parallelize(1 to n, 2)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.random
19+
20+
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
21+
import org.scalatest.FunSuite
22+
23+
class SamplingUtilsSuite extends FunSuite{
24+
25+
test("computeFraction") {
26+
// test that the computed fraction guarantees enough datapoints
27+
// in the sample with a failure rate <= 0.0001
28+
val n = 100000
29+
30+
for (s <- 1 to 15) {
31+
val frac = SamplingUtils.computeFraction(s, n, true)
32+
val poisson = new PoissonDistribution(frac * n)
33+
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
34+
}
35+
for (s <- List(20, 100, 1000)) {
36+
val frac = SamplingUtils.computeFraction(s, n, true)
37+
val poisson = new PoissonDistribution(frac * n)
38+
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
39+
}
40+
for (s <- List(1, 10, 100, 1000)) {
41+
val frac = SamplingUtils.computeFraction(s, n, false)
42+
val binomial = new BinomialDistribution(n, frac)
43+
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)