@@ -43,7 +43,7 @@ import org.apache.spark.partial.PartialResult
43
43
import org .apache .spark .storage .StorageLevel
44
44
import org .apache .spark .util .{BoundedPriorityQueue , SerializableHyperLogLog , Utils }
45
45
import org .apache .spark .util .collection .OpenHashMap
46
- import org .apache .spark .util .random .{BernoulliSampler , PoissonSampler }
46
+ import org .apache .spark .util .random .{BernoulliSampler , PoissonSampler , SamplingUtils }
47
47
48
48
/**
49
49
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -400,12 +400,21 @@ abstract class RDD[T: ClassTag](
400
400
throw new IllegalArgumentException (" Negative number of elements requested" )
401
401
}
402
402
403
+ if (! withReplacement && num > initialCount) {
404
+ throw new IllegalArgumentException (" Cannot create sample larger than the original when " +
405
+ " sampling without replacement" )
406
+ }
407
+
403
408
if (initialCount == 0 ) {
404
409
return new Array [T ](0 )
405
410
}
406
411
407
412
if (initialCount > Integer .MAX_VALUE - 1 ) {
408
- maxSelected = Integer .MAX_VALUE - 1
413
+ maxSelected = Integer .MAX_VALUE - (5.0 * math.sqrt(Integer .MAX_VALUE )).toInt
414
+ if (num > maxSelected) {
415
+ throw new IllegalArgumentException (" Cannot support a sample size > Integer.MAX_VALUE - " +
416
+ " 5.0 * math.sqrt(Integer.MAX_VALUE)" )
417
+ }
409
418
} else {
410
419
maxSelected = initialCount.toInt
411
420
}
@@ -415,7 +424,7 @@ abstract class RDD[T: ClassTag](
415
424
total = maxSelected
416
425
fraction = multiplier * (maxSelected + 1 ) / initialCount
417
426
} else {
418
- fraction = computeFraction(num, initialCount, withReplacement)
427
+ fraction = SamplingUtils . computeFraction(num, initialCount, withReplacement)
419
428
total = num
420
429
}
421
430
@@ -431,35 +440,6 @@ abstract class RDD[T: ClassTag](
431
440
Utils .randomizeInPlace(samples, rand).take(total)
432
441
}
433
442
434
- /**
435
- * Let p = num / total, where num is the sample size and total is the total number of
436
- * datapoints in the RDD. We're trying to compute q > p such that
437
- * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
438
- * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
439
- * i.e. the failure rate of not having a sufficiently large sample < 0.0001.
440
- * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
441
- * num > 12, but we need a slightly larger q (9 empirically determined).
442
- * - when sampling without replacement, we're drawing each datapoint with prob_i
443
- * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
444
- * rate, where success rate is defined the same as in sampling with replacement.
445
- *
446
- * @param num sample size
447
- * @param total size of RDD
448
- * @param withReplacement whether sampling with replacement
449
- * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
450
- */
451
- private [rdd] def computeFraction (num : Int , total : Long , withReplacement : Boolean ): Double = {
452
- val fraction = num.toDouble / total
453
- if (withReplacement) {
454
- val numStDev = if (num < 12 ) 9 else 5
455
- fraction + numStDev * math.sqrt(fraction / total)
456
- } else {
457
- val delta = 1e-4
458
- val gamma = - math.log(delta) / total
459
- math.min(1 , fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
460
- }
461
- }
462
-
463
443
/**
464
444
* Return the union of this RDD and another one. Any identical elements will appear multiple
465
445
* times (use `.distinct()` to eliminate them).
0 commit comments