Skip to content

Commit 4dfa2cd

Browse files
author
Davies Liu
committed
refactor
1 parent f866bcf commit 4dfa2cd

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

python/pyspark/rdd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pyspark.join import python_join, python_left_outer_join, \
3939
python_right_outer_join, python_full_outer_join, python_cogroup
4040
from pyspark.statcounter import StatCounter
41-
from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
41+
from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
4242
from pyspark.storagelevel import StorageLevel
4343
from pyspark.resultiterable import ResultIterable
4444
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -337,7 +337,7 @@ def randomSplit(self, weights, seed=None):
337337
cweights.append(cweights[-1] + w / s)
338338
if seed is None:
339339
seed = random.randint(0, 2 ** 32 - 1)
340-
return [self.mapPartitionsWithIndex(RDDSampler(False, ub, seed, lb).func, True)
340+
return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
341341
for lb, ub in zip(cweights, cweights[1:])]
342342

343343
# this is ported from scala/spark/RDD.scala

python/pyspark/rddsampler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,9 @@ def shuffle(self, vals):
9696

9797
class RDDSampler(RDDSamplerBase):
9898

99-
def __init__(self, withReplacement, fraction, seed=None, lowbound=0.0):
99+
def __init__(self, withReplacement, fraction, seed=None):
100100
RDDSamplerBase.__init__(self, withReplacement, seed)
101101
self._fraction = fraction
102-
self._lowbound = lowbound
103102

104103
def func(self, split, iterator):
105104
if self._withReplacement:
@@ -112,10 +111,23 @@ def func(self, split, iterator):
112111
yield obj
113112
else:
114113
for obj in iterator:
115-
if self._lowbound <= self.getUniformSample(split) < self._fraction:
114+
if self.getUniformSample(split) <= self._fraction:
116115
yield obj
117116

118117

118+
class RDDRangeSampler(RDDSamplerBase):
119+
120+
def __init__(self, lowerBound, upperBound, seed=None):
121+
RDDSamplerBase.__init__(self, False, seed)
122+
self._lowerBound = lowerBound
123+
self._upperBound = upperBound
124+
125+
def func(self, split, iterator):
126+
for obj in iterator:
127+
if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
128+
yield obj
129+
130+
119131
class RDDStratifiedSampler(RDDSamplerBase):
120132

121133
def __init__(self, withReplacement, fractions, seed=None):

0 commit comments

Comments
 (0)