Skip to content

Commit 0c247db

Browse files
committed
SPARK-1438 RDD language apis to support optional seed in RDD methods sample/takeSample
1 parent fb98488 commit 0c247db

File tree

11 files changed

+58
-18
lines changed

11 files changed

+58
-18
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
133133
/**
134134
* Return a sampled subset of this RDD.
135135
*/
136-
def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD =
136+
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
137+
sample(withReplacement, fraction, System.nanoTime)
138+
139+
/**
140+
* Return a sampled subset of this RDD.
141+
*/
142+
def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD =
137143
fromRDD(srdd.sample(withReplacement, fraction, seed))
138144

139145
/**

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
119119
/**
120120
* Return a sampled subset of this RDD.
121121
*/
122-
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
122+
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
123+
sample(withReplacement, fraction, System.nanoTime)
124+
125+
/**
126+
* Return a sampled subset of this RDD.
127+
*/
128+
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
123129
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
124130

125131
/**

core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
9898
/**
9999
* Return a sampled subset of this RDD.
100100
*/
101-
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
101+
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
102+
sample(withReplacement, fraction, System.nanoTime)
103+
104+
/**
105+
* Return a sampled subset of this RDD.
106+
*/
107+
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
102108
wrapRDD(rdd.sample(withReplacement, fraction, seed))
103109

104110
/**

core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
394394
new java.util.ArrayList(arr)
395395
}
396396

397-
def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
397+
def takeSample(withReplacement: Boolean, num: Int): JList[T] =
398+
takeSample(withReplacement, num, System.nanoTime)
399+
400+
def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
398401
import scala.collection.JavaConversions._
399402
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
400403
new java.util.ArrayList(arr)

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ abstract class RDD[T: ClassTag](
321321
/**
322322
* Return a sampled subset of this RDD.
323323
*/
324-
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
324+
def sample(withReplacement: Boolean, fraction: Double, seed: Long = System.nanoTime): RDD[T] = {
325325
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
326326
if (withReplacement) {
327327
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
@@ -346,7 +346,7 @@ abstract class RDD[T: ClassTag](
346346
}.toArray
347347
}
348348

349-
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
349+
def takeSample(withReplacement: Boolean, num: Int, seed: Long = System.nanoTime): Array[T] = {
350350
var fraction = 0.0
351351
var total = 0
352352
val multiplier = 3.0

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
466466
test("takeSample") {
467467
val data = sc.parallelize(1 to 100, 2)
468468

469+
for (num <- List(5,20,100)) {
470+
val sample = data.takeSample(withReplacement=false, num=num)
471+
assert(sample.size === num) // Got exactly num elements
472+
assert(sample.toSet.size === num) // Elements are distinct
473+
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
474+
}
469475
for (seed <- 1 to 5) {
470476
val sample = data.takeSample(withReplacement=false, 20, seed)
471477
assert(sample.size === 20) // Got exactly 20 elements
@@ -483,6 +489,19 @@ class RDDSuite extends FunSuite with SharedSparkContext {
483489
assert(sample.size === 20) // Got exactly 20 elements
484490
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
485491
}
492+
{
493+
val sample = data.takeSample(withReplacement=true, num=20)
494+
assert(sample.size === 20) // Got exactly 100 elements
495+
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
496+
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
497+
}
498+
{
499+
val sample = data.takeSample(withReplacement=true, num=100)
500+
assert(sample.size === 100) // Got exactly 100 elements
501+
// Chance of getting all distinct elements is astronomically low, so test we got < 100
502+
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
503+
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
504+
}
486505
for (seed <- 1 to 5) {
487506
val sample = data.takeSample(withReplacement=true, 100, seed)
488507
assert(sample.size === 100) // Got exactly 100 elements

python/pyspark/rdd.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from threading import Thread
3131
import warnings
3232
import heapq
33+
import random
3334

3435
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3536
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -332,7 +333,7 @@ def distinct(self):
332333
.reduceByKey(lambda x, _: x) \
333334
.map(lambda (x, _): x)
334335

335-
def sample(self, withReplacement, fraction, seed):
336+
def sample(self, withReplacement, fraction, seed=None):
336337
"""
337338
Return a sampled subset of this RDD (relies on numpy and falls back
338339
on default random generator if numpy is unavailable).
@@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed):
344345
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
345346

346347
# this is ported from scala/spark/RDD.scala
347-
def takeSample(self, withReplacement, num, seed):
348+
def takeSample(self, withReplacement, num, seed=None):
348349
"""
349350
Return a fixed-size sampled subset of this RDD (currently requires numpy).
350351
@@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed):
381382
# If the first sample didn't turn out large enough, keep trying to take samples;
382383
# this shouldn't happen often because we use a big multiplier for their initial size.
383384
# See: scala/spark/RDD.scala
385+
random.seed(seed)
384386
while len(samples) < total:
385-
if seed > sys.maxint - 2:
386-
seed = -1
387-
seed += 1
388-
samples = self.sample(withReplacement, fraction, seed).collect()
387+
samples = self.sample(withReplacement, fraction, random.randint(0,sys.maxint)).collect()
389388

390-
sampler = RDDSampler(withReplacement, fraction, seed+1)
389+
sampler = RDDSampler(withReplacement, fraction, random.randint(0,sys.maxint))
391390
sampler.shuffle(samples)
392391
return samples[0:total]
393392

python/pyspark/rddsampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import random
2020

2121
class RDDSampler(object):
22-
def __init__(self, withReplacement, fraction, seed):
22+
def __init__(self, withReplacement, fraction, seed=None):
2323
try:
2424
import numpy
2525
self._use_numpy = True

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
168168
def references = Set.empty
169169
}
170170

171-
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
171+
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
172172
extends UnaryNode {
173173

174174
def output = child.output

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,11 @@ class SchemaRDD(
256256
* @group Query
257257
*/
258258
@Experimental
259+
override
259260
def sample(
260-
fraction: Double,
261261
withReplacement: Boolean = true,
262-
seed: Int = (math.random * 1000).toInt) =
262+
fraction: Double,
263+
seed: Long) =
263264
new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))
264265

265266
/**

0 commit comments

Comments
 (0)