Skip to content

Commit 7f22fa8

Browse files
Davies Liumengxr
authored andcommitted
[SPARK-4327] [PySpark] Python API for RDD.randomSplit()
``` pyspark.RDD.randomSplit(self, weights, seed=None) Randomly splits this RDD with the provided weights. :param weights: weights for splits, will be normalized if they don't sum to 1 :param seed: random seed :return: split RDDs in an list >>> rdd = sc.parallelize(range(10), 1) >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11) >>> rdd1.collect() [3, 6] >>> rdd2.collect() [0, 5, 7] >>> rdd3.collect() [1, 2, 4, 8, 9] ``` Author: Davies Liu <[email protected]> Closes apache#3193 from davies/randomSplit and squashes the following commits: 78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain f5fdf63 [Davies Liu] fix bug with int in weights 4dfa2cd [Davies Liu] refactor f866bcf [Davies Liu] remove unneeded change c7a2007 [Davies Liu] switch to python implementation 95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit 0d9b256 [Davies Liu] refactor 1715ee3 [Davies Liu] address comments 41fce54 [Davies Liu] randomSplit()
1 parent bb46046 commit 7f22fa8

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

python/pyspark/rdd.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import warnings
2929
import heapq
3030
import bisect
31-
from random import Random
31+
import random
3232
from math import sqrt, log, isinf, isnan
3333

3434
from pyspark.accumulators import PStatsParam
@@ -38,7 +38,7 @@
3838
from pyspark.join import python_join, python_left_outer_join, \
3939
python_right_outer_join, python_full_outer_join, python_cogroup
4040
from pyspark.statcounter import StatCounter
41-
from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
41+
from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
4242
from pyspark.storagelevel import StorageLevel
4343
from pyspark.resultiterable import ResultIterable
4444
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -316,6 +316,30 @@ def sample(self, withReplacement, fraction, seed=None):
316316
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
317317
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
318318

319+
def randomSplit(self, weights, seed=None):
320+
"""
321+
Randomly splits this RDD with the provided weights.
322+
323+
:param weights: weights for splits, will be normalized if they don't sum to 1
324+
:param seed: random seed
325+
:return: split RDDs in a list
326+
327+
>>> rdd = sc.parallelize(range(5), 1)
328+
>>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
329+
>>> rdd1.collect()
330+
[1, 3]
331+
>>> rdd2.collect()
332+
[0, 2, 4]
333+
"""
334+
s = float(sum(weights))
335+
cweights = [0.0]
336+
for w in weights:
337+
cweights.append(cweights[-1] + w / s)
338+
if seed is None:
339+
seed = random.randint(0, 2 ** 32 - 1)
340+
return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
341+
for lb, ub in zip(cweights, cweights[1:])]
342+
319343
# this is ported from scala/spark/RDD.scala
320344
def takeSample(self, withReplacement, num, seed=None):
321345
"""
@@ -341,7 +365,7 @@ def takeSample(self, withReplacement, num, seed=None):
341365
if initialCount == 0:
342366
return []
343367

344-
rand = Random(seed)
368+
rand = random.Random(seed)
345369

346370
if (not withReplacement) and num >= initialCount:
347371
# shuffle current RDD and return

python/pyspark/rddsampler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,20 @@ def func(self, split, iterator):
115115
yield obj
116116

117117

118+
class RDDRangeSampler(RDDSamplerBase):
119+
120+
def __init__(self, lowerBound, upperBound, seed=None):
121+
RDDSamplerBase.__init__(self, False, seed)
122+
self._use_numpy = False # no performance gain from numpy
123+
self._lowerBound = lowerBound
124+
self._upperBound = upperBound
125+
126+
def func(self, split, iterator):
127+
for obj in iterator:
128+
if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
129+
yield obj
130+
131+
118132
class RDDStratifiedSampler(RDDSamplerBase):
119133

120134
def __init__(self, withReplacement, fractions, seed=None):

0 commit comments

Comments
 (0)