Skip to content

Commit 0a9b3e3

Browse files
committed
"reviewer comment addressed"
1 parent f80f270 commit 0a9b3e3

File tree

5 files changed

+24
-30
lines changed

5 files changed

+24
-30
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ abstract class RDD[T: ClassTag](
391391
seed: Long = Utils.random.nextLong): Array[T] = {
392392
var fraction = 0.0
393393
var total = 0
394+
val numStDev = 10.0
394395
val initialCount = this.count()
395396

396397
if (num < 0) {
@@ -406,15 +407,15 @@ abstract class RDD[T: ClassTag](
406407
"sampling without replacement")
407408
}
408409

409-
if (initialCount > Integer.MAX_VALUE - 1) {
410-
val maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt
410+
if (initialCount > Int.MaxValue - 1) {
411+
val maxSelected = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
411412
if (num > maxSelected) {
412-
throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " +
413-
"5.0 * math.sqrt(Integer.MAX_VALUE)")
413+
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
414+
s"$numStDev * math.sqrt(Int.MaxValue)")
414415
}
415416
}
416417

417-
fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement)
418+
fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement)
418419
total = num
419420

420421
val rand = new Random(seed)

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ private[spark] object SamplingUtils {
3131
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
3232
* rate, where success rate is defined the same as in sampling with replacement.
3333
*
34-
* @param num sample size
34+
* @param sampleSizeLowerBound sample size
3535
* @param total size of RDD
3636
* @param withReplacement whether sampling with replacement
3737
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
3838
*/
39-
def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
40-
val fraction = num.toDouble / total
39+
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, withReplacement: Boolean): Double = {
40+
val fraction = sampleSizeLowerBound.toDouble / total
4141
if (withReplacement) {
42-
val numStDev = if (num < 12) 9 else 5
42+
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
4343
fraction + numStDev * math.sqrt(fraction / total)
4444
} else {
4545
val delta = 1e-4

core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ class SamplingUtilsSuite extends FunSuite{
2828
val n = 100000
2929

3030
for (s <- 1 to 15) {
31-
val frac = SamplingUtils.computeFraction(s, n, true)
31+
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
3232
val poisson = new PoissonDistribution(frac * n)
3333
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
3434
}
3535
for (s <- List(20, 100, 1000)) {
36-
val frac = SamplingUtils.computeFraction(s, n, true)
36+
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
3737
val poisson = new PoissonDistribution(frac * n)
3838
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
3939
}
4040
for (s <- List(1, 10, 100, 1000)) {
41-
val frac = SamplingUtils.computeFraction(s, n, false)
41+
val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
4242
val binomial = new BinomialDistribution(n, frac)
4343
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
4444
}

pom.xml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,6 @@
256256
<artifactId>commons-codec</artifactId>
257257
<version>1.5</version>
258258
</dependency>
259-
<dependency>
260-
<groupId>org.apache.commons</groupId>
261-
<artifactId>commons-math3</artifactId>
262-
<version>3.3</version>
263-
</dependency>
264259
<dependency>
265260
<groupId>com.google.code.findbugs</groupId>
266261
<artifactId>jsr305</artifactId>

python/pyspark/rdd.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -366,27 +366,25 @@ def takeSample(self, withReplacement, num, seed=None):
366366

367367
fraction = 0.0
368368
total = 0
369-
multiplier = 3.0
369+
numStDev = 10.0
370370
initialCount = self.count()
371-
maxSelected = 0
372371

373-
if (num < 0):
372+
if num < 0:
374373
raise ValueError
375374

376-
if (initialCount == 0):
375+
if initialCount == 0:
377376
return list()
378377

378+
if (not withReplacement) and num > initialCount:
379+
raise ValueError
380+
379381
if initialCount > sys.maxint - 1:
380-
maxSelected = sys.maxint - 1
381-
else:
382-
maxSelected = initialCount
382+
maxSelected = sys.maxint - int(numStDev * sqrt(sys.maxint))
383+
if num > maxSelected:
384+
raise ValueError
383385

384-
if num > initialCount and not withReplacement:
385-
total = maxSelected
386-
fraction = multiplier * (maxSelected + 1) / initialCount
387-
else:
388-
fraction = self._computeFraction(num, initialCount, withReplacement)
389-
total = num
386+
fraction = self._computeFraction(num, initialCount, withReplacement)
387+
total = num
390388

391389
samples = self.sample(withReplacement, fraction, seed).collect()
392390

0 commit comments

Comments
 (0)