Skip to content

Commit 944a10c

Browse files
committed
[SPARK-2145] Add lower bound on sampling rate
to guarantee sampling performance
1 parent 0214a76 commit 944a10c

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
217217
* the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
218218
* need two additional passes over the RDD to guarantee sample size with a 99.99% confidence.
219219
*
220+
* Note that if the sampling rate for any stratum is < 1e-10, we will throw an exception to
221+
* avoid not being able to ever create the sample as an artifact of the RNG's quality.
222+
*
220223
* @param withReplacement whether to sample with or without replacement
221224
* @param fractionByKey function mapping key to sampling rate
222225
* @param seed seed for the random number generator
@@ -227,6 +230,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
227230
fractionByKey: K => Double,
228231
seed: Long = Utils.random.nextLong,
229232
exact: Boolean = true): RDD[(K, V)]= {
233+
234+
require(fractionByKey.asInstanceOf[Map[K, Double]].forall({case(k, v) => v >= 1e-10}),
235+
"Unable to support sampling rates < 1e-10.")
236+
230237
if (withReplacement) {
231238
val counts = if (exact) Some(this.countByKey()) else None
232239
val samplingFunc =

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,13 @@ abstract class RDD[T: ClassTag](
350350

351351
/**
352352
* Return a sampled subset of this RDD.
353+
*
354+
* fraction < 1e-10 not supported.
353355
*/
354356
def sample(withReplacement: Boolean,
355357
fraction: Double,
356358
seed: Long = Utils.random.nextLong): RDD[T] = {
357-
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
359+
require(fraction >= 1e-10, "Invalid fraction value: " + fraction)
358360
if (withReplacement) {
359361
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
360362
} else {

0 commit comments

Comments
 (0)