Skip to content

Commit 51649f5

Browse files
author
Davies Liu
committed
remove numpy in RDDSampler
1 parent 78bf997 commit 51649f5

File tree

2 files changed

+30
-66
lines changed

2 files changed

+30
-66
lines changed

python/pyspark/rdd.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,13 @@ def distinct(self, numPartitions=None):
310310

311311
def sample(self, withReplacement, fraction, seed=None):
312312
"""
313-
Return a sampled subset of this RDD (relies on numpy and falls back
314-
on default random generator if numpy is unavailable).
313+
Return a sampled subset of this RDD.
314+
315+
>>> rdd = sc.parallelize(range(100), 4)
316+
>>> rdd.sample(True, 0.1, 27).count()
317+
10
318+
>>> rdd.sample(False, 0.1, 81).count()
319+
10
315320
"""
316321
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
317322
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
@@ -343,8 +348,7 @@ def randomSplit(self, weights, seed=None):
343348
# this is ported from scala/spark/RDD.scala
344349
def takeSample(self, withReplacement, num, seed=None):
345350
"""
346-
Return a fixed-size sampled subset of this RDD (currently requires
347-
numpy).
351+
Return a fixed-size sampled subset of this RDD.
348352
349353
>>> rdd = sc.parallelize(range(0, 10))
350354
>>> len(rdd.takeSample(True, 20, 1))

python/pyspark/rddsampler.py

Lines changed: 22 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,76 +22,34 @@
2222
class RDDSamplerBase(object):
2323

2424
def __init__(self, withReplacement, seed=None):
25-
try:
26-
import numpy
27-
self._use_numpy = True
28-
except ImportError:
29-
print >> sys.stderr, (
30-
"NumPy does not appear to be installed. "
31-
"Falling back to default random generator for sampling.")
32-
self._use_numpy = False
33-
34-
self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1)
25+
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
3526
self._withReplacement = withReplacement
3627
self._random = None
37-
self._split = None
38-
self._rand_initialized = False
3928

4029
def initRandomGenerator(self, split):
41-
if self._use_numpy:
42-
import numpy
43-
self._random = numpy.random.RandomState(self._seed ^ split)
44-
else:
45-
self._random = random.Random(self._seed ^ split)
30+
self._random = random.Random(self._seed ^ split)
4631

4732
# mixing because the initial seeds are close to each other
4833
for _ in xrange(10):
4934
self._random.randint(0, 1)
5035

51-
self._split = split
52-
self._rand_initialized = True
53-
54-
def getUniformSample(self, split):
55-
if not self._rand_initialized or split != self._split:
56-
self.initRandomGenerator(split)
57-
58-
if self._use_numpy:
59-
return self._random.random_sample()
60-
else:
61-
return self._random.uniform(0.0, 1.0)
62-
63-
def getPoissonSample(self, split, mean):
64-
if not self._rand_initialized or split != self._split:
65-
self.initRandomGenerator(split)
36+
def getUniformSample(self):
37+
return self._random.random()
6638

67-
if self._use_numpy:
68-
return self._random.poisson(mean)
69-
else:
70-
# here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
71-
# drawing a sequence of numbers delta_j ~ Exp(mean)
72-
num_arrivals = 1
73-
cur_time = 0.0
39+
def getPoissonSample(self, mean):
40+
# here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
41+
# drawing a sequence of numbers delta_j ~ Exp(mean)
42+
num_arrivals = 0
43+
cur_time = self._random.expovariate(mean)
7444

45+
while cur_time < 1.0:
7546
cur_time += self._random.expovariate(mean)
47+
num_arrivals += 1
7648

77-
if cur_time > 1.0:
78-
return 0
49+
return num_arrivals
7950

80-
while(cur_time <= 1.0):
81-
cur_time += self._random.expovariate(mean)
82-
num_arrivals += 1
83-
84-
return (num_arrivals - 1)
85-
86-
def shuffle(self, vals):
87-
if self._random is None:
88-
self.initRandomGenerator(0) # this should only ever called on the master so
89-
# the split does not matter
90-
91-
if self._use_numpy:
92-
self._random.shuffle(vals)
93-
else:
94-
self._random.shuffle(vals, self._random.random)
51+
def func(self, split, iterator):
52+
raise NotImplementedError
9553

9654

9755
class RDDSampler(RDDSamplerBase):
@@ -101,31 +59,32 @@ def __init__(self, withReplacement, fraction, seed=None):
10159
self._fraction = fraction
10260

10361
def func(self, split, iterator):
62+
self.initRandomGenerator(split)
10463
if self._withReplacement:
10564
for obj in iterator:
10665
# For large datasets, the expected number of occurrences of each element in
10766
# a sample with replacement is Poisson(frac). We use that to get a count for
10867
# each element.
109-
count = self.getPoissonSample(split, mean=self._fraction)
68+
count = self.getPoissonSample(self._fraction)
11069
for _ in range(0, count):
11170
yield obj
11271
else:
11372
for obj in iterator:
114-
if self.getUniformSample(split) <= self._fraction:
73+
if self.getUniformSample() <= self._fraction:
11574
yield obj
11675

11776

11877
class RDDRangeSampler(RDDSamplerBase):
11978

12079
def __init__(self, lowerBound, upperBound, seed=None):
12180
RDDSamplerBase.__init__(self, False, seed)
122-
self._use_numpy = False # no performance gain from numpy
12381
self._lowerBound = lowerBound
12482
self._upperBound = upperBound
12583

12684
def func(self, split, iterator):
85+
self.initRandomGenerator(split)
12786
for obj in iterator:
128-
if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
87+
if self._lowerBound <= self.getUniformSample() < self._upperBound:
12988
yield obj
13089

13190

@@ -136,15 +95,16 @@ def __init__(self, withReplacement, fractions, seed=None):
13695
self._fractions = fractions
13796

13897
def func(self, split, iterator):
98+
self.initRandomGenerator(split)
13999
if self._withReplacement:
140100
for key, val in iterator:
141101
# For large datasets, the expected number of occurrences of each element in
142102
# a sample with replacement is Poisson(frac). We use that to get a count for
143103
# each element.
144-
count = self.getPoissonSample(split, mean=self._fractions[key])
104+
count = self.getPoissonSample(self._fractions[key])
145105
for _ in range(0, count):
146106
yield key, val
147107
else:
148108
for key, val in iterator:
149-
if self.getUniformSample(split) <= self._fractions[key]:
109+
if self.getUniformSample() <= self._fractions[key]:
150110
yield key, val

0 commit comments

Comments
 (0)