Skip to content

Commit 444e750

Browse files
committed
edge cases
1 parent 3de882b commit 444e750

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,13 +390,15 @@ abstract class RDD[T: ClassTag](
390390
num: Int,
391391
seed: Long = Utils.random.nextLong): Array[T] = {
392392
val numStDev = 10.0
393-
val initialCount = this.count()
394393

395394
if (num < 0) {
396395
throw new IllegalArgumentException("Negative number of elements requested")
396+
} else if (num == 0) {
397+
return new Array[T](0)
397398
}
398399

399-
if (initialCount == 0 || num == 0) {
400+
val initialCount = this.count()
401+
if (initialCount == 0) {
400402
return new Array[T](0)
401403
}
402404

@@ -407,7 +409,7 @@ abstract class RDD[T: ClassTag](
407409
}
408410

409411
val rand = new Random(seed)
410-
if (!withReplacement && num > initialCount) {
412+
if (!withReplacement && num >= initialCount) {
411413
return Utils.randomizeInPlace(this.collect(), rand)
412414
}
413415

python/pyspark/rdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def takeSample(self, withReplacement, num, seed=None):
370370
>>> len(rdd.takeSample(False, 15, 3))
371371
10
372372
"""
373+
numStDev = 10.0
373374

374375
if num < 0:
375376
raise ValueError("Sample size cannot be negative.")
@@ -388,7 +389,6 @@ def takeSample(self, withReplacement, num, seed=None):
388389
rand.shuffle(samples)
389390
return samples
390391

391-
numStDev = 10.0
392392
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
393393
if num > maxSampleSize:
394394
raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)

0 commit comments

Comments
 (0)